diff --git a/.cirrus.yml b/.cirrus.yml new file mode 100644 index 0000000000..c055e33e0b --- /dev/null +++ b/.cirrus.yml @@ -0,0 +1,34 @@ +env: + CIRRUS_CLONE_DEPTH: 1 + +freebsd_task: + matrix: + - name: FreeBSD 13.0 (GCC 10 from packages) + # RunCatchingExceptionsOtherException fails on 13.0 with system Clang + # 11.0 and ports Clang 10/11 as well, so GCC 10 is used instead. + freebsd_instance: + image_family: freebsd-13-0 + preinstall_script: + # Stock clang11 fails some exception unit tests + pkg install -y gcc10 + env: + CC: gcc10 + CXX: g++10 + - name: FreeBSD 12.2 (System Clang 10) + freebsd_instance: + image_family: freebsd-12-2 + - name: FreeBSD 11.4 (System Clang 10) + freebsd_instance: + image_family: freebsd-11-4 + install_script: + pkg install -y automake autoconf libtool + compiler_version_script: + ${CXX:-"c++"} --version + autoreconf_script: + - cd c++ && autoreconf -i + configure_script: + - cd c++ && ./configure + build_script: + - make -C c++ + test_script: + - make -C c++ check diff --git a/.github/workflows/quick-test.yml b/.github/workflows/quick-test.yml new file mode 100644 index 0000000000..e8ad4b1252 --- /dev/null +++ b/.github/workflows/quick-test.yml @@ -0,0 +1,207 @@ +name: Quick Tests + +on: + pull_request: + paths-ignore: + - 'doc/**' + push: + branches: + - master + - 'release-*' + +jobs: + Linux-musl: + runs-on: ubuntu-20.04 + # We depend on both clang and libc++. Alpine Linux 3.17 seems to be the first version to include + # a libc++ package (based on LLVM 15), but building capnproto failed when I tried it. + # Alpine Linux 3.18's libc++ package is from LLVM 16, however, and worked out-of-the-box, so + # Clang 16 appears to be the earliest Clang version we can run easily on Alpine Linux. + container: alpine:3.18.2 + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: apk add autoconf automake bash build-base cmake libtool libucontext-dev linux-headers openssl-dev clang16 libc++-dev + - name: super-test + run: ./super-test.sh quick clang-16 + Linux-old: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + compiler: [g++-7, clang-6.0] + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get install -y build-essential git zlib1g-dev cmake libssl-dev ${{ matrix.compiler }} + - name: super-test + run: | + ./super-test.sh quick ${{ matrix.compiler }} + Linux: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + compiler: [g++-12, clang-14] + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get install -y build-essential git zlib1g-dev cmake libssl-dev ${{ matrix.compiler }} + - name: super-test + run: | + ./super-test.sh quick ${{ matrix.compiler }} + Linux-lock-tracking: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + compiler: [clang-9] + features: ["-DKJ_TRACK_LOCK_BLOCKING=1 -DKJ_SAVE_ACQUIRED_LOCK_INFO=1 -DKJ_CONTENTION_WARNING_THRESHOLD=200"] + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get install -y build-essential git zlib1g-dev cmake libssl-dev ${{ matrix.compiler }} + - name: super-test + run: | + # librt is used for timer_create in the unit tests for lock tracking (mutex-test.c++). + ./super-test.sh quick ${{ matrix.compiler }} cpp-features "${{matrix.features}}" extra-libs "-lrt" + ManyLinux: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + include: + - cross-compiler: manylinux2014-x64 + - cross-compiler: manylinux2014-x86 + docker-run-args: --platform linux/386 + steps: + - uses: actions/checkout@v2 + - name: install dockcross + run: | + docker run ${{ matrix.docker-run-args }} --rm dockcross/${{ matrix.cross-compiler }} > ./dockcross + chmod +x ./dockcross + - name: super-test + run: | + ./dockcross ./super-test.sh quick g++ + MacOS: + runs-on: macos-latest + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: | + brew install autoconf automake libtool pkg-config + - name: super-test + run: | + ./super-test.sh quick + MSVC: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ['windows-2019', 'windows-latest'] + include: + - os: windows-2019 + target: 'Visual Studio 16 2019' + arch: -A x64 + - os: windows-latest + target: 'Visual Studio 17 2022' + arch: -A x64 + steps: + - uses: actions/checkout@v2 + - name: Include $CONDA in $PATH + run: | + echo "$Env:CONDA\condabin" >> $env:GITHUB_PATH + - name: Install dependencies via Conda + run: | + conda update -n base -c defaults -q conda + conda install -n base -c defaults -q ninja openssl zlib + - name: Build and test + shell: cmd + run: | + echo "Activate conda base environment" + call activate base + echo "Building Cap'n Proto with ${{ matrix.target }}" + cmake -Hc++ -Bbuild-output ${{ matrix.arch }} -G "${{ matrix.target }}" -DCMAKE_BUILD_TYPE=debug -DCMAKE_PREFIX_PATH="%CONDA_PREFIX%" -DCMAKE_INSTALL_PREFIX=%CD%\capnproto-c++-install + cmake --build build-output --config debug --target install + + echo "Building Cap'n Proto samples with ${{ matrix.target }}" + cmake -Hc++/samples -Bbuild-output-samples ${{ matrix.arch }} -G "${{ matrix.target }}" -DCMAKE_BUILD_TYPE=debug -DCMAKE_PREFIX_PATH=%CD%\capnproto-c++-install + cmake --build build-output-samples --config debug + + cd build-output\src + ctest -V -C debug + MinGW: + runs-on: windows-latest + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v2 + - name: Build and test + shell: cmd + run: | + echo "Deleting broken Postgres install until https://github.com/actions/virtual-environments/issues/1089 is fixed..." + rmdir /s /q C:\PROGRA~1\POSTGR~1 + + echo "Building Cap'n Proto with MinGW" + cmake -Hc++ -Bbuild-output -G "MinGW Makefiles" -DCMAKE_BUILD_TYPE=debug -DCMAKE_INSTALL_PREFIX=%CD%\capnproto-c++-install -DCMAKE_SH="CMAKE_SH-NOTFOUND" -DCMAKE_CXX_STANDARD_LIBRARIES="-static-libgcc -static-libstdc++" + cmake --build build-output --target install -- -j2 + + echo "Building Cap'n Proto samples with MinGW" + cmake -Hc++/samples -Bbuild-output-samples -G "MinGW Makefiles" -DCMAKE_BUILD_TYPE=debug -DCMAKE_PREFIX_PATH=%CD%\capnproto-c++-install -DCMAKE_SH="CMAKE_SH-NOTFOUND" + cmake --build build-output-samples + + cd build-output\src + ctest -V -C debug + # Cygwin: + # runs-on: windows-latest + # strategy: + # fail-fast: false + # steps: + # - run: git config --global core.autocrlf false + # - uses: actions/checkout@v2 + # # TODO(someday): If we could cache the Cygwin installation we wouldn't have to spend three + # # minutes installing it for every build. Unfortuntaley, actions/cache@v1 does not preserve + # # DOS file attributes, which corrupts the Cygwin install. In particular, Cygwin marks + # # symlinks with the "DOS SYSTEM" attribute. We could cache just the downloaded packages, + # # but it turns out that only saves a couple seconds; most of the time is spend unpacking. + # - name: Install Cygwin + # run: | + # choco config get cacheLocation + # choco install --no-progress cygwin + # - name: Install Cygwin additional packages + # shell: cmd + # run: | + # C:\tools\cygwin\cygwinsetup.exe -qgnNdO -R C:/tools/cygwin -l C:/tools/cygwin/packages -s http://mirrors.kernel.org/sourceware/cygwin/ -P autoconf,automake,libtool,gcc,gcc-g++,binutils,libssl-devel,make,zlib-devel,pkg-config,cmake,xxd + # - name: Build and test + # shell: cmd + # run: | + # C:\tools\cygwin\bin\bash -lc 'export PATH=/usr/local/bin:/usr/bin:/bin; cd /cygdrive/d/a/capnproto/capnproto; ./super-test.sh quick' + Linux-bazel-clang: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + clang_version: [16] + steps: + - uses: actions/checkout@v3 + - uses: bazelbuild/setup-bazelisk@v2 + - name: install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get install -y build-essential git + # todo: replace with apt-get when clang-16 is part of ubuntu lts + - name: install clang + uses: egor-tensin/setup-clang@v1 + with: + version: ${{ matrix.clang_version }} + - name: super-test + run: | + cd c++ + bazel test --verbose_failures --test_output=errors //... diff --git a/.github/workflows/release-test.yml b/.github/workflows/release-test.yml new file mode 100644 index 0000000000..bd30a93422 --- /dev/null +++ b/.github/workflows/release-test.yml @@ -0,0 +1,130 @@ +name: Release Tests + +on: + push: + branches: + - master + - 'release-*' + - 'fix-release*' + +jobs: + Linux: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + # We can only run extended tests with the default version of g++, because it has to match + # the version of g++-multilib for 32-bit cross-compilation, and alternate versions of + # g++-multilib generally aren't available. Clang is more lenient, but we might as well be + # consistent. The quick tests should be able to catch issues with older and newer compiler + # versions. + compiler: [g++, clang] + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get update + sudo apt-get install -y build-essential git zlib1g-dev cmake libssl-dev valgrind gcc-multilib g++-multilib ${{ matrix.compiler }} + - name: super-test + run: | + ./super-test.sh ${{ matrix.compiler }} + MacOS: + runs-on: macos-latest + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: | + brew install autoconf automake libtool pkg-config + - name: super-test + run: | + ./super-test.sh + MinGW-Wine: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v2 + # See: https://github.com/actions/virtual-environments/issues/4589#issuecomment-1100899313 + # GitHub's Ubuntu image installs all kinds of stuff from non-Ubuntu repositories which cause + # conflicts with Ubuntu packages ultimately preventing installation of wine32. Let's try to + # fix that... + - name: remove unwanted packages and repositories + run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo apt-get update -qq + sudo apt-get install -yqq --allow-downgrades libgd3/focal libpcre2-8-0/focal libpcre2-16-0/focal libpcre2-32-0/focal libpcre2-posix2/focal + sudo apt-get purge -yqq libmono* moby* mono* php* libgdiplus libpcre2-posix3 libzip4 + - name: install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo dpkg --add-architecture i386 + sudo apt-get update + sudo apt-get install -y build-essential git cmake mingw-w64 wine-stable wine64 wine32 wine-binfmt + sudo update-binfmts --import wine + - name: 64-bit Build and Test + run: | + ./super-test.sh mingw x86_64-w64-mingw32 + - name: 32-bit Build and Test + run: | + ./super-test.sh mingw i686-w64-mingw32 + cmake-packaging: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get install -y build-essential git cmake + - name: autotools-shared + run: | + ./super-test.sh cmake-package autotools-shared + - name: autotools-static + run: | + ./super-test.sh cmake-package autotools-static + - name: cmake-shared + run: | + ./super-test.sh cmake-package cmake-shared + - name: cmake-static + run: | + ./super-test.sh cmake-package cmake-static + Android: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v2 + - name: install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get install -y build-essential git + - name: fetch Android tools + if: steps.cache-android-sdk.outputs.cache-hit != 'true' + run: | + # The installed Android SDK is broken. + unset ANDROID_SDK_ROOT + unset ANDROID_HOME + + mkdir android-sdk + cd android-sdk + curl -o commandlinetools.zip https://dl.google.com/android/repository/commandlinetools-linux-6200805_latest.zip + unzip commandlinetools.zip + (yes || true) | tools/bin/sdkmanager --sdk_root=$PWD platform-tools 'platforms;android-25' 'system-images;android-25;google_apis;armeabi-v7a' emulator 'build-tools;25.0.2' ndk-bundle + - name: 32-bit Build and Test + run: | + # The installed Android SDK is broken. + unset ANDROID_SDK_ROOT + unset ANDROID_HOME + + echo | android-sdk/tools/bin/avdmanager create avd -n capnp -k 'system-images;android-25;google_apis;armeabi-v7a' -b google_apis/armeabi-v7a + + # avdmanager seems to set image.sysdir.1 incorrectly in the AVD's config.ini, which + # causes the emulator to fail. I don't know why. I don't know how to fix it, other than + # to patch the config like so. + sed -i -re 's,^image\.sysdir\.1=android-sdk/,image.sysdir.1=,g' $HOME/.android/avd/capnp.avd/config.ini + + ./super-test.sh android $PWD/android-sdk arm-linux-androideabi armv7a-linux-androideabi24 diff --git a/.gitignore b/.gitignore index f828d4c419..0a8f9e6c08 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ # Ekam build artifacts. /c++/tmp/ /c++/bin/ +/c++/deps/ # setup-ekam.sh /c++/.ekam @@ -59,7 +60,9 @@ /c++/build-aux/ /c++/capnp /c++/capnp-evolution-test -/c++/*.pc +/c++/cmake/CapnProtoConfig.cmake +/c++/cmake/CapnProtoConfigVersion.cmake +/c++/pkgconfig/*.pc /c++/capnp-test /c++/capnpc-c++ /c++/capnpc-capnp @@ -72,6 +75,13 @@ /c++/m4/ltversion.m4 /c++/m4/lt~obsolete.m4 /c++/samples/addressbook +/c++/.cache/ # editor artefacts *~ + +# cross-compiling / glibc testing +/dockcross + +# bazel output +bazel-* diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 61aeeeec3d..0000000000 --- a/.travis.yml +++ /dev/null @@ -1,24 +0,0 @@ -branches: - only: - - master - - /release-.*/ -language: cpp -os: - - linux - - osx -compiler: - - gcc - - clang -dist: trusty -sudo: false -addons: - apt: - packages: - - automake - - autoconf - - libtool - - pkg-config -before_install: - - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew update; fi - - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew install automake autoconf libtool; fi -script: ./super-test.sh -j2 quick # limit parallelism due to limited memory on Travis diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000..4f4ce16a03 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,5 @@ +cmake_minimum_required(VERSION 3.4) +project("Cap'n Proto Root" CXX) +include(CTest) + +add_subdirectory(c++) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 31fceb940e..24902a9404 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -2,7 +2,7 @@ The following people have made large code contributions to this repository. Those contributions are copyright the respective authors and licensed by them under the same MIT license terms as the rest of the library. -Kenton Varda : Primary Author +Kenton Varda : Primary Author Jason Choy : kj/threadlocal.h and other iOS tweaks, `name` annotation in C++ code generator Remy Blank (contributions copyright Google Inc.): KJ Timers Joshua Warner : cmake build, AnyStruct/AnyList, other stuff @@ -17,6 +17,7 @@ Harris Hancock : MSVC support Branislav Katreniak : JSON decode Matthew Maurer : Canonicalization Support David Renshaw : bugfixes and miscellaneous maintenance +Ingvar Stepanyan : Custom handlers for JSON decode This file does not list people who maintain their own Cap'n Proto implementations as separate projects. Those people are awesome too! :) diff --git a/LICENSE b/LICENSE index d5b6230ad8..1eabc941f6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,7 @@ -Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +Copyright (c) 2013-2017 Sandstorm Development Group, Inc.; Cloudflare, Inc.; +and other contributors. Each commit is copyright by its respective author or +author's employer. + Licensed under the MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/README.md b/README.md index 3c5c27aef3..895233f191 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,10 @@ - +[![Quick Tests](https://github.com/capnproto/capnproto/workflows/Quick%20Tests/badge.svg?branch=master&event=push)](https://github.com/capnproto/capnproto/actions?query=workflow%3A%22Quick+Tests%22) +[![Release Tests](https://github.com/capnproto/capnproto/workflows/Release%20Tests/badge.svg?branch=master&event=push)](https://github.com/capnproto/capnproto/actions?query=workflow%3A%22Release+Tests%22) + + Cap'n Proto is an insanely fast data interchange format and capability-based RPC system. Think -JSON, except binary. Or think [Protocol Buffers](http://protobuf.googlecode.com), except faster. +JSON, except binary. Or think [Protocol Buffers](https://github.com/google/protobuf), except faster. In fact, in benchmarks, Cap'n Proto is INFINITY TIMES faster than Protocol Buffers. -[Read more...](http://kentonv.github.com/capnproto/) +[Read more...](http://kentonv.github.io/capnproto/) diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 33a1fd7ac8..0000000000 --- a/appveyor.yml +++ /dev/null @@ -1,87 +0,0 @@ -# Cap'n Proto AppVeyor configuration -# -# See https://www.appveyor.com/docs/appveyor-yml/ for configuration options. -# -# This script configures AppVeyor to: -# - Download and unzip MinGW-w64 4.8.5 for x86_64, the minimum gcc version Cap'n Proto advertises -# support for. -# - Use CMake to ... -# build Cap'n Proto with MinGW. -# build Cap'n Proto with VS2015 and VS2017. -# build Cap'n Proto samples with VS2015 and VS2017. - -version: "{build}" - -branches: - only: - - master - - /release-.*/ -# Don't build non-master branches (unless they open a pull request). - -image: Visual Studio 2017 -# AppVeyor build worker image (VM template). - -shallow_clone: true -# Fetch repository as zip archive. - -cache: - - x86_64-4.8.5-release-win32-seh-rt_v4-rev0.7z - -environment: - MINGW_DIR: mingw64 - MINGW_URL: https://sourceforge.net/projects/mingw-w64/files/Toolchains%20targetting%20Win64/Personal%20Builds/mingw-builds/4.8.5/threads-win32/seh/x86_64-4.8.5-release-win32-seh-rt_v4-rev0.7z/download - MINGW_ARCHIVE: x86_64-4.8.5-release-win32-seh-rt_v4-rev0.7z - BUILD_TYPE: debug - - matrix: - # TODO(someday): Add MSVC x64 builds, MinGW x86 build? - - - CMAKE_GENERATOR: Visual Studio 15 2017 - BUILD_NAME: vs2017 - EXTRA_BUILD_FLAGS: # /maxcpucount - # TODO(someday): Right now /maxcpucount occasionally expresses a filesystem-related race: - # capnp-capnpc++ complains that it can't create test.capnp.h. - - - CMAKE_GENERATOR: Visual Studio 14 2015 - BUILD_NAME: vs2015 - EXTRA_BUILD_FLAGS: # /maxcpucount - - - CMAKE_GENERATOR: MinGW Makefiles - BUILD_NAME: mingw - EXTRA_BUILD_FLAGS: -j2 - -install: - - if not exist "%MINGW_ARCHIVE%" appveyor DownloadFile "%MINGW_URL%" -FileName "%MINGW_ARCHIVE%" - - 7z x -y "%MINGW_ARCHIVE%" > nul - - ps: Get-Command sh.exe -All | Remove-Item - # CMake refuses to generate MinGW Makefiles if sh.exe is in the PATH - -before_build: - - set PATH=%CD%\%MINGW_DIR%\bin;%PATH% - - set BUILD_DIR=build-%BUILD_NAME% - - set INSTALL_PREFIX=%CD%\capnproto-c++-%BUILD_NAME% - - cmake --version - -build_script: - - echo "Building Cap'n Proto with %CMAKE_GENERATOR%" - - >- - cmake -Hc++ -B%BUILD_DIR% -G "%CMAKE_GENERATOR%" - -DCMAKE_BUILD_TYPE=%BUILD_TYPE% - -DCMAKE_INSTALL_PREFIX=%INSTALL_PREFIX% - - cmake --build %BUILD_DIR% --config %BUILD_TYPE% --target install -- %EXTRA_BUILD_FLAGS% - # MinGW wants the build type at configure-time while MSVC wants the build type at build-time. We - # can satisfy both by passing the build type to both cmake invocations. We have to suffer a - # warning, but both generators will work. - - - echo "Building Cap'n Proto samples with %CMAKE_GENERATOR%" - - >- - cmake -Hc++/samples -B%BUILD_DIR%-samples -G "%CMAKE_GENERATOR%" - -DCMAKE_BUILD_TYPE=%BUILD_TYPE% - -DCMAKE_PREFIX_PATH=%INSTALL_PREFIX% - - cmake --build %BUILD_DIR%-samples --config %BUILD_TYPE% - -test_script: - - timeout /t 2 - # Sleep a little to prevent interleaving test output with build output. - - cd %BUILD_DIR%\src - - ctest -V -C %BUILD_TYPE% diff --git a/c++/.bazelignore b/c++/.bazelignore new file mode 100644 index 0000000000..b67bdf49d4 --- /dev/null +++ b/c++/.bazelignore @@ -0,0 +1 @@ +ekam-provider \ No newline at end of file diff --git a/c++/.bazelrc b/c++/.bazelrc new file mode 100644 index 0000000000..cab5c8c13c --- /dev/null +++ b/c++/.bazelrc @@ -0,0 +1,26 @@ +common --enable_platform_specific_config + +build:unix --cxxopt='-std=c++14' --host_cxxopt='-std=c++14' --force_pic --verbose_failures +build:unix --cxxopt='-Wall' --host_cxxopt='-Wall' +build:unix --cxxopt='-Wextra' --host_cxxopt='-Wextra' +build:unix --cxxopt='-Wno-strict-aliasing' --host_cxxopt='-Wno-strict-aliasing' +build:unix --cxxopt='-Wno-sign-compare' --host_cxxopt='-Wno-sign-compare' +build:unix --cxxopt='-Wno-unused-parameter' --host_cxxopt='-Wno-unused-parameter' + +build:linux --config=unix +build:macos --config=unix + +# See https://bazel.build/configure/windows#symlink +startup --windows_enable_symlinks +# We use LLVM's MSVC-compatible compiler driver to compile our code on Windows +# under Bazel. MSVC is natively supported when using CMake builds. +build:windows --compiler=clang-cl + +build:windows --cxxopt='/std:c++14' --host_cxxopt='/std:c++14' --verbose_failures +build:windows --cxxopt='/wo4503' --host_cxxopt='/wo4503' +# The `/std:c++14` argument is unused during boringssl compilation and we don't +# want a warning when compiling each file. +build:windows --cxxopt='-Wno-unused-command-line-argument' --host_cxxopt='-Wno-unused-command-line-argument' + +# build with ssl, zlib and bazel by default +build --//src/kj:openssl=True --//src/kj:zlib=True --//src/kj:brotli=True diff --git a/c++/.bazelversion b/c++/.bazelversion new file mode 100644 index 0000000000..5e3254243a --- /dev/null +++ b/c++/.bazelversion @@ -0,0 +1 @@ +6.1.2 diff --git a/c++/BUILD.bazel b/c++/BUILD.bazel new file mode 100644 index 0000000000..e69de29bb2 diff --git a/c++/CMakeLists.txt b/c++/CMakeLists.txt index a0ebd85898..02e7da7a11 100644 --- a/c++/CMakeLists.txt +++ b/c++/CMakeLists.txt @@ -1,18 +1,19 @@ +cmake_minimum_required(VERSION 3.6) project("Cap'n Proto" CXX) -cmake_minimum_required(VERSION 3.1) -set(VERSION 0.6.1) +set(VERSION 1.1-dev) -set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +include(CTest) include(CheckIncludeFileCXX) include(GNUInstallDirs) if(MSVC) - check_include_file_cxx(initializer_list HAS_CXX11) + check_include_file_cxx(initializer_list HAS_CXX14) else() - check_include_file_cxx(initializer_list HAS_CXX11 "-std=gnu++0x") + check_include_file_cxx(initializer_list HAS_CXX14 "-std=gnu++1y") endif() -if(NOT HAS_CXX11) - message(SEND_ERROR "Requires a C++11 compiler and standard library.") +if(NOT HAS_CXX14) + message(SEND_ERROR "Requires a C++14 compiler and standard library.") endif() # these arguments are passed to all install(TARGETS) calls @@ -25,7 +26,6 @@ set(INSTALL_TARGETS_DEFAULT_ARGS # Options ====================================================================== -option(BUILD_TESTING "Build unit tests and enable CTest 'check' target." ON) option(EXTERNAL_CAPNP "Use the system capnp binary, or the one specified in $CAPNP, instead of using the compiled one." OFF) option(CAPNP_LITE "Compile Cap'n Proto in 'lite mode', in which all reflection APIs (schema.h, dynamic.h, etc.) are not included. Produces a smaller library at the cost of features. All programs built against the library must be compiled with -DCAPNP_LITE. Requires EXTERNAL_CAPNP." OFF) @@ -41,6 +41,95 @@ else() set(CAPNP_LITE_FLAG) endif() +set(WITH_OPENSSL "AUTO" CACHE STRING + "Whether or not to build libkj-tls by linking against openssl") +# define list of values GUI will offer for the variable +set_property(CACHE WITH_OPENSSL PROPERTY STRINGS AUTO ON OFF) + +set(WITH_ZLIB "AUTO" CACHE STRING + "Whether or not to build libkj-gzip by linking against zlib") +set_property(CACHE WITH_ZLIB PROPERTY STRINGS AUTO ON OFF) + +# shadow cache variable original value with ON/OFF, +# so from now on OpenSSL-specific code just has to check: +# if (WITH_OPENSSL) +# ... +# endif() +if (CAPNP_LITE) + set(WITH_OPENSSL OFF) +elseif (WITH_OPENSSL STREQUAL "AUTO") + find_package(OpenSSL COMPONENTS Crypto SSL) + if (OPENSSL_FOUND) + set(WITH_OPENSSL ON) + else() + set(WITH_OPENSSL OFF) + endif() +elseif (WITH_OPENSSL) + find_package(OpenSSL REQUIRED COMPONENTS Crypto SSL) +endif() + +# shadow cache variable original value with ON/OFF, +# so from now on ZLIB-specific code just has to check: +# if (WITH_ZLIB) +# ... +# endif() +if(CAPNP_LITE) + set(WITH_ZLIB OFF) +elseif (WITH_ZLIB STREQUAL "AUTO") + find_package(ZLIB) + if(ZLIB_FOUND) + set(WITH_ZLIB ON) + else() + set(WITH_ZLIB OFF) + endif() +elseif (WITH_ZLIB) + find_package(ZLIB REQUIRED) +endif() + +set(WITH_FIBERS "AUTO" CACHE STRING + "Whether or not to build libkj-async with fibers") +# define list of values GUI will offer for the variable +set_property(CACHE WITH_FIBERS PROPERTY STRINGS AUTO ON OFF) + +# CapnProtoConfig.cmake.in needs this variable. +set(_WITH_LIBUCONTEXT OFF) + +if (WITH_FIBERS OR WITH_FIBERS STREQUAL "AUTO") + set(_capnp_fibers_found OFF) + if (WIN32 OR CYGWIN) + set(_capnp_fibers_found ON) + else() + # Fibers need makecontext, setcontext, getcontext, swapcontext that may be in libc, + # or in libucontext (e.g. for musl). + # We assume that makecontext implies that the others are present. + include(CheckLibraryExists) + check_library_exists(c makecontext "" HAVE_UCONTEXT_LIBC) + if (HAVE_UCONTEXT_LIBC) + set(_capnp_fibers_found ON) + else() + # Try with libucontext + find_package(PkgConfig) + if (PKG_CONFIG_FOUND) + pkg_check_modules(libucontext IMPORTED_TARGET libucontext) + if (libucontext_FOUND) + set(_WITH_LIBUCONTEXT ON) + set(_capnp_fibers_found ON) + endif() + else() + set(_capnp_fibers_found OFF) + endif() + endif() + endif() + + if (_capnp_fibers_found) + set(WITH_FIBERS ON) + elseif(WITH_FIBERS STREQUAL "AUTO") + set(WITH_FIBERS OFF) + else() + message(FATAL_ERROR "Missing 'makecontext', 'getcontext', 'setcontext' or 'swapcontext' symbol in libc and no libucontext found: KJ won't be able to build with fibers. Disable fibers (-DWITH_FIBERS=OFF).") + endif() +endif() + if(MSVC) # TODO(cleanup): Enable higher warning level in MSVC, but make sure to test # build with that warning level and clean out false positives. @@ -81,11 +170,17 @@ add_subdirectory(src) # Install ====================================================================== include(CMakePackageConfigHelpers) -write_basic_package_version_file( - "${CMAKE_CURRENT_BINARY_DIR}/cmake/CapnProtoConfigVersion.cmake" - VERSION ${VERSION} - COMPATIBILITY AnyNewerVersion -) + +# We used to use write_basic_package_version_file(), but since the autotools build needs to install +# a config version script as well, I copied the AnyNewerVersion template from my CMake Modules +# directory to Cap'n Proto's cmake/ directory (alternatively, we could make the autotools build +# depend on CMake). +# +# We might as well use the local copy of the template. In the future we can modify the project's +# version compatibility policy just by changing that file. +set(PACKAGE_VERSION ${VERSION}) +configure_file(cmake/CapnProtoConfigVersion.cmake.in cmake/CapnProtoConfigVersion.cmake @ONLY) + set(CONFIG_PACKAGE_LOCATION ${CMAKE_INSTALL_LIBDIR}/cmake/CapnProto) configure_package_config_file(cmake/CapnProtoConfig.cmake.in @@ -120,23 +215,30 @@ if(NOT MSVC) # Don't install pkg-config files when building with MSVC set(PTHREAD_CFLAGS "-pthread") set(STDLIB_FLAG) # TODO: Unsupported - configure_file(kj.pc.in "${CMAKE_CURRENT_BINARY_DIR}/kj.pc" @ONLY) - install(FILES "${CMAKE_CURRENT_BINARY_DIR}/kj.pc" DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") - - configure_file(capnp.pc.in "${CMAKE_CURRENT_BINARY_DIR}/capnp.pc" @ONLY) - install(FILES "${CMAKE_CURRENT_BINARY_DIR}/capnp.pc" DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + set(CAPNP_PKG_CONFIG_FILES + pkgconfig/kj.pc + pkgconfig/capnp.pc + pkgconfig/capnpc.pc + ) if(NOT CAPNP_LITE) - configure_file(kj-async.pc.in "${CMAKE_CURRENT_BINARY_DIR}/kj-async.pc" @ONLY) - install(FILES "${CMAKE_CURRENT_BINARY_DIR}/kj-async.pc" DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") - - configure_file(capnp-rpc.pc.in "${CMAKE_CURRENT_BINARY_DIR}/capnp-rpc.pc" @ONLY) - install(FILES "${CMAKE_CURRENT_BINARY_DIR}/capnp-rpc.pc" DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") - - configure_file(capnp-json.pc.in "${CMAKE_CURRENT_BINARY_DIR}/capnp-json.pc" @ONLY) - install(FILES "${CMAKE_CURRENT_BINARY_DIR}/capnp-json.pc" DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + list(APPEND CAPNP_PKG_CONFIG_FILES + pkgconfig/kj-async.pc + pkgconfig/kj-gzip.pc + pkgconfig/kj-http.pc + pkgconfig/kj-test.pc + pkgconfig/kj-tls.pc + pkgconfig/capnp-rpc.pc + pkgconfig/capnp-websocket.pc + pkgconfig/capnp-json.pc + ) endif() + foreach(pcfile ${CAPNP_PKG_CONFIG_FILES}) + configure_file(${pcfile}.in "${CMAKE_CURRENT_BINARY_DIR}/${pcfile}" @ONLY) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/${pcfile}" DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + endforeach() + unset(STDLIB_FLAG) unset(PTHREAD_CFLAGS) unset(includedir) diff --git a/c++/Makefile.am b/c++/Makefile.am index 4f595ae54c..1567491d4d 100644 --- a/c++/Makefile.am +++ b/c++/Makefile.am @@ -2,7 +2,11 @@ ACLOCAL_AMFLAGS = -I m4 -AUTOMAKE_OPTIONS = foreign subdir-objects +# We use serial-tests so that test output will be written directly to stdout +# which is much preferred in CI environments where the test logs may be hard +# to get at after the fact. Most of our tests are bundled into a single +# executable anyway so cannot easily be parallelized. +AUTOMAKE_OPTIONS = foreign subdir-objects serial-tests # When running distcheck, verify that we've included all the files needed by # the cmake build. @@ -22,19 +26,27 @@ EXTRA_DIST = \ src/capnp/compiler/capnp-test.sh \ src/capnp/testdata/segmented-packed \ src/capnp/testdata/errors.capnp.nobuild \ + src/capnp/testdata/errors2.capnp.nobuild \ + src/capnp/testdata/no-file-id.capnp.nobuild \ src/capnp/testdata/short.txt \ src/capnp/testdata/flat \ src/capnp/testdata/binary \ src/capnp/testdata/errors.txt \ + src/capnp/testdata/errors2.txt \ src/capnp/testdata/segmented \ src/capnp/testdata/packed \ src/capnp/testdata/pretty.txt \ src/capnp/testdata/lists.binary \ src/capnp/testdata/packedflat \ + src/capnp/testdata/pretty.json \ + src/capnp/testdata/short.json \ + src/capnp/testdata/annotated-json.binary \ + src/capnp/testdata/annotated.json \ CMakeLists.txt \ - cmake/FindCapnProto.cmake \ - cmake/CapnProtoConfig.cmake.in \ cmake/CapnProtoMacros.cmake \ + cmake/CapnProtoTargets.cmake \ + cmake/CapnProtoConfig.cmake.in \ + cmake/CapnProtoConfigVersion.cmake.in \ src/CMakeLists.txt \ src/kj/CMakeLists.txt \ src/capnp/CMakeLists.txt @@ -76,21 +88,24 @@ maintainer-clean-local: public_capnpc_inputs = \ src/capnp/c++.capnp \ src/capnp/schema.capnp \ + src/capnp/stream.capnp \ src/capnp/rpc.capnp \ src/capnp/rpc-twoparty.capnp \ - src/capnp/persistent.capnp \ - src/capnp/compat/json.capnp + src/capnp/persistent.capnp capnpc_inputs = \ $(public_capnpc_inputs) \ src/capnp/compiler/lexer.capnp \ - src/capnp/compiler/grammar.capnp + src/capnp/compiler/grammar.capnp \ + src/capnp/compat/json.capnp capnpc_outputs = \ src/capnp/c++.capnp.c++ \ src/capnp/c++.capnp.h \ src/capnp/schema.capnp.c++ \ src/capnp/schema.capnp.h \ + src/capnp/stream.capnp.c++ \ + src/capnp/stream.capnp.h \ src/capnp/rpc.capnp.c++ \ src/capnp/rpc.capnp.h \ src/capnp/rpc-twoparty.capnp.c++ \ @@ -112,12 +127,19 @@ includekjstddir = $(includekjdir)/std includekjcompatdir = $(includekjdir)/compat dist_includecapnp_DATA = $(public_capnpc_inputs) +dist_includecapnpcompat_DATA = src/capnp/compat/json.capnp pkgconfigdir = $(libdir)/pkgconfig -pkgconfig_DATA = capnp.pc capnp-rpc.pc kj.pc kj-async.pc +pkgconfig_DATA = $(CAPNP_PKG_CONFIG_FILES) + +cmakeconfigdir = $(libdir)/cmake/CapnProto +cmakeconfig_DATA = $(CAPNP_CMAKE_CONFIG_FILES) \ + cmake/CapnProtoMacros.cmake \ + cmake/CapnProtoTargets.cmake noinst_HEADERS = \ - src/kj/miniposix.h + src/kj/miniposix.h \ + src/kj/async-io-internal.h includekj_HEADERS = \ src/kj/common.h \ @@ -125,9 +147,14 @@ includekj_HEADERS = \ src/kj/memory.h \ src/kj/refcount.h \ src/kj/array.h \ + src/kj/list.h \ src/kj/vector.h \ src/kj/string.h \ src/kj/string-tree.h \ + src/kj/hash.h \ + src/kj/table.h \ + src/kj/map.h \ + src/kj/encoding.h \ src/kj/exception.h \ src/kj/debug.h \ src/kj/arena.h \ @@ -136,17 +163,23 @@ includekj_HEADERS = \ src/kj/one-of.h \ src/kj/function.h \ src/kj/mutex.h \ + src/kj/source-location.h \ src/kj/thread.h \ src/kj/threadlocal.h \ + src/kj/filesystem.h \ src/kj/async-prelude.h \ src/kj/async.h \ src/kj/async-inl.h \ src/kj/time.h \ + src/kj/timer.h \ src/kj/async-unix.h \ src/kj/async-win32.h \ src/kj/async-io.h \ + src/kj/cidr.h \ + src/kj/async-queue.h \ src/kj/main.h \ src/kj/test.h \ + src/kj/win32-api-version.h \ src/kj/windows-sanity.h includekjparse_HEADERS = \ @@ -158,7 +191,11 @@ includekjstd_HEADERS = \ includekjcompat_HEADERS = \ src/kj/compat/gtest.h \ - src/kj/compat/http.h + src/kj/compat/url.h \ + src/kj/compat/http.h \ + src/kj/compat/gzip.h \ + src/kj/compat/readiness-io.h \ + src/kj/compat/tls.h includecapnp_HEADERS = \ src/capnp/c++.capnp.h \ @@ -173,6 +210,7 @@ includecapnp_HEADERS = \ src/capnp/capability.h \ src/capnp/membrane.h \ src/capnp/schema.capnp.h \ + src/capnp/stream.capnp.h \ src/capnp/schema-lite.h \ src/capnp/schema.h \ src/capnp/schema-loader.h \ @@ -196,34 +234,61 @@ includecapnp_HEADERS = \ includecapnpcompat_HEADERS = \ src/capnp/compat/json.h \ - src/capnp/compat/json.capnp.h + src/capnp/compat/json.capnp.h \ + src/capnp/compat/std-iterator.h \ + src/capnp/compat/websocket-rpc.h + +if BUILD_KJ_TLS +MAYBE_KJ_TLS_LA=libkj-tls.la +MAYBE_KJ_TLS_TESTS= \ + src/kj/compat/readiness-io-test.c++ \ + src/kj/compat/tls-test.c++ +else +MAYBE_KJ_TLS_LA= +MAYBE_KJ_TLS_TESTS= +endif + +if BUILD_KJ_GZIP +MAYBE_KJ_GZIP_LA=libkj-gzip.la +MAYBE_KJ_GZIP_TESTS= \ + src/kj/compat/gzip-test.c++ +else +MAYBE_KJ_TLS_LA= +MAYBE_KJ_TLS_TESTS= +endif if LITE_MODE lib_LTLIBRARIES = libkj.la libkj-test.la libcapnp.la else -lib_LTLIBRARIES = libkj.la libkj-test.la libkj-async.la libkj-http.la libcapnp.la libcapnp-rpc.la libcapnp-json.la libcapnpc.la +lib_LTLIBRARIES = libkj.la libkj-test.la libkj-async.la libkj-http.la $(MAYBE_KJ_TLS_LA) $(MAYBE_KJ_GZIP_LA) libcapnp.la libcapnp-rpc.la libcapnp-json.la libcapnp-websocket.la libcapnpc.la endif -# Don't include security release in soname -- we want to replace old binaries -# in this case. -SO_VERSION = $(shell echo $(VERSION) | sed -e 's/^\([0-9]*[.][0-9]*[.][0-9]*\)\([.][0-9]*\)*\(-.*\)*$$/\1\3/g') - libkj_la_LIBADD = $(PTHREAD_LIBS) libkj_la_LDFLAGS = -release $(SO_VERSION) -no-undefined libkj_la_SOURCES= \ + src/kj/cidr.c++ \ src/kj/common.c++ \ src/kj/units.c++ \ src/kj/memory.c++ \ src/kj/refcount.c++ \ src/kj/array.c++ \ + src/kj/list.c++ \ src/kj/string.c++ \ src/kj/string-tree.c++ \ + src/kj/source-location.c++ \ + src/kj/hash.c++ \ + src/kj/table.c++ \ + src/kj/encoding.c++ \ src/kj/exception.c++ \ src/kj/debug.c++ \ src/kj/arena.c++ \ src/kj/io.c++ \ src/kj/mutex.c++ \ src/kj/thread.c++ \ + src/kj/time.c++ \ + src/kj/filesystem.c++ \ + src/kj/filesystem-disk-unix.c++ \ + src/kj/filesystem-disk-win32.c++ \ src/kj/test-helpers.c++ \ src/kj/main.c++ \ src/kj/parse/char.c++ @@ -242,12 +307,33 @@ libkj_async_la_SOURCES= \ src/kj/async-io.c++ \ src/kj/async-io-unix.c++ \ src/kj/async-io-win32.c++ \ - src/kj/time.c++ + src/kj/timer.c++ +if BUILD_KJ_GZIP +libkj_http_la_LIBADD = libkj-async.la libkj.la -lz $(ASYNC_LIBS) $(PTHREAD_LIBS) +libkj_http_la_LDFLAGS = -release $(SO_VERSION) -no-undefined +libkj_http_la_SOURCES= \ + src/kj/compat/url.c++ \ + src/kj/compat/http.c++ +else libkj_http_la_LIBADD = libkj-async.la libkj.la $(ASYNC_LIBS) $(PTHREAD_LIBS) libkj_http_la_LDFLAGS = -release $(SO_VERSION) -no-undefined libkj_http_la_SOURCES= \ + src/kj/compat/url.c++ \ src/kj/compat/http.c++ +endif + +libkj_tls_la_LIBADD = libkj-async.la libkj.la -lssl -lcrypto $(ASYNC_LIBS) $(PTHREAD_LIBS) +libkj_tls_la_LDFLAGS = -release $(SO_VERSION) -no-undefined +libkj_tls_la_SOURCES= \ + src/kj/compat/readiness-io.c++ \ + src/kj/compat/tls.c++ + +libkj_gzip_la_LIBADD = libkj-async.la libkj.la -lz $(ASYNC_LIBS) $(PTHREAD_LIBS) +libkj_gzip_la_LDFLAGS = -release $(SO_VERSION) -no-undefined +libkj_gzip_la_SOURCES= \ + src/kj/compat/gzip.c++ + endif !LITE_MODE if !LITE_MODE @@ -270,6 +356,7 @@ libcapnp_la_SOURCES= \ src/capnp/any.c++ \ src/capnp/message.c++ \ src/capnp/schema.capnp.c++ \ + src/capnp/stream.capnp.c++ \ src/capnp/serialize.c++ \ src/capnp/serialize-packed.c++ \ $(heavy_sources) @@ -296,11 +383,16 @@ libcapnp_json_la_SOURCES= \ src/capnp/compat/json.c++ \ src/capnp/compat/json.capnp.c++ +libcapnp_websocket_la_LIBADD = libcapnp.la libcapnp-rpc.la libkj.la libkj-async.la libkj-http.la $(PTHREAD_LIBS) +libcapnp_websocket_la_LDFLAGS = -release $(SO_VERSION) -no-undefined +libcapnp_websocket_la_SOURCES= \ + src/capnp/compat/websocket-rpc.c++ + libcapnpc_la_LIBADD = libcapnp.la libkj.la $(PTHREAD_LIBS) libcapnpc_la_LDFLAGS = -release $(SO_VERSION) -no-undefined libcapnpc_la_SOURCES= \ - src/capnp/compiler/md5.h \ - src/capnp/compiler/md5.c++ \ + src/capnp/compiler/type-id.h \ + src/capnp/compiler/type-id.c++ \ src/capnp/compiler/error-reporter.h \ src/capnp/compiler/error-reporter.c++ \ src/capnp/compiler/lexer.capnp.h \ @@ -311,6 +403,9 @@ libcapnpc_la_SOURCES= \ src/capnp/compiler/grammar.capnp.c++ \ src/capnp/compiler/parser.h \ src/capnp/compiler/parser.c++ \ + src/capnp/compiler/resolver.h \ + src/capnp/compiler/generics.h \ + src/capnp/compiler/generics.c++ \ src/capnp/compiler/node-translator.h \ src/capnp/compiler/node-translator.c++ \ src/capnp/compiler/compiler.h \ @@ -320,7 +415,7 @@ libcapnpc_la_SOURCES= \ bin_PROGRAMS = capnp capnpc-capnp capnpc-c++ -capnp_LDADD = libcapnpc.la libcapnp.la libkj.la $(PTHREAD_LIBS) +capnp_LDADD = libcapnpc.la libcapnp-json.la libcapnp.la libkj.la $(PTHREAD_LIBS) capnp_SOURCES = \ src/capnp/compiler/module-loader.h \ src/capnp/compiler/module-loader.c++ \ @@ -338,10 +433,15 @@ capnpc_c___SOURCES = src/capnp/compiler/capnpc-c++.c++ # Also attempt to run ldconfig, because otherwise users get confused. If # it fails (e.g. because the platform doesn't have it, or because the # user doesn't have root privileges), don't worry about it. +# +# We need to specify the path for OpenBSD. install-exec-hook: ln -sf capnp $(DESTDIR)$(bindir)/capnpc - ldconfig < /dev/null > /dev/null 2>&1 || true - + if [ `uname` == 'OpenBSD' ]; then \ + (ldconfig /usr/local/lib /usr/lib /usr/X11R6/lib > /dev/null 2>&1 || true); \ + else \ + ldconfig < /dev/null > /dev/null 2>&1 || true; \ + fi uninstall-hook: rm -f $(DESTDIR)$(bindir)/capnpc @@ -362,7 +462,8 @@ endif LITE_MODE test_capnpc_inputs = \ src/capnp/test.capnp \ src/capnp/test-import.capnp \ - src/capnp/test-import2.capnp + src/capnp/test-import2.capnp \ + src/capnp/compat/json-test.capnp test_capnpc_outputs = \ src/capnp/test.capnp.c++ \ @@ -370,20 +471,22 @@ test_capnpc_outputs = \ src/capnp/test-import.capnp.c++ \ src/capnp/test-import.capnp.h \ src/capnp/test-import2.capnp.c++ \ - src/capnp/test-import2.capnp.h + src/capnp/test-import2.capnp.h \ + src/capnp/compat/json-test.capnp.c++ \ + src/capnp/compat/json-test.capnp.h if USE_EXTERNAL_CAPNP test_capnpc_middleman: $(test_capnpc_inputs) @$(MKDIR_P) src - $(CAPNP) compile --src-prefix=$(srcdir)/src -o$(CAPNPC_CXX):src -I$(srcdir)/src $^ + $(CAPNP) compile --src-prefix=$(srcdir)/src -o$(CAPNPC_CXX):src -I$(srcdir)/src $$(for FILE in $(test_capnpc_inputs); do echo $(srcdir)/$$FILE; done) touch test_capnpc_middleman else test_capnpc_middleman: capnp$(EXEEXT) capnpc-c++$(EXEEXT) $(test_capnpc_inputs) @$(MKDIR_P) src - echo $^ | (read CAPNP CAPNPC_CXX SOURCES && ./$$CAPNP compile --src-prefix=$(srcdir)/src -o./$$CAPNPC_CXX:src -I$(srcdir)/src $$SOURCES) + ./capnp$(EXEEXT) compile --src-prefix=$(srcdir)/src -o./capnpc-c++$(EXEEXT):src -I$(srcdir)/src $$(for FILE in $(test_capnpc_inputs); do echo $(srcdir)/$$FILE; done) touch test_capnpc_middleman endif @@ -407,15 +510,26 @@ capnp_test_LDADD = libcapnp-test.a libcapnp.la libkj-test.la libkj.la else !LITE_MODE check_PROGRAMS = capnp-test capnp-evolution-test capnp-afl-testcase +if HAS_FUZZING_ENGINE + check_PROGRAMS += capnp-llvm-fuzzer-testcase +endif heavy_tests = \ src/kj/async-test.c++ \ + src/kj/async-xthread-test.c++ \ + src/kj/async-coroutine-test.c++ \ src/kj/async-unix-test.c++ \ + src/kj/async-unix-xthread-test.c++ \ src/kj/async-win32-test.c++ \ + src/kj/async-win32-xthread-test.c++ \ src/kj/async-io-test.c++ \ + src/kj/async-queue-test.c++ \ src/kj/parse/common-test.c++ \ src/kj/parse/char-test.c++ \ src/kj/std/iostream-test.c++ \ + src/kj/compat/url-test.c++ \ src/kj/compat/http-test.c++ \ + $(MAYBE_KJ_GZIP_TESTS) \ + $(MAYBE_KJ_TLS_TESTS) \ src/capnp/canonicalize-test.c++ \ src/capnp/capability-test.c++ \ src/capnp/membrane-test.c++ \ @@ -430,18 +544,24 @@ heavy_tests = \ src/capnp/rpc-twoparty-test.c++ \ src/capnp/ez-rpc-test.c++ \ src/capnp/compat/json-test.c++ \ + src/capnp/compat/websocket-rpc-test.c++ \ src/capnp/compiler/lexer-test.c++ \ - src/capnp/compiler/md5-test.c++ + src/capnp/compiler/type-id-test.c++ capnp_test_LDADD = \ libcapnp-test.a \ libcapnpc.la \ libcapnp-rpc.la \ + libcapnp-websocket.la \ libcapnp-json.la \ libcapnp.la \ libkj-http.la \ + $(MAYBE_KJ_GZIP_LA) \ + $(MAYBE_KJ_TLS_LA) \ libkj-async.la \ libkj-test.la \ - libkj.la + libkj.la \ + $(ASYNC_LIBS) \ + $(PTHREAD_LIBS) endif !LITE_MODE @@ -451,8 +571,12 @@ capnp_test_SOURCES = \ src/kj/memory-test.c++ \ src/kj/refcount-test.c++ \ src/kj/array-test.c++ \ + src/kj/list-test.c++ \ src/kj/string-test.c++ \ src/kj/string-tree-test.c++ \ + src/kj/table-test.c++ \ + src/kj/map-test.c++ \ + src/kj/encoding-test.c++ \ src/kj/exception-test.c++ \ src/kj/debug-test.c++ \ src/kj/arena-test.c++ \ @@ -462,8 +586,10 @@ capnp_test_SOURCES = \ src/kj/function-test.c++ \ src/kj/io-test.c++ \ src/kj/mutex-test.c++ \ + src/kj/time-test.c++ \ src/kj/threadlocal-test.c++ \ - src/kj/threadlocal-pthread-test.c++ \ + src/kj/filesystem-test.c++ \ + src/kj/filesystem-disk-test.c++ \ src/kj/test-test.c++ \ src/capnp/common-test.c++ \ src/capnp/blob-test.c++ \ @@ -486,6 +612,12 @@ capnp_evolution_test_SOURCES = src/capnp/compiler/evolution-test.c++ capnp_afl_testcase_LDADD = libcapnp-test.a libcapnp-rpc.la libcapnp.la libkj.la libkj-async.la capnp_afl_testcase_SOURCES = src/capnp/afl-testcase.c++ + +if HAS_FUZZING_ENGINE + capnp_llvm_fuzzer_testcase_LDADD = libcapnp-test.a libcapnp-rpc.la libcapnp.la libkj.la libkj-async.la + capnp_llvm_fuzzer_testcase_SOURCES = src/capnp/llvm-fuzzer-testcase.c++ + capnp_llvm_fuzzer_testcase_LDFLAGS = $(LIB_FUZZING_ENGINE) +endif endif !LITE_MODE if LITE_MODE diff --git a/c++/Makefile.ekam b/c++/Makefile.ekam index ece8d4a96d..04c4220512 100644 --- a/c++/Makefile.ekam +++ b/c++/Makefile.ekam @@ -16,19 +16,19 @@ all: echo "You probably accidentally told Eclipse to build. Stopping." once: - CXXFLAGS="$(EXTRA_FLAG) -std=c++11 -O2 -DNDEBUG -Wall" LIBS='-lz -pthread' $(EKAM) -j6 + CXXFLAGS="$(EXTRA_FLAG) -std=c++14 -O2 -DNDEBUG -Wall" LIBS='-lz -pthread' $(EKAM) -j6 continuous: - CXXFLAGS="$(EXTRA_FLAG) -std=c++11 -g -DCAPNP_DEBUG_TYPES=1 -Wall" LIBS='-lz -pthread' $(EKAM) -j6 -c -n :51315 + CXXFLAGS="$(EXTRA_FLAG) -std=c++14 -g -DCAPNP_DEBUG_TYPES=1 -Wall" LIBS='-lz -pthread' $(EKAM) -j6 -c -n :51315 continuous-opt: - CXXFLAGS="$(EXTRA_FLAG) -std=c++11 -O2 -DNDEBUG -Wall" LIBS='-lz -pthread' $(EKAM) -j6 -c -n :51315 + CXXFLAGS="$(EXTRA_FLAG) -std=c++14 -O2 -DNDEBUG -Wall" LIBS='-lz -pthread' $(EKAM) -j6 -c -n :51315 continuous-opt3: - CXXFLAGS="$(EXTRA_FLAG) -std=c++11 -O3 -DNDEBUG -Wall" LIBS='-lz -pthread' $(EKAM) -j6 -c -n :51315 + CXXFLAGS="$(EXTRA_FLAG) -std=c++14 -O3 -DNDEBUG -Wall" LIBS='-lz -pthread' $(EKAM) -j6 -c -n :51315 continuous-opts: - CXXFLAGS="$(EXTRA_FLAG) -std=c++11 -Os -DNDEBUG -Wall" LIBS='-lz -pthread' $(EKAM) -j6 -c -n :51315 + CXXFLAGS="$(EXTRA_FLAG) -std=c++14 -Os -DNDEBUG -Wall" LIBS='-lz -pthread' $(EKAM) -j6 -c -n :51315 clean: rm -rf bin lib tmp diff --git a/c++/WORKSPACE b/c++/WORKSPACE new file mode 100644 index 0000000000..d94a279ef3 --- /dev/null +++ b/c++/WORKSPACE @@ -0,0 +1,54 @@ +workspace(name = "capnp-cpp") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//:build/load_br.bzl", "load_brotli") + +http_archive( + name = "bazel_skylib", + sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", + ], +) + +load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") + +bazel_skylib_workspace() + +http_archive( + name = "ssl", + sha256 = "873ec711658f65192e9c58554ce058d1cfa4e57e13ab5366ee16f76d1c757efc", + strip_prefix = "google-boringssl-ed2e74e", + type = "tgz", + # from master-with-bazel branch + urls = ["https://github.com/google/boringssl/tarball/ed2e74e737dc802ed9baad1af62c1514430a70d6"], +) + +# Based on https://github.com/bazelbuild/bazel/blob/master/third_party/zlib/BUILD. +_zlib_build = """ +cc_library( + name = "zlib", + srcs = glob(["*.c"]), + hdrs = glob(["*.h"]), + # Temporary workaround for zlib warnings and mac compilation, should no longer be needed with next release https://github.com/madler/zlib/issues/633 + copts = [ + "-w", + "-Dverbose=-1", + ] + select({ + "@platforms//os:macos": [ "-std=c90" ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], +) +""" + +http_archive( + name = "zlib", + build_file_content = _zlib_build, + sha256 = "d14c38e313afc35a9a8760dadf26042f51ea0f5d154b0630a31da0540107fb98", + strip_prefix = "zlib-1.2.13", + urls = ["https://zlib.net/zlib-1.2.13.tar.xz"], +) + +load_brotli() diff --git a/c++/build/configure.bzl b/c++/build/configure.bzl new file mode 100644 index 0000000000..bafee637eb --- /dev/null +++ b/c++/build/configure.bzl @@ -0,0 +1,103 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "int_flag") + +def kj_configure(): + """Generates set of flag, settings for kj configuration. + + Creates kj-defines cc_library with all necessary preprocessor defines. + """ + + # Flags to configure KJ library build. + bool_flag( + name = "openssl", + build_setting_default = False, + ) + + bool_flag( + name = "zlib", + build_setting_default = False, + ) + + bool_flag( + name = "brotli", + build_setting_default = False, + ) + + bool_flag( + name = "libdl", + build_setting_default = False, + ) + + bool_flag( + name = "save_acquired_lock_info", + build_setting_default = False, + ) + + bool_flag( + name = "track_lock_blocking", + build_setting_default = False, + ) + + bool_flag( + name = "coroutines", + build_setting_default = False, + ) + + # Settings to use in select() expressions + native.config_setting( + name = "use_openssl", + flag_values = {"openssl": "True"}, + visibility = ["//visibility:public"], + ) + + native.config_setting( + name = "use_zlib", + flag_values = {"zlib": "True"}, + ) + + native.config_setting( + name = "use_brotli", + flag_values = {"brotli": "True"}, + ) + + native.config_setting( + name = "use_libdl", + flag_values = {"libdl": "True"}, + ) + + native.config_setting( + name = "use_coroutines", + flag_values = {"coroutines": "True"}, + ) + + native.config_setting( + name = "use_save_acquired_lock_info", + flag_values = {"save_acquired_lock_info": "True"}, + ) + + native.config_setting( + name = "use_track_lock_blocking", + flag_values = {"track_lock_blocking": "True"}, + ) + + native.cc_library( + name = "kj-defines", + defines = select({ + "//src/kj:use_openssl": ["KJ_HAS_OPENSSL"], + "//conditions:default": [], + }) + select({ + "//src/kj:use_zlib": ["KJ_HAS_ZLIB"], + "//conditions:default": [], + }) + select({ + "//src/kj:use_brotli": ["KJ_HAS_BROTLI"], + "//conditions:default": [], + }) + select({ + "//src/kj:use_libdl": ["KJ_HAS_LIBDL"], + "//conditions:default": [], + }) + select({ + "//src/kj:use_save_acquired_lock_info": ["KJ_SAVE_ACQUIRED_LOCK_INFO=1"], + "//conditions:default": ["KJ_SAVE_ACQUIRED_LOCK_INFO=0"], + }) + select({ + "//src/kj:use_track_lock_blocking": ["KJ_TRACK_LOCK_BLOCKING=1"], + "//conditions:default": ["KJ_TRACK_LOCK_BLOCKING=0"], + }), + ) diff --git a/c++/build/load_br.bzl b/c++/build/load_br.bzl new file mode 100644 index 0000000000..fe6fdfedd6 --- /dev/null +++ b/c++/build/load_br.bzl @@ -0,0 +1,12 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# Defined in a bzl file to allow dependents to pull in brotli via capnproto. Using latest brotli +# commit due to macOS compile issues with v1.0.9, switch to a release version later +def load_brotli(): + http_archive( + name = "brotli", + sha256 = "e33f397d86aaa7f3e786bdf01a7b5cff4101cfb20041c04b313b149d34332f64", + strip_prefix = "google-brotli-ed1995b", + type = "tgz", + urls = ["https://github.com/google/brotli/tarball/ed1995b6bda19244070ab5d331111f16f67c8054"], + ) diff --git a/c++/cmake/CapnProtoConfig.cmake.in b/c++/cmake/CapnProtoConfig.cmake.in index a8c90d0248..4b8ac96db4 100644 --- a/c++/cmake/CapnProtoConfig.cmake.in +++ b/c++/cmake/CapnProtoConfig.cmake.in @@ -1,19 +1,47 @@ +# Cap'n Proto CMake Package Configuration +# +# When configured and installed, this file enables client projects to find Cap'n Proto using +# CMake's find_package() command. It adds imported targets in the CapnProto:: namespace, such as +# CapnProto::kj, CapnProto::capnp, etc. (one target for each file in pkgconfig/*.pc.in), defines +# the capnp_generate_cpp() function, and exposes some variables for compatibility with the original +# FindCapnProto.cmake module. +# # Example usage: # find_package(CapnProto) # capnp_generate_cpp(CAPNP_SRCS CAPNP_HDRS schema.capnp) -# include_directories(${CMAKE_CURRENT_BINARY_DIR}) # add_executable(foo main.cpp ${CAPNP_SRCS}) -# target_link_libraries(foo CapnProto::capnp) +# target_link_libraries(foo PRIVATE CapnProto::capnp) +# target_include_directories(foo PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +# +# If you are using RPC features, use 'CapnProto::capnp-rpc' in the target_link_libraries() call. # -# If you are using RPC features, use 'CapnProto::capnp-rpc' -# in target_link_libraries call. +# Paths to `capnp` and `capnpc-c++` are exposed in the following variables: +# CAPNP_EXECUTABLE +# Path to the `capnp` tool (can be set to override). +# CAPNPC_CXX_EXECUTABLE +# Path to the `capnpc-c++` tool (can be set to override). +# +# For FindCapnProto.cmake compatibility, the following variables are also provided. Please prefer +# using the imported targets in new CMake code. +# CAPNP_INCLUDE_DIRS +# Include directories for the library's headers. +# CANP_LIBRARIES +# The Cap'n Proto library paths. +# CAPNP_LIBRARIES_LITE +# Paths to only the 'lite' libraries. +# CAPNP_DEFINITIONS +# Compiler definitions required for building with the library. +# CAPNP_FOUND +# Set if the libraries have been located (prefer using CapnProto_FOUND in new code). # @PACKAGE_INIT@ set(CapnProto_VERSION @VERSION@) -set(CAPNP_EXECUTABLE $) -set(CAPNPC_CXX_EXECUTABLE $) +set(CAPNP_EXECUTABLE $ + CACHE FILEPATH "Location of capnp executable") +set(CAPNPC_CXX_EXECUTABLE $ + CACHE FILEPATH "Location of capnpc-c++ executable") set(CAPNP_INCLUDE_DIRECTORY "@PACKAGE_CMAKE_INSTALL_FULL_INCLUDEDIR@") # work around http://public.kitware.com/Bug/view.php?id=15258 @@ -21,7 +49,75 @@ if(NOT _IMPORT_PREFIX) set(_IMPORT_PREFIX ${PACKAGE_PREFIX_DIR}) endif() +if (@WITH_OPENSSL@) # WITH_OPENSSL + include(CMakeFindDependencyMacro) + if (CMAKE_VERSION VERSION_LESS 3.9) + # find_dependency() did not support COMPONENTS until CMake 3.9 + # + # in practice, this call can be erroneous + # if the user has only libcrypto installed, but not libssl + find_dependency(OpenSSL) + else() + find_dependency(OpenSSL COMPONENTS Crypto SSL) + endif() +endif() + +if (@WITH_ZLIB@) # WITH_ZLIB + include(CMakeFindDependencyMacro) + find_dependency(ZLIB) +endif() + +if (@_WITH_LIBUCONTEXT@) # _WITH_LIBUCONTEXT + set(forwarded_config_flags) + if(CapnProto_FIND_QUIETLY) + list(APPEND forwarded_config_flags QUIET) + endif() + if(CapnProto_FIND_REQUIRED) + list(APPEND forwarded_config_flags REQUIRED) + endif() + # If the consuming project called find_package(CapnProto) with the QUIET or REQUIRED flags, forward + # them to calls to find_package(PkgConfig) and pkg_check_modules(). Note that find_dependency() + # would do this for us in the former case, but there is no such forwarding wrapper for + # pkg_check_modules(). + + find_package(PkgConfig ${forwarded_config_flags}) + if(NOT ${PkgConfig_FOUND}) + # If we're here, the REQUIRED flag must not have been passed, else we would have had a fatal + # error. Nevertheless, a diagnostic for this case is probably nice. + if(NOT CapnProto_FIND_QUIETLY) + message(WARNING "pkg-config cannot be found") + endif() + set(CapnProto_FOUND OFF) + return() + endif() + if (CMAKE_VERSION VERSION_LESS 3.6) + # CMake >= 3.6 required due to the use of IMPORTED_TARGET + message(SEND_ERROR "libucontext support requires CMake >= 3.6.") + endif() + + pkg_check_modules(libucontext IMPORTED_TARGET ${forwarded_config_flags} libucontext) +endif() include("${CMAKE_CURRENT_LIST_DIR}/CapnProtoTargets.cmake") include("${CMAKE_CURRENT_LIST_DIR}/CapnProtoMacros.cmake") + + +# FindCapnProto.cmake provides dependency information via several CAPNP_-prefixed variables. New +# code should not rely on these variables, but prefer linking directly to the imported targets we +# now provide. However, we should still set these variables to ease the transition for projects +# which currently depend on the find-module. + +set(CAPNP_INCLUDE_DIRS ${CAPNP_INCLUDE_DIRECTORY}) + +# No need to list all libraries, just the leaves of the dependency tree. +set(CAPNP_LIBRARIES_LITE CapnProto::capnp) +set(CAPNP_LIBRARIES CapnProto::capnp-rpc CapnProto::capnp-json + CapnProto::kj-http) + +set(CAPNP_DEFINITIONS) +if(TARGET CapnProto::capnp AND NOT TARGET CapnProto::capnp-rpc) + set(CAPNP_DEFINITIONS -DCAPNP_LITE) +endif() + +set(CAPNP_FOUND ${CapnProto_FOUND}) diff --git a/c++/cmake/CapnProtoConfigVersion.cmake.in b/c++/cmake/CapnProtoConfigVersion.cmake.in new file mode 100644 index 0000000000..0773a4f111 --- /dev/null +++ b/c++/cmake/CapnProtoConfigVersion.cmake.in @@ -0,0 +1,36 @@ +# This is a copy of /usr/share/cmake-3.5/Modules/BasicConfigVersion-AnyNewerVersion.cmake.in, with +# the following change: +# - @CVF_VERSION renamed to @PACKAGE_VERSION@. Autoconf defines a PACKAGE_VERSION +# output variable for us, so might as well take advantage of that. + +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version. +# The variable PACKAGE_VERSION must be set before calling configure_file(). + +set(PACKAGE_VERSION "@PACKAGE_VERSION@") + +if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "@CMAKE_SIZEOF_VOID_P@" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "@CMAKE_SIZEOF_VOID_P@") + math(EXPR installedBits "@CMAKE_SIZEOF_VOID_P@ * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/c++/cmake/CapnProtoMacros.cmake b/c++/cmake/CapnProtoMacros.cmake index 3d58927b02..e44b66ea0f 100644 --- a/c++/cmake/CapnProtoMacros.cmake +++ b/c++/cmake/CapnProtoMacros.cmake @@ -3,19 +3,19 @@ # Example usage: # find_package(CapnProto) # capnp_generate_cpp(CAPNP_SRCS CAPNP_HDRS schema.capnp) -# include_directories(${CMAKE_CURRENT_BINARY_DIR}) # add_executable(foo main.cpp ${CAPNP_SRCS}) -# target_link_libraries(foo CapnProto::capnp-rpc) +# target_link_libraries(foo PRIVATE CapnProto::capnp-rpc) +# target_include_directories(foo PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) # -# If you are using not using the RPC features you can use -# 'CapnProto::capnp' in target_link_libraries call +# If you are not using the RPC features you can use 'CapnProto::capnp' in the +# target_link_libraries call # # Configuration variables (optional): # CAPNPC_OUTPUT_DIR # Directory to place compiled schema sources (default: CMAKE_CURRENT_BINARY_DIR). # CAPNPC_IMPORT_DIRS # List of additional include directories for the schema compiler. -# (CMAKE_CURRENT_SOURCE_DIR and CAPNP_INCLUDE_DIRECTORY are always included.) +# (CAPNPC_SRC_PREFIX and CAPNP_INCLUDE_DIRECTORY are always included.) # CAPNPC_SRC_PREFIX # Schema file source prefix (default: CMAKE_CURRENT_SOURCE_DIR). # CAPNPC_FLAGS @@ -30,9 +30,15 @@ function(CAPNP_GENERATE_CPP SOURCES HEADERS) set(tool_depends ${EMPTY_STRING}) #Use cmake targets available if(TARGET capnp_tool) - set(CAPNP_EXECUTABLE capnp_tool) - GET_TARGET_PROPERTY(CAPNPC_CXX_EXECUTABLE capnpc_cpp CAPNPC_CXX_EXECUTABLE) - GET_TARGET_PROPERTY(CAPNP_INCLUDE_DIRECTORY capnp_tool CAPNP_INCLUDE_DIRECTORY) + if(NOT CAPNP_EXECUTABLE) + set(CAPNP_EXECUTABLE $) + endif() + if(NOT CAPNPC_CXX_EXECUTABLE) + get_target_property(CAPNPC_CXX_EXECUTABLE capnpc_cpp CAPNPC_CXX_EXECUTABLE) + endif() + if(NOT CAPNP_INCLUDE_DIRECTORY) + get_target_property(CAPNP_INCLUDE_DIRECTORY capnp_tool CAPNP_INCLUDE_DIRECTORY) + endif() list(APPEND tool_depends capnp_tool capnpc_cpp) endif() if(NOT CAPNP_EXECUTABLE) @@ -45,17 +51,6 @@ function(CAPNP_GENERATE_CPP SOURCES HEADERS) message(SEND_ERROR "Could not locate capnp header files (CAPNP_INCLUDE_DIRECTORY).") endif() - # Default compiler includes - set(include_path -I ${CMAKE_CURRENT_SOURCE_DIR} -I ${CAPNP_INCLUDE_DIRECTORY}) - - if(DEFINED CAPNPC_IMPORT_DIRS) - # Append each directory as a series of '-I' flags in ${include_path} - foreach(directory ${CAPNPC_IMPORT_DIRS}) - get_filename_component(absolute_path "${directory}" ABSOLUTE) - list(APPEND include_path -I ${absolute_path}) - endforeach() - endif() - if(DEFINED CAPNPC_OUTPUT_DIR) # Prepend a ':' to get the format for the '-o' flag right set(output_dir ":${CAPNPC_OUTPUT_DIR}") @@ -68,14 +63,30 @@ function(CAPNP_GENERATE_CPP SOURCES HEADERS) endif() get_filename_component(CAPNPC_SRC_PREFIX "${CAPNPC_SRC_PREFIX}" ABSOLUTE) + # Default compiler includes. Note that in capnp's own test usage of capnp_generate_cpp(), these + # two variables will end up evaluating to the same directory. However, it's difficult to + # deduplicate them because if CAPNP_INCLUDE_DIRECTORY came from the capnp_tool target property, + # then it must be a generator expression in order to handle usages in both the build tree and the + # install tree. This vastly overcomplicates duplication detection, so the duplication doesn't seem + # worth fixing. + set(include_path -I "${CAPNPC_SRC_PREFIX}" -I "${CAPNP_INCLUDE_DIRECTORY}") + + if(DEFINED CAPNPC_IMPORT_DIRS) + # Append each directory as a series of '-I' flags in ${include_path} + foreach(directory ${CAPNPC_IMPORT_DIRS}) + get_filename_component(absolute_path "${directory}" ABSOLUTE) + list(APPEND include_path -I "${absolute_path}") + endforeach() + endif() + set(${SOURCES}) set(${HEADERS}) foreach(schema_file ${ARGN}) - if(NOT EXISTS "${CAPNPC_SRC_PREFIX}/${schema_file}") - message(FATAL_ERROR "Cap'n Proto schema file '${CAPNPC_SRC_PREFIX}/${schema_file}' does not exist!") - endif() get_filename_component(file_path "${schema_file}" ABSOLUTE) get_filename_component(file_dir "${file_path}" PATH) + if(NOT EXISTS "${file_path}") + message(FATAL_ERROR "Cap'n Proto schema file '${file_path}' does not exist!") + endif() # Figure out where the output files will go if (NOT DEFINED CAPNPC_OUTPUT_DIR) diff --git a/c++/cmake/CapnProtoTargets.cmake b/c++/cmake/CapnProtoTargets.cmake new file mode 100644 index 0000000000..879c2c0288 --- /dev/null +++ b/c++/cmake/CapnProtoTargets.cmake @@ -0,0 +1,221 @@ +# This CMake script adds imported targets for each shared library and executable distributed by +# Cap'n Proto's autotools build. +# +# This file IS NOT USED by the CMake build! The CMake build generates its own version of this script +# from its set of exported targets. I used such a generated script as a reference when writing this +# one. +# +# The set of library targets provided by this script is automatically generated from the list of .pc +# files maintained in configure.ac. The set of executable targets is hard-coded in this file. +# +# You can request that this script print debugging information by invoking cmake with: +# +# -DCapnProto_DEBUG=ON +# +# TODO(someday): Distinguish between debug and release builds. I.e., set IMPORTED_LOCATION_RELEASE +# rather than IMPORTED_LOCATION, etc., if this installation was configured as a release build. But +# how do we tell? grep for -g in CXXFLAGS? + +if(CMAKE_VERSION VERSION_LESS 3.1) + message(FATAL_ERROR "CMake >= 3.1 required") +endif() + +set(forwarded_config_flags) +if(CapnProto_FIND_QUIETLY) + list(APPEND forwarded_config_flags QUIET) +endif() +if(CapnProto_FIND_REQUIRED) + list(APPEND forwarded_config_flags REQUIRED) +endif() +# If the consuming project called find_package(CapnProto) with the QUIET or REQUIRED flags, forward +# them to calls to find_package(PkgConfig) and pkg_check_modules(). Note that find_dependency() +# would do this for us in the former case, but there is no such forwarding wrapper for +# pkg_check_modules(). + +find_package(PkgConfig ${forwarded_config_flags}) +if(NOT ${PkgConfig_FOUND}) + # If we're here, the REQUIRED flag must not have been passed, else we would have had a fatal + # error. Nevertheless, a diagnostic for this case is probably nice. + if(NOT CapnProto_FIND_QUIETLY) + message(WARNING "pkg-config cannot be found") + endif() + set(CapnProto_FOUND OFF) + return() +endif() + +function(_capnp_import_pkg_config_target target) + # Add an imported library target named CapnProto::${target}, using the output of various + # invocations of `pkg-config ${target}`. The generated imported library target tries to mimic the + # behavior of a real CMake-generated imported target as closely as possible. + # + # Usage: _capnp_import_pkg_config_target(target ) + + set(all_targets ${ARGN}) + + pkg_check_modules(${target} ${forwarded_config_flags} ${target}) + + if(NOT ${${target}_FOUND}) + if(NOT CapnProto_FIND_QUIETLY) + message(WARNING "CapnProtoConfig.cmake was configured to search for ${target}.pc, but pkg-config cannot find it. Ignoring this target.") + endif() + return() + endif() + + if(CapnProto_DEBUG) + # Dump the information pkg-config discovered. + foreach(var VERSION LIBRARY_DIRS LIBRARIES LDFLAGS_OTHER INCLUDE_DIRS CFLAGS_OTHER) + message(STATUS "${target}_${var} = ${${target}_${var}}") + endforeach() + endif() + + if(NOT ${${target}_VERSION} VERSION_EQUAL ${CapnProto_VERSION}) + if(NOT CapnProto_FIND_QUIETLY) + message(WARNING "CapnProtoConfig.cmake was configured to search for version ${CapnProto_VERSION}, but ${target} version ${${target}_VERSION} was found. Ignoring this target.") + endif() + return() + endif() + + # Make an educated guess as to what the target's .so and .a filenames must be. + set(target_name_shared + ${CMAKE_SHARED_LIBRARY_PREFIX}${target}-${CapnProto_VERSION}${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(target_name_static + ${CMAKE_STATIC_LIBRARY_PREFIX}${target}${CMAKE_STATIC_LIBRARY_SUFFIX}) + + # Find the actual target's file. find_library() sets a cache variable, so I made the variable name + # unique-ish. + find_library(CapnProto_${target}_IMPORTED_LOCATION + NAMES ${target_name_shared} ${target_name_static} # prefer libfoo-version.so over libfoo.a + PATHS ${${target}_LIBRARY_DIRS} + NO_DEFAULT_PATH + ) + # If the installed version of Cap'n Proto is in a system location, pkg-config will not have filled + # in ${target}_LIBRARY_DIRS. To account for this, fall back to a regular search. + find_library(CapnProto_${target}_IMPORTED_LOCATION + NAMES ${target_name_shared} ${target_name_static} # prefer libfoo-version.so over libfoo.a + ) + + if(NOT CapnProto_${target}_IMPORTED_LOCATION) + # Not an error if the library doesn't exist -- we may have found a lite mode installation. + if(CapnProto_DEBUG) + message(STATUS "${target} library does not exist") + endif() + return() + endif() + + # Record some information about this target -- shared versus static, location and soname -- which + # we'll use to build our imported target later. + + set(target_location ${CapnProto_${target}_IMPORTED_LOCATION}) + get_filename_component(target_name "${target_location}" NAME) + + set(target_type STATIC) + set(imported_soname_property) + if(target_name STREQUAL ${target_name_shared}) + set(target_type SHARED) + set(imported_soname_property IMPORTED_SONAME ${target_name}) + endif() + + # Each library dependency of the target is either the target itself, a sibling Cap'n Proto + # library, or a system library. We ignore the first case by removing this target from the + # dependencies. The remaining dependencies are either passed through or, if they are a sibling + # Cap'n Proto library, prefixed with `CapnProto::`. + set(dependencies ${${target}_LIBRARIES}) + list(REMOVE_ITEM dependencies ${target}) + set(target_interface_libs) + foreach(dependency ${dependencies}) + list(FIND all_targets ${dependency} target_index) + # TODO(cleanup): CMake >= 3.3 lets us write: `if(NOT ${dependency} IN_LIST all_targets)` + if(target_index EQUAL -1) + list(APPEND target_interface_libs ${dependency}) + else() + list(APPEND target_interface_libs CapnProto::${dependency}) + endif() + endforeach() + + add_library(CapnProto::${target} ${target_type} IMPORTED) + set_target_properties(CapnProto::${target} PROPERTIES + ${imported_soname_property} + IMPORTED_LOCATION "${target_location}" + # TODO(cleanup): Use cxx_std_14 once it's safe to require cmake 3.8. + INTERFACE_COMPILE_FEATURES "cxx_generic_lambdas" + INTERFACE_COMPILE_OPTIONS "${${target}_CFLAGS_OTHER}" + INTERFACE_INCLUDE_DIRECTORIES "${${target}_INCLUDE_DIRS}" + + # I'm dumping LDFLAGS_OTHER in with the libraries because there exists no + # INTERFACE_LINK_OPTIONS. See https://gitlab.kitware.com/cmake/cmake/issues/16543. + INTERFACE_LINK_LIBRARIES "${target_interface_libs};${${target}_LDFLAGS_OTHER}" + ) + + if(CapnProto_DEBUG) + # Dump all the properties we generated for the imported target. + foreach(prop + IMPORTED_LOCATION + IMPORTED_SONAME + INTERFACE_COMPILE_FEATURES + INTERFACE_COMPILE_OPTIONS + INTERFACE_INCLUDE_DIRECTORIES + INTERFACE_LINK_LIBRARIES) + get_target_property(value CapnProto::${target} ${prop}) + message(STATUS "CapnProto::${target} ${prop} = ${value}") + endforeach() + endif() +endfunction() + +# ======================================================================================== +# Imported library targets + +# Build a list of targets to search for from the list of .pc files. +# I.e. [somewhere/foo.pc, somewhere/bar.pc] -> [foo, bar] +set(library_targets) +foreach(filename ${CAPNP_PKG_CONFIG_FILES}) + get_filename_component(target ${filename} NAME_WE) + list(APPEND library_targets ${target}) +endforeach() + +# Try to add an imported library target CapnProto::foo for each foo.pc distributed with Cap'n Proto. +foreach(target ${library_targets}) + _capnp_import_pkg_config_target(${target} ${library_targets}) +endforeach() + +# Handle lite-mode and no libraries found cases. It is tempting to set a CapnProto_LITE variable +# here, but the real CMake-generated implementation does no such thing -- we'd need to set it in +# CapnProtoConfig.cmake.in itself. +if(TARGET CapnProto::capnp AND TARGET CapnProto::kj) + if(NOT TARGET CapnProto::capnp-rpc) + if(NOT CapnProto_FIND_QUIETLY) + message(STATUS "Found an installation of Cap'n Proto lite. Executable and library targets beyond libkj and libcapnp will be unavailable.") + endif() + # Lite mode doesn't include the executables, so return here. + return() + endif() +else() + # If we didn't even find capnp or kj, then we didn't find anything usable. + set(CapnProto_FOUND OFF) + return() +endif() + +# ======================================================================================== +# Imported executable targets + +get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) + +# Add executable targets for the capnp compiler and plugins. This list must be kept manually in sync +# with the rest of the project. + +add_executable(CapnProto::capnp_tool IMPORTED) +set_target_properties(CapnProto::capnp_tool PROPERTIES + IMPORTED_LOCATION "${_IMPORT_PREFIX}/bin/capnp${CMAKE_EXECUTABLE_SUFFIX}" +) + +add_executable(CapnProto::capnpc_cpp IMPORTED) +set_target_properties(CapnProto::capnpc_cpp PROPERTIES + IMPORTED_LOCATION "${_IMPORT_PREFIX}/bin/capnpc-c++${CMAKE_EXECUTABLE_SUFFIX}" +) + +add_executable(CapnProto::capnpc_capnp IMPORTED) +set_target_properties(CapnProto::capnpc_capnp PROPERTIES + IMPORTED_LOCATION "${_IMPORT_PREFIX}/bin/capnpc-capnp${CMAKE_EXECUTABLE_SUFFIX}" +) diff --git a/c++/cmake/FindCapnProto.cmake b/c++/cmake/FindCapnProto.cmake deleted file mode 100644 index bacd549a60..0000000000 --- a/c++/cmake/FindCapnProto.cmake +++ /dev/null @@ -1,197 +0,0 @@ -# -# Finds the Cap'n Proto libraries, and compiles schema files. -# -# Configuration variables (optional): -# CAPNPC_OUTPUT_DIR -# Directory to place compiled schema sources (default: the same directory as the schema file). -# CAPNPC_IMPORT_DIRS -# List of additional include directories for the schema compiler. -# (CMAKE_CURRENT_SOURCE_DIR and CAPNP_INCLUDE_DIRS are always included.) -# CAPNPC_SRC_PREFIX -# Schema file source prefix (default: CMAKE_CURRENT_SOURCE_DIR). -# CAPNPC_FLAGS -# Additional flags to pass to the schema compiler. -# -# Variables that are discovered: -# CAPNP_EXECUTABLE -# Path to the `capnp` tool (can be set to override). -# CAPNPC_CXX_EXECUTABLE -# Path to the `capnpc-c++` tool (can be set to override). -# CAPNP_INCLUDE_DIRS -# Include directories for the library's headers (can be set to override). -# CANP_LIBRARIES -# The Cap'n Proto library paths. -# CAPNP_LIBRARIES_LITE -# Paths to only the 'lite' libraries. -# CAPNP_DEFINITIONS -# Compiler definitions required for building with the library. -# CAPNP_FOUND -# Set if the libraries have been located. -# -# Example usage: -# -# find_package(CapnProto REQUIRED) -# include_directories(${CAPNP_INCLUDE_DIRS}) -# add_definitions(${CAPNP_DEFINITIONS}) -# -# capnp_generate_cpp(CAPNP_SRCS CAPNP_HDRS schema.capnp) -# add_executable(a a.cc ${CAPNP_SRCS} ${CAPNP_HDRS}) -# target_link_library(a ${CAPNP_LIBRARIES}) -# -# For out-of-source builds: -# -# set(CAPNPC_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) -# include_directories(${CAPNPC_OUTPUT_DIR}) -# capnp_generate_cpp(...) -# - -# CAPNP_GENERATE_CPP =========================================================== - -function(CAPNP_GENERATE_CPP SOURCES HEADERS) - if(NOT ARGN) - message(SEND_ERROR "CAPNP_GENERATE_CPP() called without any source files.") - endif() - if(NOT CAPNP_EXECUTABLE) - message(SEND_ERROR "Could not locate capnp executable (CAPNP_EXECUTABLE).") - endif() - if(NOT CAPNPC_CXX_EXECUTABLE) - message(SEND_ERROR "Could not locate capnpc-c++ executable (CAPNPC_CXX_EXECUTABLE).") - endif() - if(NOT CAPNP_INCLUDE_DIRS) - message(SEND_ERROR "Could not locate capnp header files (CAPNP_INCLUDE_DIRS).") - endif() - - # Default compiler includes - set(include_path -I ${CMAKE_CURRENT_SOURCE_DIR} -I ${CAPNP_INCLUDE_DIRS}) - - if(DEFINED CAPNPC_IMPORT_DIRS) - # Append each directory as a series of '-I' flags in ${include_path} - foreach(directory ${CAPNPC_IMPORT_DIRS}) - get_filename_component(absolute_path "${directory}" ABSOLUTE) - list(APPEND include_path -I ${absolute_path}) - endforeach() - endif() - - if(DEFINED CAPNPC_OUTPUT_DIR) - # Prepend a ':' to get the format for the '-o' flag right - set(output_dir ":${CAPNPC_OUTPUT_DIR}") - else() - set(output_dir ":.") - endif() - - if(NOT DEFINED CAPNPC_SRC_PREFIX) - set(CAPNPC_SRC_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}") - endif() - get_filename_component(CAPNPC_SRC_PREFIX "${CAPNPC_SRC_PREFIX}" ABSOLUTE) - - set(${SOURCES}) - set(${HEADERS}) - foreach(schema_file ${ARGN}) - get_filename_component(file_path "${schema_file}" ABSOLUTE) - get_filename_component(file_dir "${file_path}" PATH) - - # Figure out where the output files will go - if (NOT DEFINED CAPNPC_OUTPUT_DIR) - set(output_base "${file_path}") - else() - # Output files are placed in CAPNPC_OUTPUT_DIR, at a location as if they were - # relative to CAPNPC_SRC_PREFIX. - string(LENGTH "${CAPNPC_SRC_PREFIX}" prefix_len) - string(SUBSTRING "${file_path}" 0 ${prefix_len} output_prefix) - if(NOT "${CAPNPC_SRC_PREFIX}" STREQUAL "${output_prefix}") - message(SEND_ERROR "Could not determine output path for '${schema_file}' ('${file_path}') with source prefix '${CAPNPC_SRC_PREFIX}' into '${CAPNPC_OUTPUT_DIR}'.") - endif() - - string(SUBSTRING "${file_path}" ${prefix_len} -1 output_path) - set(output_base "${CAPNPC_OUTPUT_DIR}${output_path}") - endif() - - add_custom_command( - OUTPUT "${output_base}.c++" "${output_base}.h" - COMMAND "${CAPNP_EXECUTABLE}" - ARGS compile - -o ${CAPNPC_CXX_EXECUTABLE}${output_dir} - --src-prefix ${CAPNPC_SRC_PREFIX} - ${include_path} - ${CAPNPC_FLAGS} - ${file_path} - DEPENDS "${schema_file}" - COMMENT "Compiling Cap'n Proto schema ${schema_file}" - VERBATIM - ) - list(APPEND ${SOURCES} "${output_base}.c++") - list(APPEND ${HEADERS} "${output_base}.h") - endforeach() - - set_source_files_properties(${${SOURCES}} ${${HEADERS}} PROPERTIES GENERATED TRUE) - set(${SOURCES} ${${SOURCES}} PARENT_SCOPE) - set(${HEADERS} ${${HEADERS}} PARENT_SCOPE) -endfunction() - -# Find Libraries/Paths ========================================================= - -# Use pkg-config to get path hints and definitions -find_package(PkgConfig QUIET) -pkg_check_modules(PKGCONFIG_CAPNP capnp) -pkg_check_modules(PKGCONFIG_CAPNP_RPC capnp-rpc QUIET) -pkg_check_modules(PKGCONFIG_CAPNP_JSON capnp-json QUIET) - -find_library(CAPNP_LIB_KJ kj - HINTS "${PKGCONFIG_CAPNP_LIBDIR}" ${PKGCONFIG_CAPNP_LIBRARY_DIRS} -) -find_library(CAPNP_LIB_KJ-ASYNC kj-async - HINTS "${PKGCONFIG_CAPNP_RPC_LIBDIR}" ${PKGCONFIG_CAPNP_RPC_LIBRARY_DIRS} -) -find_library(CAPNP_LIB_CAPNP capnp - HINTS "${PKGCONFIG_CAPNP_LIBDIR}" ${PKGCONFIG_CAPNP_LIBRARY_DIRS} -) -find_library(CAPNP_LIB_CAPNP-RPC capnp-rpc - HINTS "${PKGCONFIG_CAPNP_RPC_LIBDIR}" ${PKGCONFIG_CAPNP_RPC_LIBRARY_DIRS} -) -find_library(CAPNP_LIB_CAPNP-JSON capnp-json - HINTS "${PKGCONFIG_CAPNP_JSON_LIBDIR}" ${PKGCONFIG_CAPNP_JSON_LIBRARY_DIRS} -) -mark_as_advanced(CAPNP_LIB_KJ CAPNP_LIB_KJ-ASYNC CAPNP_LIB_CAPNP CAPNP_LIB_CAPNP-RPC CAPNP_LIB_CAPNP-JSON) -set(CAPNP_LIBRARIES_LITE - ${CAPNP_LIB_CAPNP} - ${CAPNP_LIB_KJ} -) -set(CAPNP_LIBRARIES - ${CAPNP_LIB_CAPNP-JSON} - ${CAPNP_LIB_CAPNP-RPC} - ${CAPNP_LIB_CAPNP} - ${CAPNP_LIB_KJ-ASYNC} - ${CAPNP_LIB_KJ} -) - -# Was only the 'lite' library found? -if(CAPNP_LIB_CAPNP AND NOT CAPNP_LIB_CAPNP-RPC) - set(CAPNP_DEFINITIONS -DCAPNP_LITE) -else() - set(CAPNP_DEFINITIONS) -endif() - -find_path(CAPNP_INCLUDE_DIRS capnp/generated-header-support.h - HINTS "${PKGCONFIG_CAPNP_INCLUDEDIR}" ${PKGCONFIG_CAPNP_INCLUDE_DIRS} -) - -find_program(CAPNP_EXECUTABLE - NAMES capnp - DOC "Cap'n Proto Command-line Tool" - HINTS "${PKGCONFIG_CAPNP_PREFIX}/bin" -) - -find_program(CAPNPC_CXX_EXECUTABLE - NAMES capnpc-c++ - DOC "Capn'n Proto C++ Compiler" - HINTS "${PKGCONFIG_CAPNP_PREFIX}/bin" -) - -# Only *require* the include directory, libkj, and libcapnp. If compiling with -# CAPNP_LITE, nothing else will be found. -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(CAPNP DEFAULT_MSG - CAPNP_INCLUDE_DIRS - CAPNP_LIB_KJ - CAPNP_LIB_CAPNP -) diff --git a/c++/compile_flags.txt b/c++/compile_flags.txt new file mode 100644 index 0000000000..a9e8a51665 --- /dev/null +++ b/c++/compile_flags.txt @@ -0,0 +1,15 @@ +-std=c++20 +-Isrc +-Itmp +-isystem/usr/local/include +-isystem/usr/include/x86_64-linux-gnu +-isystem/usr/include +-DKJ_HEADER_WARNINGS +-DCAPNP_HEADER_WARNINGS +-DCAPNP_DEBUG_TYPES +-DKJ_HAS_OPENSSL +-DKJ_HAS_LIBDL +-DKJ_HAS_ZLIB +-DKJ_HAS_BROTLI +-DKJ_BENCHMARK_MALLOC +-xc++ diff --git a/c++/configure.ac b/c++/configure.ac index 2015d362b1..60b9c3de0b 100644 --- a/c++/configure.ac +++ b/c++/configure.ac @@ -1,6 +1,6 @@ ## Process this file with autoconf to produce configure. -AC_INIT([Capn Proto],[0.6.1],[capnproto@googlegroups.com],[capnproto-c++]) +AC_INIT([Capn Proto],[1.1-dev],[capnproto@googlegroups.com],[capnproto-c++]) AC_CONFIG_SRCDIR([src/capnp/layout.c++]) AC_CONFIG_AUX_DIR([build-aux]) @@ -22,6 +22,21 @@ AC_ARG_WITH([external-capnp], one (useful for cross-compiling)])], [external_capnp=yes],[external_capnp=no]) +AC_ARG_WITH([zlib], + [AS_HELP_STRING([--with-zlib], + [build libkj-gzip by linking against zlib @<:@default=check@:>@])], + [],[with_zlib=check]) + +AC_ARG_WITH([openssl], + [AS_HELP_STRING([--with-openssl], + [build libkj-tls by linking against openssl @<:@default=check@:>@])], + [],[with_openssl=check]) + +AC_ARG_WITH([fibers], + [AS_HELP_STRING([--with-fibers], + [build libkj-async with fibers @<:@default=check@:>@])], + [],[with_fibers=check]) + AC_ARG_ENABLE([reflection], [ AS_HELP_STRING([--disable-reflection], [ compile Cap'n Proto in "lite mode", in which all reflection APIs (schema.h, dynamic.h, etc.) @@ -50,7 +65,7 @@ AC_ARG_ENABLE([reflection], [ AC_PROG_CC AC_PROG_CXX AC_LANG([C++]) -AX_CXX_COMPILE_STDCXX_11 +AX_CXX_COMPILE_STDCXX_14 AS_CASE("${host_os}", *mingw*, [ # We don't use pthreads on MinGW. @@ -105,5 +120,153 @@ AC_SUBST([STDLIB_FLAG]) LIBS="$PTHREAD_LIBS $LIBS" CXXFLAGS="$CXXFLAGS $PTHREAD_CFLAGS" -AC_CONFIG_FILES([Makefile capnp.pc capnp-rpc.pc capnp-json.pc kj.pc kj-async.pc]) +AC_DEFUN([CAPNP_PKG_CONFIG_FILES], [ \ + pkgconfig/capnp.pc \ + pkgconfig/capnpc.pc \ + pkgconfig/capnp-rpc.pc \ + pkgconfig/capnp-json.pc \ + pkgconfig/capnp-websocket.pc \ + pkgconfig/kj.pc \ + pkgconfig/kj-async.pc \ + pkgconfig/kj-http.pc \ + pkgconfig/kj-gzip.pc \ + pkgconfig/kj-tls.pc \ + pkgconfig/kj-test.pc \ +]) +AC_DEFUN([CAPNP_CMAKE_CONFIG_FILES], [ \ + cmake/CapnProtoConfig.cmake \ + cmake/CapnProtoConfigVersion.cmake \ +]) + +[CAPNP_PKG_CONFIG_FILES]="CAPNP_PKG_CONFIG_FILES" +[CAPNP_CMAKE_CONFIG_FILES]="CAPNP_CMAKE_CONFIG_FILES" +AC_SUBST([CAPNP_PKG_CONFIG_FILES]) +AC_SUBST([CAPNP_CMAKE_CONFIG_FILES]) + +# Don't include security release in soname -- we want to replace old binaries +# in this case. +SO_VERSION=$(echo $VERSION | sed -e 's/^\([0-9]*[.][0-9]*[.][0-9]*\)\([.][0-9]*\)*\(-.*\)*$/\1\3/g') +AC_SUBST([SO_VERSION]) + +# CapnProtoConfig.cmake.in needs these PACKAGE_* output variables. +PACKAGE_INIT="set([CAPNP_PKG_CONFIG_FILES] CAPNP_PKG_CONFIG_FILES)" +PACKAGE_CMAKE_INSTALL_FULL_INCLUDEDIR="\${CMAKE_CURRENT_LIST_DIR}/../../../include" +AC_SUBST([PACKAGE_INIT]) +AC_SUBST([PACKAGE_CMAKE_INSTALL_FULL_INCLUDEDIR]) + +# CapnProtoConfigVersion.cmake.in needs PACKAGE_VERSION (already defined by AC_INIT) and +# CMAKE_SIZEOF_VOID_P output variables. +AC_CHECK_SIZEOF([void *]) +AC_SUBST(CMAKE_SIZEOF_VOID_P, $ac_cv_sizeof_void_p) + +# Detect presence of zlib, if it was not specified explicitly. +AS_IF([test "$with_zlib" = check], [ + AC_CHECK_LIB(z, deflate, [:], [ + with_zlib=no + ]) + AC_CHECK_HEADER([zlib.h], [:], [ + with_zlib=no + ]) + AS_IF([test "$with_zlib" = no], [ + AC_MSG_WARN("could not find zlib -- won't build libkj-gzip") + ], [ + with_zlib=yes + ]) +]) +AS_IF([test "$with_zlib" != no], [ + CXXFLAGS="$CXXFLAGS -DKJ_HAS_ZLIB" +]) +AM_CONDITIONAL([BUILD_KJ_GZIP], [test "$with_zlib" != no]) + +# Detect presence of OpenSSL, if it was not specified explicitly. +AS_IF([test "$with_openssl" = check], [ + AC_CHECK_LIB(crypto, CRYPTO_new_ex_data, [:], [ + with_openssl=no + ]) + AC_CHECK_LIB(ssl, OPENSSL_init_ssl, [:], [ + with_openssl=no + ], [-lcrypto]) + AC_CHECK_HEADER([openssl/ssl.h], [:], [ + with_openssl=no + ]) + AS_IF([test "$with_openssl" = no], [ + AC_MSG_WARN("could not find OpenSSL -- won't build libkj-tls") + ], [ + with_openssl=yes + ]) +]) +AS_IF([test "$with_openssl" != no], [ + CXXFLAGS="$CXXFLAGS -DKJ_HAS_OPENSSL" +]) +AM_CONDITIONAL([BUILD_KJ_TLS], [test "$with_openssl" != no]) + +# Fibers don't work if exceptions are disabled, so default off in that case. +AS_IF([test "$with_fibers" != no], [ + AC_MSG_CHECKING([if exceptions are enabled]) + AC_COMPILE_IFELSE([void foo() { throw 1; }], [ + AC_MSG_RESULT([yes]) + ], [ + AS_IF([test "$with_fibers" = check], [ + AC_MSG_RESULT([no -- therefore, disabling fibers]) + with_fibers=no + ], [ + AC_MSG_RESULT([no]) + AC_MSG_ERROR([Fibers require exceptions, but your compiler flags disable exceptions. Please either enable exceptions or disable fibers (--without-fibers).]) + ]) + ]) +]) + +# Check for library support necessary for fibers. +AS_IF([test "$with_fibers" != no], [ + case "${host_os}" in + cygwin* | mingw* ) + # Fibers always work on Windows, where there's an explicit API for them. + with_fibers=yes + ;; + * ) + # Fibers need the symbols getcontext, setcontext, swapcontext and makecontext. + # We assume that makecontext implies the rest. + libc_supports_fibers=yes + AC_SEARCH_LIBS([makecontext], [], [], [ + libc_supports_fibers=no + ]) + + AS_IF([test "$libc_supports_fibers" = yes], [ + with_fibers=yes + ], [ + # If getcontext does not exist in libc, try with libucontext + ucontext_supports_fibers=yes + AC_CHECK_LIB(ucontext, [makecontext], [], [ + ucontext_supports_fibers=no + ]) + AS_IF([test "$ucontext_supports_fibers" = yes], [ + ASYNC_LIBS="$ASYNC_LIBS -lucontext" + with_fibers=yes + ], [ + AS_IF([test "$with_fibers" = yes], [ + AC_MSG_ERROR([Missing symbols required for fibers (makecontext, setcontext, ...). Disable fibers (--without-fibers) or install libucontext]) + ], [ + AC_MSG_WARN([could not find required symbols (makecontext, setcontext, ...) -- won't build with fibers]) + with_fibers=no + ]) + ]) + ]) + ;; + esac +]) +AS_IF([test "$with_fibers" = yes], [ + CXXFLAGS="$CXXFLAGS -DKJ_USE_FIBERS" +], [ + CXXFLAGS="$CXXFLAGS -DKJ_USE_FIBERS=0" +]) + +# CapnProtoConfig.cmake.in needs these variables, +# we force them to NO because we don't need the CMake dependency for them, +# the dependencies are provided by the .pc files. +AC_SUBST(WITH_OPENSSL, NO) +AC_SUBST(_WITH_LIBUCONTEXT, NO) + +AM_CONDITIONAL([HAS_FUZZING_ENGINE], [test "x$LIB_FUZZING_ENGINE" != "x"]) + +AC_CONFIG_FILES([Makefile] CAPNP_PKG_CONFIG_FILES CAPNP_CMAKE_CONFIG_FILES) AC_OUTPUT diff --git a/c++/ekam-build.sh b/c++/ekam-build.sh new file mode 100755 index 0000000000..eff565a44c --- /dev/null +++ b/c++/ekam-build.sh @@ -0,0 +1,70 @@ +#! /bin/bash +# +# This file builds Cap'n Proto using Ekam. + +set -euo pipefail + +NPROC=$(nproc) + +if [ ! -e deps/ekam ]; then + mkdir -p deps + git clone https://github.com/capnproto/ekam.git deps/ekam +fi + +if [ ! -e deps/ekam/deps/capnproto ]; then + mkdir -p deps/ekam/deps + ln -s ../../../.. deps/ekam/deps/capnproto +fi + +if [ ! -e deps/ekam/ekam ]; then + (cd deps/ekam && make -j$NPROC) +fi + +OPT_CXXFLAGS= +EXTRA_LIBS= +EKAM_FLAGS= + +while [ $# -gt 0 ]; do + case $1 in + dbg | debug ) + OPT_CXXFLAGS="-g -DCAPNP_DEBUG_TYPES " + ;; + opt | release ) + OPT_CXXFLAGS="-DNDEBUG -O2 -g" + ;; + prof | profile ) + OPT_CXXFLAGS="-DNDEBUG -O2 -g" + EXTRA_LIBS="$EXTRA_LIBS -lprofiler" + ;; + tcmalloc ) + EXTRA_LIBS="$EXTRA_LIBS -ltcmalloc" + ;; + continuous ) + EKAM_FLAGS="-c -n :41315" + ;; + * ) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac + shift +done + +CLANG_CXXFLAGS="-std=c++20 -stdlib=libc++ -pthread -Wall -Wextra -Werror -Wno-strict-aliasing -Wno-sign-compare -Wno-unused-parameter -Wimplicit-fallthrough -Wno-error=unused-command-line-argument -Wno-missing-field-initializers -DKJ_HEADER_WARNINGS -DCAPNP_HEADER_WARNINGS -DKJ_HAS_OPENSSL -DKJ_HAS_LIBDL -DKJ_HAS_ZLIB -DKJ_BENCHMARK_MALLOC" + +export CXX=${CXX:-clang++} +export CC=${CC:-clang} +export LIBS="-lz -ldl -lcrypto -lssl -stdlib=libc++ $EXTRA_LIBS -pthread" +export CXXFLAGS=${CXXFLAGS:-$OPT_CXXFLAGS $CLANG_CXXFLAGS} + +# TODO(someday): Get the protobuf benchmarks working. For now these settings will prevent build +# errors in the benchmarks directory. Note that it's tricky to link against an installed copy +# of libprotobuf because we have to use compatible C++ standard libraries. We either need to +# build libprotobuf from source using libc++, or we need to switch back to libstdc++ when +# enabling libprotobuf. Arguably building from source would be more fair so we can match compiler +# flags for performance comparison purposes, but we'll have to see if ekam is still able to build +# libprotobuf these days... +CXXFLAGS="$CXXFLAGS -DCAPNP_NO_PROTOBUF_BENCHMARK" +export PROTOC=/bin/true + +exec deps/ekam/bin/ekam $EKAM_FLAGS -j$NPROC diff --git a/c++/m4/ax_cxx_compile_stdcxx_11.m4 b/c++/m4/ax_cxx_compile_stdcxx_14.m4 similarity index 80% rename from c++/m4/ax_cxx_compile_stdcxx_11.m4 rename to c++/m4/ax_cxx_compile_stdcxx_14.m4 index e669ceab0f..ec41463eaa 100644 --- a/c++/m4/ax_cxx_compile_stdcxx_11.m4 +++ b/c++/m4/ax_cxx_compile_stdcxx_14.m4 @@ -1,19 +1,20 @@ # ============================================================================ # http://www.gnu.org/software/autoconf-archive/ax_cxx_compile_stdcxx_11.html # Additionally modified to detect -stdlib by Kenton Varda. +# Further modified for C++14 by Kenton Varda. # ============================================================================ # # SYNOPSIS # -# AX_CXX_COMPILE_STDCXX_11([ext|noext]) +# AX_CXX_COMPILE_STDCXX_14([ext|noext]) # # DESCRIPTION # -# Check for baseline language coverage in the compiler for the C++11 +# Check for baseline language coverage in the compiler for the C++14 # standard; if necessary, add switches to CXXFLAGS to enable support. -# Errors out if no mode that supports C++11 baseline syntax can be found. +# Errors out if no mode that supports C++14 baseline syntax can be found. # The argument, if specified, indicates whether you insist on an extended -# mode (e.g. -std=gnu++11) or a strict conformance mode (e.g. -std=c++11). +# mode (e.g. -std=gnu++14) or a strict conformance mode (e.g. -std=c++14). # If neither is specified, you get whatever works, with preference for an # extended mode. # @@ -42,7 +43,7 @@ #serial 1 -m4_define([_AX_CXX_COMPILE_STDCXX_11_testbody], [ +m4_define([_AX_CXX_COMPILE_STDCXX_14_testbody], [[ template struct check { @@ -78,7 +79,18 @@ m4_define([_AX_CXX_COMPILE_STDCXX_11_testbody], [ #include #endif #endif -]) + + // C++14 stuff + auto deduceReturnType(int i) { return i; } + + auto genericLambda = [](auto x, auto y) { return x + y; }; + auto captureExpressions = [x = 123]() { return x; }; + + // Avoid unused variable warnings. + int foo() { + return genericLambda(1, 2) + captureExpressions(); + } +]]) m4_define([_AX_CXX_COMPILE_STDCXX_11_testbody_lib], [ #include @@ -87,31 +99,31 @@ m4_define([_AX_CXX_COMPILE_STDCXX_11_testbody_lib], [ #include ]) -AC_DEFUN([AX_CXX_COMPILE_STDCXX_11], [dnl +AC_DEFUN([AX_CXX_COMPILE_STDCXX_14], [dnl m4_if([$1], [], [], [$1], [ext], [], [$1], [noext], [], - [m4_fatal([invalid argument `$1' to AX_CXX_COMPILE_STDCXX_11])])dnl + [m4_fatal([invalid argument `$1' to AX_CXX_COMPILE_STDCXX_14])])dnl AC_LANG_ASSERT([C++])dnl ac_success=no - AC_CACHE_CHECK(whether $CXX supports C++11 features by default, - ax_cv_cxx_compile_cxx11, - [AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], - [ax_cv_cxx_compile_cxx11=yes], - [ax_cv_cxx_compile_cxx11=no])]) - if test x$ax_cv_cxx_compile_cxx11 = xyes; then + AC_CACHE_CHECK(whether $CXX supports C++14 features by default, + ax_cv_cxx_compile_cxx14, + [AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_14_testbody])], + [ax_cv_cxx_compile_cxx14=yes], + [ax_cv_cxx_compile_cxx14=no])]) + if test x$ax_cv_cxx_compile_cxx14 = xyes; then ac_success=yes fi m4_if([$1], [noext], [], [dnl if test x$ac_success = xno; then - for switch in -std=gnu++11 -std=gnu++0x; do - cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx11_$switch]) - AC_CACHE_CHECK(whether $CXX supports C++11 features with $switch, + for switch in -std=gnu++14 -std=gnu++1y; do + cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx14_$switch]) + AC_CACHE_CHECK(whether $CXX supports C++14 features with $switch, $cachevar, [ac_save_CXX="$CXX" CXX="$CXX $switch" - AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], + AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_14_testbody])], [eval $cachevar=yes], [eval $cachevar=no]) CXX="$ac_save_CXX"]) @@ -125,13 +137,13 @@ AC_DEFUN([AX_CXX_COMPILE_STDCXX_11], [dnl m4_if([$1], [ext], [], [dnl if test x$ac_success = xno; then - for switch in -std=c++11 -std=c++0x; do - cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx11_$switch]) - AC_CACHE_CHECK(whether $CXX supports C++11 features with $switch, + for switch in -std=c++14 -std=c++1y; do + cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx14_$switch]) + AC_CACHE_CHECK(whether $CXX supports C++14 features with $switch, $cachevar, [ac_save_CXX="$CXX" CXX="$CXX $switch" - AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], + AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_14_testbody])], [eval $cachevar=yes], [eval $cachevar=no]) CXX="$ac_save_CXX"]) @@ -144,7 +156,7 @@ AC_DEFUN([AX_CXX_COMPILE_STDCXX_11], [dnl fi]) if test x$ac_success = xno; then - AC_MSG_ERROR([*** A compiler with support for C++11 language features is required.]) + AC_MSG_ERROR([*** A compiler with support for C++14 language features is required.]) else ac_success=no AC_CACHE_CHECK(whether $CXX supports C++11 library features by default, diff --git a/c++/capnp-json.pc.in b/c++/pkgconfig/capnp-json.pc.in similarity index 100% rename from c++/capnp-json.pc.in rename to c++/pkgconfig/capnp-json.pc.in diff --git a/c++/capnp-rpc.pc.in b/c++/pkgconfig/capnp-rpc.pc.in similarity index 100% rename from c++/capnp-rpc.pc.in rename to c++/pkgconfig/capnp-rpc.pc.in diff --git a/c++/pkgconfig/capnp-websocket.pc.in b/c++/pkgconfig/capnp-websocket.pc.in new file mode 100644 index 0000000000..e64a28be15 --- /dev/null +++ b/c++/pkgconfig/capnp-websocket.pc.in @@ -0,0 +1,11 @@ +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: Cap'n Proto WebSocket RPC +Description: WebSocket MessageStream for Cap'n Proto +Version: @VERSION@ +Libs: -L${libdir} -lcapnp-websocket +Requires: capnp = @VERSION@ capnp-rpc = @VERSION@ kj = @VERSION@ kj-async = @VERSION@ kj-http = @VERSION@ +Cflags: -I${includedir} diff --git a/c++/capnp.pc.in b/c++/pkgconfig/capnp.pc.in similarity index 100% rename from c++/capnp.pc.in rename to c++/pkgconfig/capnp.pc.in diff --git a/c++/pkgconfig/capnpc.pc.in b/c++/pkgconfig/capnpc.pc.in new file mode 100644 index 0000000000..4e62944b4b --- /dev/null +++ b/c++/pkgconfig/capnpc.pc.in @@ -0,0 +1,12 @@ +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: Cap'n Proto +Description: Insanely fast serialization system compiler library +Version: @VERSION@ +Libs: -L${libdir} -lcapnpc @PTHREAD_CFLAGS@ @PTHREAD_LIBS@ @STDLIB_FLAG@ +Libs.private: @LIBS@ +Requires: kj = @VERSION@ +Cflags: -I${includedir} @PTHREAD_CFLAGS@ @STDLIB_FLAG@ @CAPNP_LITE_FLAG@ diff --git a/c++/kj-async.pc.in b/c++/pkgconfig/kj-async.pc.in similarity index 55% rename from c++/kj-async.pc.in rename to c++/pkgconfig/kj-async.pc.in index 765197f34b..49d5ff6996 100644 --- a/c++/kj-async.pc.in +++ b/c++/pkgconfig/kj-async.pc.in @@ -6,6 +6,6 @@ includedir=@includedir@ Name: KJ Async Framework Library Description: Basic utility library called KJ (async part) Version: @VERSION@ -Libs: -L${libdir} -lkj-async @PTHREAD_CFLAGS@ @PTHREAD_LIBS@ @STDLIB_FLAG@ +Libs: -L${libdir} -lkj-async @ASYNC_LIBS@ @PTHREAD_CFLAGS@ @PTHREAD_LIBS@ @STDLIB_FLAG@ Requires: kj = @VERSION@ -Cflags: -I${includedir} @PTHREAD_CFLAGS@ @STDLIB_FLAG@ @CAPNP_LITE_FLAG@ +Cflags: -I${includedir} @ASYNC_LIBS@ @PTHREAD_CFLAGS@ @STDLIB_FLAG@ @CAPNP_LITE_FLAG@ diff --git a/c++/pkgconfig/kj-gzip.pc.in b/c++/pkgconfig/kj-gzip.pc.in new file mode 100644 index 0000000000..cc999a08c5 --- /dev/null +++ b/c++/pkgconfig/kj-gzip.pc.in @@ -0,0 +1,11 @@ +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: KJ Gzip Adapters +Description: Basic utility library called KJ (gzip part) +Version: @VERSION@ +Libs: -L${libdir} -lkj-gzip @PTHREAD_CFLAGS@ @PTHREAD_LIBS@ @STDLIB_FLAG@ +Requires: kj-async = @VERSION@ +Cflags: -I${includedir} @PTHREAD_CFLAGS@ @STDLIB_FLAG@ @CAPNP_LITE_FLAG@ diff --git a/c++/pkgconfig/kj-http.pc.in b/c++/pkgconfig/kj-http.pc.in new file mode 100644 index 0000000000..63b7143684 --- /dev/null +++ b/c++/pkgconfig/kj-http.pc.in @@ -0,0 +1,11 @@ +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: KJ HTTP Library +Description: Basic utility library called KJ (HTTP part) +Version: @VERSION@ +Libs: -L${libdir} -lkj-http @PTHREAD_CFLAGS@ @PTHREAD_LIBS@ @STDLIB_FLAG@ +Requires: kj-async = @VERSION@ +Cflags: -I${includedir} @PTHREAD_CFLAGS@ @STDLIB_FLAG@ @CAPNP_LITE_FLAG@ diff --git a/c++/pkgconfig/kj-test.pc.in b/c++/pkgconfig/kj-test.pc.in new file mode 100644 index 0000000000..eed62f4592 --- /dev/null +++ b/c++/pkgconfig/kj-test.pc.in @@ -0,0 +1,11 @@ +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: KJ Test Framework +Description: Basic utility library called KJ (test part) +Version: @VERSION@ +Libs: -L${libdir} -lkj-test @PTHREAD_CFLAGS@ @PTHREAD_LIBS@ @STDLIB_FLAG@ +Requires: kj = @VERSION@ +Cflags: -I${includedir} @PTHREAD_CFLAGS@ @STDLIB_FLAG@ @CAPNP_LITE_FLAG@ diff --git a/c++/pkgconfig/kj-tls.pc.in b/c++/pkgconfig/kj-tls.pc.in new file mode 100644 index 0000000000..421255efbe --- /dev/null +++ b/c++/pkgconfig/kj-tls.pc.in @@ -0,0 +1,11 @@ +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: KJ TLS Adapters +Description: Basic utility library called KJ (TLS part) +Version: @VERSION@ +Libs: -L${libdir} -lkj-tls @PTHREAD_CFLAGS@ @PTHREAD_LIBS@ @STDLIB_FLAG@ +Requires: kj-async = @VERSION@ +Cflags: -I${includedir} @PTHREAD_CFLAGS@ @STDLIB_FLAG@ @CAPNP_LITE_FLAG@ diff --git a/c++/kj.pc.in b/c++/pkgconfig/kj.pc.in similarity index 100% rename from c++/kj.pc.in rename to c++/pkgconfig/kj.pc.in diff --git a/c++/regenerate-bootstraps.sh b/c++/regenerate-bootstraps.sh index e0f350a750..d806b835ea 100755 --- a/c++/regenerate-bootstraps.sh +++ b/c++/regenerate-bootstraps.sh @@ -5,7 +5,7 @@ set -euo pipefail export PATH=$PWD/bin:$PWD:$PATH capnp compile -Isrc --no-standard-import --src-prefix=src -oc++:src \ - src/capnp/c++.capnp src/capnp/schema.capnp \ + src/capnp/c++.capnp src/capnp/schema.capnp src/capnp/stream.capnp \ src/capnp/compiler/lexer.capnp src/capnp/compiler/grammar.capnp \ src/capnp/rpc.capnp src/capnp/rpc-twoparty.capnp src/capnp/persistent.capnp \ src/capnp/compat/json.capnp diff --git a/c++/samples/CMakeLists.txt b/c++/samples/CMakeLists.txt index b9ccba05e1..6a36b17520 100644 --- a/c++/samples/CMakeLists.txt +++ b/c++/samples/CMakeLists.txt @@ -23,13 +23,16 @@ find_package(CapnProto CONFIG REQUIRED) capnp_generate_cpp(addressbookSources addressbookHeaders addressbook.capnp) add_executable(addressbook addressbook.c++ ${addressbookSources}) -target_link_libraries(addressbook CapnProto::capnp) +target_link_libraries(addressbook PRIVATE CapnProto::capnp) target_include_directories(addressbook PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -capnp_generate_cpp(calculatorSources calculatorHeaders calculator.capnp) -add_executable(calculator-client calculator-client.c++ ${calculatorSources}) -add_executable(calculator-server calculator-server.c++ ${calculatorSources}) -target_link_libraries(calculator-client CapnProto::capnp-rpc) -target_link_libraries(calculator-server CapnProto::capnp-rpc) -target_include_directories(calculator-client PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -target_include_directories(calculator-server PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +# Don't build the rpc sample if find_package() found an installation of Cap'n Proto lite. +if(TARGET CapnProto::capnp-rpc) + capnp_generate_cpp(calculatorSources calculatorHeaders calculator.capnp) + add_executable(calculator-client calculator-client.c++ ${calculatorSources}) + add_executable(calculator-server calculator-server.c++ ${calculatorSources}) + target_link_libraries(calculator-client PRIVATE CapnProto::capnp-rpc) + target_link_libraries(calculator-server PRIVATE CapnProto::capnp-rpc) + target_include_directories(calculator-client PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + target_include_directories(calculator-server PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +endif() diff --git a/c++/samples/addressbook.c++ b/c++/samples/addressbook.c++ index f71e595039..df93b3b6d2 100644 --- a/c++/samples/addressbook.c++ +++ b/c++/samples/addressbook.c++ @@ -23,12 +23,12 @@ // // If Cap'n Proto is installed, build the sample like: // capnp compile -oc++ addressbook.capnp -// c++ -std=c++11 -Wall addressbook.c++ addressbook.capnp.c++ `pkg-config --cflags --libs capnp` -o addressbook +// c++ -std=c++14 -Wall addressbook.c++ addressbook.capnp.c++ `pkg-config --cflags --libs capnp` -o addressbook // // If Cap'n Proto is not installed, but the source is located at $SRC and has been // compiled in $BUILD (often both are simply ".." from here), you can do: // $BUILD/capnp compile -I$SRC/src -o$BUILD/capnpc-c++ addressbook.capnp -// c++ -std=c++11 -Wall addressbook.c++ addressbook.capnp.c++ -I$SRC/src -L$BUILD/.libs -lcapnp -lkj -o addressbook +// c++ -std=c++14 -Wall addressbook.c++ addressbook.capnp.c++ -I$SRC/src -L$BUILD/.libs -lcapnp -lkj -o addressbook // // Run like: // ./addressbook write | ./addressbook read @@ -113,6 +113,8 @@ void printAddressBook(int fd) { } } +#if !CAPNP_LITE + #include "addressbook.capnp.h" #include #include @@ -260,8 +262,9 @@ void dynamicPrintMessage(int fd, StructSchema schema) { std::cout << std::endl; } +#endif // !CAPNP_LITE + int main(int argc, char* argv[]) { - StructSchema schema = Schema::from(); if (argc != 2) { std::cerr << "Missing arg." << std::endl; return 1; @@ -269,10 +272,14 @@ int main(int argc, char* argv[]) { writeAddressBook(1); } else if (strcmp(argv[1], "read") == 0) { printAddressBook(0); +#if !CAPNP_LITE } else if (strcmp(argv[1], "dwrite") == 0) { + StructSchema schema = Schema::from(); dynamicWriteAddressBook(1, schema); } else if (strcmp(argv[1], "dread") == 0) { + StructSchema schema = Schema::from(); dynamicPrintMessage(0, schema); +#endif } else { std::cerr << "Invalid arg: " << argv[1] << std::endl; return 1; diff --git a/c++/samples/test.sh b/c++/samples/test.sh index 9d335aab54..59e90b3f98 100755 --- a/c++/samples/test.sh +++ b/c++/samples/test.sh @@ -6,16 +6,16 @@ set -exuo pipefail capnpc -oc++ addressbook.capnp -c++ -std=c++11 -Wall addressbook.c++ addressbook.capnp.c++ \ +c++ -std=c++14 -Wall addressbook.c++ addressbook.capnp.c++ \ $(pkg-config --cflags --libs capnp) -o addressbook ./addressbook write | ./addressbook read ./addressbook dwrite | ./addressbook dread rm addressbook addressbook.capnp.c++ addressbook.capnp.h capnpc -oc++ calculator.capnp -c++ -std=c++11 -Wall calculator-client.c++ calculator.capnp.c++ \ +c++ -std=c++14 -Wall calculator-client.c++ calculator.capnp.c++ \ $(pkg-config --cflags --libs capnp-rpc) -o calculator-client -c++ -std=c++11 -Wall calculator-server.c++ calculator.capnp.c++ \ +c++ -std=c++14 -Wall calculator-server.c++ calculator.capnp.c++ \ $(pkg-config --cflags --libs capnp-rpc) -o calculator-server rm -f /tmp/capnp-calculator-example-$$ ./calculator-server unix:/tmp/capnp-calculator-example-$$ & diff --git a/c++/src/CMakeLists.txt b/c++/src/CMakeLists.txt index 8cbea8d877..3035301436 100644 --- a/c++/src/CMakeLists.txt +++ b/c++/src/CMakeLists.txt @@ -4,33 +4,31 @@ if(BUILD_TESTING) include(CTest) - if (EXTERNAL_CAPNP) - # Setup CAPNP_GENERATE_CPP for compiling test schemas + if(EXTERNAL_CAPNP) + # Set up CAPNP_GENERATE_CPP for compiling test schemas find_package(CapnProto CONFIG QUIET) - if (NOT CapnProto_FOUND) - # Try and find the executables from an autotools-based installation - # Setup paths to the schema compiler for generating ${test_capnp_files} - if(NOT EXTERNAL_CAPNP AND NOT CAPNP_LITE) - set(CAPNP_EXECUTABLE $) - set(CAPNPC_CXX_EXECUTABLE $) - else() - # Allow paths to tools to be set with either environment variables or find_program() - if (NOT CAPNP_EXECUTABLE) - if (DEFINED ENV{CAPNP}) - set(CAPNP_EXECUTABLE "$ENV{CAPNP}") - else() - find_program(CAPNP_EXECUTABLE capnp) - endif() + if(NOT CapnProto_FOUND) + # No working installation of Cap'n Proto found, so fall back to searching the environment. + # + # We search for the external capnp compiler binaries via $CAPNP, $CAPNPC_CXX, and + # find_program(). find_program() will use various paths in its search, among them + # ${CMAKE_PREFIX_PATH}/bin and $PATH. + + if(NOT CAPNP_EXECUTABLE) + if(DEFINED ENV{CAPNP}) + set(CAPNP_EXECUTABLE "$ENV{CAPNP}") + else() + find_program(CAPNP_EXECUTABLE capnp) endif() + endif() - if(NOT CAPNPC_CXX_EXECUTABLE) - if (DEFINED ENV{CAPNPC_CXX}) - set(CAPNPC_CXX_EXECUTABLE "$ENV{CAPNPC_CXX}") - else() - # Also search in the same directory that `capnp` was found in - get_filename_component(capnp_dir "${CAPNP_EXECUTABLE}" DIRECTORY) - find_program(CAPNPC_CXX_EXECUTABLE capnpc-c++ HINTS "${capnp_dir}") - endif() + if(NOT CAPNPC_CXX_EXECUTABLE) + if(DEFINED ENV{CAPNPC_CXX}) + set(CAPNPC_CXX_EXECUTABLE "$ENV{CAPNPC_CXX}") + else() + # Also search in the same directory that `capnp` was found in + get_filename_component(capnp_dir "${CAPNP_EXECUTABLE}" DIRECTORY) + find_program(CAPNPC_CXX_EXECUTABLE capnpc-c++ HINTS "${capnp_dir}") endif() endif() endif() diff --git a/c++/src/benchmark/capnproto-common.h b/c++/src/benchmark/capnproto-common.h index 1705f913db..8200328c11 100644 --- a/c++/src/benchmark/capnproto-common.h +++ b/c++/src/benchmark/capnproto-common.h @@ -19,8 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_BENCHMARK_CAPNP_COMMON_H_ -#define CAPNP_BENCHMARK_CAPNP_COMMON_H_ +#pragma once #if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) #pragma GCC system_header @@ -419,5 +418,3 @@ struct BenchmarkTypes { } // namespace capnp } // namespace benchmark } // namespace capnp - -#endif // CAPNP_BENCHMARK_CAPNP_COMMON_H_ diff --git a/c++/src/benchmark/common.h b/c++/src/benchmark/common.h index aa3ac3cad8..a42c722e47 100644 --- a/c++/src/benchmark/common.h +++ b/c++/src/benchmark/common.h @@ -19,8 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_BENCHMARK_COMMON_H_ -#define CAPNP_BENCHMARK_COMMON_H_ +#pragma once #if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) #pragma GCC system_header @@ -293,5 +292,3 @@ int benchmarkMain(int argc, char* argv[]) { } // namespace capnp } // namespace benchmark - -#endif // CAPNP_BENCHMARK_COMMON_H_ diff --git a/c++/src/benchmark/protobuf-carsales.c++ b/c++/src/benchmark/protobuf-carsales.c++ index 40477097ab..7190251a0f 100644 --- a/c++/src/benchmark/protobuf-carsales.c++ +++ b/c++/src/benchmark/protobuf-carsales.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if !CAPNP_NO_PROTOBUF_BENCHMARK + #include "carsales.pb.h" #include "protobuf-common.h" @@ -139,3 +141,5 @@ int main(int argc, char* argv[]) { capnp::benchmark::protobuf::BenchmarkTypes, capnp::benchmark::protobuf::CarSalesTestCase>(argc, argv); } + +#endif // !CAPNP_NO_PROTOBUF_BENCHMARK diff --git a/c++/src/benchmark/protobuf-catrank.c++ b/c++/src/benchmark/protobuf-catrank.c++ index a3036b237d..648bdb7cd4 100644 --- a/c++/src/benchmark/protobuf-catrank.c++ +++ b/c++/src/benchmark/protobuf-catrank.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if !CAPNP_NO_PROTOBUF_BENCHMARK + #include "catrank.pb.h" #include "protobuf-common.h" @@ -128,3 +130,5 @@ int main(int argc, char* argv[]) { capnp::benchmark::protobuf::BenchmarkTypes, capnp::benchmark::protobuf::CatRankTestCase>(argc, argv); } + +#endif // !CAPNP_NO_PROTOBUF_BENCHMARK diff --git a/c++/src/benchmark/protobuf-eval.c++ b/c++/src/benchmark/protobuf-eval.c++ index db27b7378a..b197a0ea71 100644 --- a/c++/src/benchmark/protobuf-eval.c++ +++ b/c++/src/benchmark/protobuf-eval.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if !CAPNP_NO_PROTOBUF_BENCHMARK + #include "eval.pb.h" #include "protobuf-common.h" @@ -116,3 +118,5 @@ int main(int argc, char* argv[]) { capnp::benchmark::protobuf::BenchmarkTypes, capnp::benchmark::protobuf::ExpressionTestCase>(argc, argv); } + +#endif // !CAPNP_NO_PROTOBUF_BENCHMARK diff --git a/c++/src/benchmark/runner.c++ b/c++/src/benchmark/runner.c++ index 5ff07567d1..155324921a 100644 --- a/c++/src/benchmark/runner.c++ +++ b/c++/src/benchmark/runner.c++ @@ -186,7 +186,7 @@ TestResult runTest(Product product, TestCase testCase, Mode mode, Reuse reuse, } char itersStr[64]; - sprintf(itersStr, "%llu", (long long unsigned int)iters); + snprintf(itersStr, sizeof(itersStr), "%llu", (long long unsigned int)iters); argv[4] = itersStr; argv[5] = nullptr; diff --git a/c++/src/capnp/BUILD.bazel b/c++/src/capnp/BUILD.bazel new file mode 100644 index 0000000000..f11a7a24aa --- /dev/null +++ b/c++/src/capnp/BUILD.bazel @@ -0,0 +1,278 @@ +load("@capnp-cpp//src/capnp:cc_capnp_library.bzl", "cc_capnp_library") + +cc_library( + name = "capnp", + srcs = [ + "any.c++", + "arena.c++", + "blob.c++", + "c++.capnp.c++", + "dynamic.c++", + "layout.c++", + "list.c++", + "message.c++", + "schema.c++", + "schema.capnp.c++", + "schema-loader.c++", + "serialize.c++", + "serialize-packed.c++", + "stream.capnp.c++", + "stringify.c++", + ], + hdrs = [ + "any.h", + "arena.h", + "blob.h", + "c++.capnp.h", + "capability.h", + "common.h", + "dynamic.h", + "endian.h", + "generated-header-support.h", + "layout.h", + "list.h", + "membrane.h", + "message.h", + "orphan.h", + "pointer-helpers.h", + "pretty-print.h", + "raw-schema.h", + "schema.capnp.h", + "schema.h", + "schema-lite.h", + "schema-loader.h", + "schema-parser.h", + "serialize.h", + "serialize-async.h", + "serialize-packed.h", + "serialize-text.h", + "stream.capnp.h", + ], + include_prefix = "capnp", + visibility = ["//visibility:public"], + deps = [ + "//src/kj:kj-async", + ], +) + +cc_library( + name = "capnp-rpc", + srcs = [ + "capability.c++", + "dynamic-capability.c++", + "ez-rpc.c++", + "membrane.c++", + "persistent.capnp.c++", + "reconnect.c++", + "rpc.c++", + "rpc.capnp.c++", + "rpc-twoparty.c++", + "rpc-twoparty.capnp.c++", + "serialize-async.c++", + ], + hdrs = [ + "ez-rpc.h", + "persistent.capnp.h", + "reconnect.h", + "rpc.capnp.h", + "rpc.h", + "rpc-prelude.h", + "rpc-twoparty.capnp.h", + "rpc-twoparty.h", + ], + include_prefix = "capnp", + visibility = ["//visibility:public"], + deps = [ + ":capnp", + ], +) + +cc_library( + name = "capnpc", + srcs = [ + "compiler/compiler.c++", + "compiler/error-reporter.c++", + "compiler/generics.c++", + "compiler/grammar.capnp.c++", + "compiler/lexer.c++", + "compiler/lexer.capnp.c++", + "compiler/node-translator.c++", + "compiler/parser.c++", + "compiler/type-id.c++", + "schema-parser.c++", + "serialize-text.c++", + ], + hdrs = [ + "compiler/compiler.h", + "compiler/error-reporter.h", + "compiler/generics.h", + "compiler/grammar.capnp.h", + "compiler/lexer.capnp.h", + "compiler/lexer.h", + "compiler/module-loader.h", + "compiler/node-translator.h", + "compiler/parser.h", + "compiler/resolver.h", + "compiler/type-id.h", + ], + include_prefix = "capnp", + visibility = ["//visibility:public"], + deps = [ + ":capnp", + ], +) + +cc_binary( + name = "capnp_tool", + srcs = [ + "compiler/capnp.c++", + "compiler/module-loader.c++", + ], + visibility = ["//visibility:public"], + deps = [ + ":capnpc", + "//src/capnp/compat:json", + ], +) + +cc_binary( + name = "capnpc-c++", + srcs = [ + "compiler/capnpc-c++.c++", + ], + visibility = ["//visibility:public"], + deps = [ + ":capnpc", + ], +) + +cc_binary( + name = "capnpc-capnp", + srcs = [ + "compiler/capnpc-capnp.c++", + ], + visibility = ["//visibility:public"], + deps = [ + ":capnpc", + ], +) + +# capnp files that are implicitly available for import to any .capnp. +filegroup( + name = "capnp_system_library", + srcs = [ + "c++.capnp", + "schema.capnp", + "stream.capnp", + "//src/capnp/compat:json.capnp", + ], + visibility = ["//visibility:public"], +) + +# library to link with every cc_capnp_library +cc_library( + name = "capnp_runtime", + visibility = ["//visibility:public"], + # include json since it is not exposed as cc_capnp_library + deps = [ + ":capnp", + "//src/capnp/compat:json", + ], +) + +filegroup( + name = "testdata", + srcs = glob(["testdata/**/*"]), +) + +cc_capnp_library( + name = "capnp_test", + srcs = [ + "test.capnp", + "test-import.capnp", + "test-import2.capnp", + ], + data = [ + "c++.capnp", + "schema.capnp", + "stream.capnp", + ":testdata", + ], + include_prefix = "capnp", + src_prefix = "src", +) + +cc_library( + name = "capnp-test", + srcs = ["test-util.c++"], + hdrs = ["test-util.h"], + deps = [ + ":capnp-rpc", + ":capnp_test", + ":capnpc", + "//src/kj:kj-test", + ], + visibility = [":__subpackages__" ] +) + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + deps = [":capnp-test"], +) for f in [ + "any-test.c++", + "blob-test.c++", + "canonicalize-test.c++", + "common-test.c++", + "capability-test.c++", + "compiler/lexer-test.c++", + "compiler/type-id-test.c++", + "dynamic-test.c++", + "encoding-test.c++", + "endian-test.c++", + "ez-rpc-test.c++", + "layout-test.c++", + "membrane-test.c++", + "message-test.c++", + "orphan-test.c++", + "reconnect-test.c++", + "rpc-test.c++", + "rpc-twoparty-test.c++", + "schema-test.c++", + "schema-loader-test.c++", + "schema-parser-test.c++", + "serialize-async-test.c++", + "serialize-packed-test.c++", + "serialize-test.c++", + "serialize-text-test.c++", + "stringify-test.c++", +]] + +cc_test( + name = "endian-reverse-test", + srcs = ["endian-reverse-test.c++"], + deps = [":capnp-test"], + target_compatible_with = select({ + "@platforms//os:windows": ["@platforms//:incompatible"], + "//conditions:default": [], + }), +) + +cc_library( + name = "endian-test-base", + hdrs = ["endian-test.c++"], + deps = [":capnp-test"], +) + +cc_test( + name = "endian-fallback-test", + srcs = ["endian-fallback-test.c++"], + deps = [":endian-test-base"], +) + +cc_test( + name = "fuzz-test", + size = "large", + srcs = ["fuzz-test.c++"], + deps = [":capnp-test"], +) diff --git a/c++/src/capnp/CMakeLists.txt b/c++/src/capnp/CMakeLists.txt index a71d510b32..9980fde617 100644 --- a/c++/src/capnp/CMakeLists.txt +++ b/c++/src/capnp/CMakeLists.txt @@ -10,6 +10,7 @@ set(capnp_sources_lite any.c++ message.c++ schema.capnp.c++ + stream.capnp.c++ serialize.c++ serialize-packed.c++ ) @@ -40,6 +41,7 @@ set(capnp_headers dynamic.h schema.h schema.capnp.h + stream.capnp.h schema-lite.h schema-loader.h schema-parser.h @@ -52,20 +54,28 @@ set(capnp_headers generated-header-support.h raw-schema.h ) +set(capnp_compat_headers + compat/std-iterator.h +) set(capnp_schemas c++.capnp schema.capnp + stream.capnp ) add_library(capnp ${capnp_sources}) add_library(CapnProto::capnp ALIAS capnp) target_link_libraries(capnp PUBLIC kj) #make sure external consumers don't need to manually set the include dirs +get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) target_include_directories(capnp INTERFACE - $ + $ $ ) +# Ensure the library has a version set to match autotools build +set_target_properties(capnp PROPERTIES VERSION ${VERSION}) install(TARGETS capnp ${INSTALL_TARGETS_DEFAULT_ARGS}) install(FILES ${capnp_headers} ${capnp_schemas} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/capnp") +install(FILES ${capnp_compat_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/capnp/compat") set(capnp-rpc_sources serialize-async.c++ @@ -97,6 +107,8 @@ if(NOT CAPNP_LITE) add_library(capnp-rpc ${capnp-rpc_sources}) add_library(CapnProto::capnp-rpc ALIAS capnp-rpc) target_link_libraries(capnp-rpc PUBLIC capnp kj-async kj) + # Ensure the library has a version set to match autotools build + set_target_properties(capnp-rpc PROPERTIES VERSION ${VERSION}) install(TARGETS capnp-rpc ${INSTALL_TARGETS_DEFAULT_ARGS}) install(FILES ${capnp-rpc_headers} ${capnp-rpc_schemas} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/capnp") endif() @@ -118,19 +130,40 @@ if(NOT CAPNP_LITE) add_library(capnp-json ${capnp-json_sources}) add_library(CapnProto::capnp-json ALIAS capnp-json) target_link_libraries(capnp-json PUBLIC capnp kj-async kj) + # Ensure the library has a version set to match autotools build + set_target_properties(capnp-json PROPERTIES VERSION ${VERSION}) install(TARGETS capnp-json ${INSTALL_TARGETS_DEFAULT_ARGS}) install(FILES ${capnp-json_headers} ${capnp-json_schemas} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/capnp/compat") endif() +# capnp-websocket ======================================================================== + +set(capnp-websocket_sources + compat/websocket-rpc.c++ +) +set(capnp-websocket_headers + compat/websocket-rpc.h +) +if(NOT CAPNP_LITE) + add_library(capnp-websocket ${capnp-websocket_sources}) + add_library(CapnProto::capnp-websocket ALIAS capnp-websocket) + target_link_libraries(capnp-websocket PUBLIC capnp capnp-rpc kj-http kj-async kj) + # Ensure the library has a version set to match autotools build + set_target_properties(capnp-websocket PROPERTIES VERSION ${VERSION}) + install(TARGETS capnp-websocket ${INSTALL_TARGETS_DEFAULT_ARGS}) + install(FILES ${capnp-websocket_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/capnp/compat") +endif() + # Tools/Compilers ============================================================== set(capnpc_sources - compiler/md5.c++ + compiler/type-id.c++ compiler/error-reporter.c++ compiler/lexer.capnp.c++ compiler/lexer.c++ compiler/grammar.capnp.c++ compiler/parser.c++ + compiler/generics.c++ compiler/node-translator.c++ compiler/compiler.c++ schema-parser.c++ @@ -139,6 +172,8 @@ set(capnpc_sources if(NOT CAPNP_LITE) add_library(capnpc ${capnpc_sources}) target_link_libraries(capnpc PUBLIC capnp kj) + # Ensure the library has a version set to match autotools build + set_target_properties(capnpc PROPERTIES VERSION ${VERSION}) install(TARGETS capnpc ${INSTALL_TARGETS_DEFAULT_ARGS}) install(FILES ${capnpc_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/capnp") endif() @@ -148,11 +183,15 @@ if(NOT CAPNP_LITE) compiler/module-loader.c++ compiler/capnp.c++ ) - target_link_libraries(capnp_tool capnpc capnp kj) + target_link_libraries(capnp_tool capnpc capnp-json capnp kj) set_target_properties(capnp_tool PROPERTIES OUTPUT_NAME capnp) set_target_properties(capnp_tool PROPERTIES CAPNP_INCLUDE_DIRECTORY $,$> ) + target_compile_definitions(capnp_tool PRIVATE + "CAPNP_INCLUDE_DIR=\"${CMAKE_INSTALL_FULL_INCLUDEDIR}\"" + "VERSION=\"${VERSION}\"" + ) add_executable(capnpc_cpp compiler/capnpc-c++.c++ @@ -173,8 +212,14 @@ if(NOT CAPNP_LITE) install(TARGETS capnp_tool capnpc_cpp capnpc_capnp ${INSTALL_TARGETS_DEFAULT_ARGS}) - # Symlink capnpc -> capnp - install(CODE "execute_process(COMMAND \"${CMAKE_COMMAND}\" -E create_symlink capnp \"\$ENV{DESTDIR}${CMAKE_INSTALL_FULL_BINDIR}/capnpc\")") + if(WIN32) + # On Windows platforms symlinks are not guaranteed to support. Also different version of CMake handle create_symlink in a different way. + # The most portable way in this case just copy the file. + install(CODE "execute_process(COMMAND \"${CMAKE_COMMAND}\" -E copy \"\$ENV{DESTDIR}${CMAKE_INSTALL_FULL_BINDIR}/capnp${CMAKE_EXECUTABLE_SUFFIX}\" \"\$ENV{DESTDIR}${CMAKE_INSTALL_FULL_BINDIR}/capnpc${CMAKE_EXECUTABLE_SUFFIX}\")") + else() + # Symlink capnpc -> capnp + install(CODE "execute_process(COMMAND \"${CMAKE_COMMAND}\" -E create_symlink capnp${CMAKE_EXECUTABLE_SUFFIX} \"\$ENV{DESTDIR}${CMAKE_INSTALL_FULL_BINDIR}/capnpc${CMAKE_EXECUTABLE_SUFFIX}\")") + endif() endif() # NOT CAPNP_LITE # Tests ======================================================================== @@ -184,19 +229,30 @@ if(BUILD_TESTING) test.capnp test-import.capnp test-import2.capnp + compat/json-test.capnp ) - # Add "/capnp" to match the path used to import the files in the test sources - set(CAPNPC_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/test_capnp/capnp") - include_directories("${CMAKE_CURRENT_BINARY_DIR}/test_capnp") # Note: no "/capnp" + set(CAPNPC_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/test_capnp") + include_directories("${CAPNPC_OUTPUT_DIR}") file(MAKE_DIRECTORY "${CAPNPC_OUTPUT_DIR}") + # Tell capnp_generate_cpp to set --src-prefix to our parent directory. This allows us to pass our + # .capnp files relative to this directory, but have their canonical name end up as + # capnp/test.capnp, capnp/test-import.capnp, etc. + get_filename_component(CAPNPC_SRC_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}" DIRECTORY) capnp_generate_cpp(test_capnp_cpp_files test_capnp_h_files ${test_capnp_files}) + # TODO(cleanup): capnp-tests and capnp-heavy-tests both depend on the test.capnp output files. In + # a parallel Makefile-based build (maybe others?), they can race and cause the custom capnp + # command in capnp_generate_cpp() to run twice. To get around this I'm using a custom target to + # force CMake to generate race-free Makefiles. Remove this garbage when we move to a + # target-based capnp_generate() command, as that will make CMake do the right thing by default. + add_custom_target(test_capnp DEPENDS ${test_capnp_cpp_files} ${test_capnp_h_files}) + if(CAPNP_LITE) set(test_libraries capnp kj-test kj) else() - set(test_libraries capnp-json capnp-rpc capnp capnpc kj-async kj-test kj) + set(test_libraries capnp-json capnp-rpc capnp-websocket capnp capnpc kj-http kj-async kj-test kj) endif() add_executable(capnp-tests @@ -218,6 +274,7 @@ if(BUILD_TESTING) ${test_capnp_h_files} ) target_link_libraries(capnp-tests ${test_libraries}) + add_dependencies(capnp-tests test_capnp) add_dependencies(check capnp-tests) add_test(NAME capnp-tests-run COMMAND capnp-tests) @@ -237,9 +294,10 @@ if(BUILD_TESTING) rpc-twoparty-test.c++ ez-rpc-test.c++ compiler/lexer-test.c++ - compiler/md5-test.c++ + compiler/type-id-test.c++ test-util.c++ compat/json-test.c++ + compat/websocket-rpc-test.c++ ${test_capnp_cpp_files} ${test_capnp_h_files} ) @@ -250,6 +308,7 @@ if(BUILD_TESTING) ) endif() + add_dependencies(capnp-heavy-tests test_capnp) add_dependencies(check capnp-heavy-tests) add_test(NAME capnp-heavy-tests-run COMMAND capnp-heavy-tests) @@ -259,3 +318,8 @@ if(BUILD_TESTING) add_test(NAME capnp-evolution-tests-run COMMAND capnp-evolution-tests) endif() # NOT CAPNP_LITE endif() # BUILD_TESTING + +if(DEFINED ENV{LIB_FUZZING_ENGINE}) + add_executable(capnp_llvm_fuzzer_testcase llvm-fuzzer-testcase.c++ test-util.c++ test-util.h ${test_capnp_cpp_files} ${test_capnp_h_files}) + target_link_libraries(capnp_llvm_fuzzer_testcase capnp-rpc capnp kj kj-async capnp-json $ENV{LIB_FUZZING_ENGINE}) +endif() diff --git a/c++/src/capnp/any-test.c++ b/c++/src/capnp/any-test.c++ index dab6ecd56f..c5b76a1d2f 100644 --- a/c++/src/capnp/any-test.c++ +++ b/c++/src/capnp/any-test.c++ @@ -130,7 +130,7 @@ TEST(Any, AnyStruct) { EXPECT_EQ(48, b.getDataSection().size()); EXPECT_EQ(20, b.getPointerSection().size()); -#if !_MSC_VER // TODO(msvc): ICE on the necessary constructor; see any.h. +#if !_MSC_VER || defined(__clang__) // TODO(msvc): ICE on the necessary constructor; see any.h. b = root.getAnyPointerField().getAs(); EXPECT_EQ(48, b.getDataSection().size()); EXPECT_EQ(20, b.getPointerSection().size()); @@ -144,7 +144,7 @@ TEST(Any, AnyStruct) { EXPECT_EQ(48, r.getDataSection().size()); EXPECT_EQ(20, r.getPointerSection().size()); -#if !_MSC_VER // TODO(msvc): ICE on the necessary constructor; see any.h. +#if !_MSC_VER || defined(__clang__) // TODO(msvc): ICE on the necessary constructor; see any.h. r = root.getAnyPointerField().getAs().asReader(); EXPECT_EQ(48, r.getDataSection().size()); EXPECT_EQ(20, r.getPointerSection().size()); @@ -201,7 +201,7 @@ TEST(Any, AnyList) { EXPECT_EQ(48, alb.as>()[0].getDataSection().size()); EXPECT_EQ(20, alb.as>()[0].getPointerSection().size()); -#if !_MSC_VER // TODO(msvc): ICE on the necessary constructor; see any.h. +#if !_MSC_VER || defined(__clang__) // TODO(msvc): ICE on the necessary constructor; see any.h. alb = root.getAnyPointerField().getAs>(); EXPECT_EQ(2, alb.size()); EXPECT_EQ(48, alb.as>()[0].getDataSection().size()); @@ -218,7 +218,7 @@ TEST(Any, AnyList) { EXPECT_EQ(48, alr.as>()[0].getDataSection().size()); EXPECT_EQ(20, alr.as>()[0].getPointerSection().size()); -#if !_MSC_VER // TODO(msvc): ICE on the necessary constructor; see any.h. +#if !_MSC_VER || defined(__clang__) // TODO(msvc): ICE on the necessary constructor; see any.h. alr = root.getAnyPointerField().getAs>().asReader(); EXPECT_EQ(2, alr.size()); EXPECT_EQ(48, alr.as>()[0].getDataSection().size()); diff --git a/c++/src/capnp/any.c++ b/c++/src/capnp/any.c++ index 520dc29bbe..5439bf545f 100644 --- a/c++/src/capnp/any.c++ +++ b/c++/src/capnp/any.c++ @@ -79,7 +79,7 @@ kj::Own AnyPointer::Pipeline::asCap() { #endif // !CAPNP_LITE -Equality AnyStruct::Reader::equals(AnyStruct::Reader right) { +Equality AnyStruct::Reader::equals(AnyStruct::Reader right) const { auto dataL = getDataSection(); size_t dataSizeL = dataL.size(); while(dataSizeL > 0 && dataL[dataSizeL - 1] == 0) { @@ -150,7 +150,7 @@ kj::StringPtr KJ_STRINGIFY(Equality res) { KJ_UNREACHABLE; } -Equality AnyList::Reader::equals(AnyList::Reader right) { +Equality AnyList::Reader::equals(AnyList::Reader right) const { if(size() != right.size()) { return Equality::NOT_EQUAL; } @@ -209,7 +209,7 @@ Equality AnyList::Reader::equals(AnyList::Reader right) { KJ_UNREACHABLE; } -Equality AnyPointer::Reader::equals(AnyPointer::Reader right) { +Equality AnyPointer::Reader::equals(AnyPointer::Reader right) const { if(getPointerType() != right.getPointerType()) { return Equality::NOT_EQUAL; } @@ -227,7 +227,7 @@ Equality AnyPointer::Reader::equals(AnyPointer::Reader right) { KJ_UNREACHABLE; } -bool AnyPointer::Reader::operator==(AnyPointer::Reader right) { +bool AnyPointer::Reader::operator==(AnyPointer::Reader right) const { switch(equals(right)) { case Equality::EQUAL: return true; @@ -240,7 +240,7 @@ bool AnyPointer::Reader::operator==(AnyPointer::Reader right) { KJ_UNREACHABLE; } -bool AnyStruct::Reader::operator==(AnyStruct::Reader right) { +bool AnyStruct::Reader::operator==(AnyStruct::Reader right) const { switch(equals(right)) { case Equality::EQUAL: return true; @@ -253,7 +253,7 @@ bool AnyStruct::Reader::operator==(AnyStruct::Reader right) { KJ_UNREACHABLE; } -bool AnyList::Reader::operator==(AnyList::Reader right) { +bool AnyList::Reader::operator==(AnyList::Reader right) const { switch(equals(right)) { case Equality::EQUAL: return true; diff --git a/c++/src/capnp/any.h b/c++/src/capnp/any.h index 6df9dc8dc2..94b527dc3d 100644 --- a/c++/src/capnp/any.h +++ b/c++/src/capnp/any.h @@ -19,17 +19,16 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_ANY_H_ -#define CAPNP_ANY_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "layout.h" #include "pointer-helpers.h" #include "orphan.h" #include "list.h" +#include // work-around macro conflict with `VOID` +#include + +CAPNP_BEGIN_HEADER namespace capnp { @@ -105,9 +104,9 @@ struct AnyPointer { inline bool isList() const { return getPointerType() == PointerType::LIST; } inline bool isCapability() const { return getPointerType() == PointerType::CAPABILITY; } - Equality equals(AnyPointer::Reader right); - bool operator==(AnyPointer::Reader right); - inline bool operator!=(AnyPointer::Reader right) { + Equality equals(AnyPointer::Reader right) const; + bool operator==(AnyPointer::Reader right) const; + inline bool operator!=(AnyPointer::Reader right) const { return !(*this == right); } @@ -159,13 +158,13 @@ struct AnyPointer { inline bool isList() { return getPointerType() == PointerType::LIST; } inline bool isCapability() { return getPointerType() == PointerType::CAPABILITY; } - inline Equality equals(AnyPointer::Reader right) { + inline Equality equals(AnyPointer::Reader right) const { return asReader().equals(right); } - inline bool operator==(AnyPointer::Reader right) { + inline bool operator==(AnyPointer::Reader right) const { return asReader() == right; } - inline bool operator!=(AnyPointer::Reader right) { + inline bool operator!=(AnyPointer::Reader right) const { return !(*this == right); } @@ -408,6 +407,10 @@ struct List { inline Iterator begin() const { return Iterator(this, 0); } inline Iterator end() const { return Iterator(this, size()); } + inline MessageSize totalSize() const { + return reader.totalSize().asPublic(); + } + private: _::ListReader reader; template @@ -461,10 +464,12 @@ class AnyStruct::Reader { inline Reader(T&& value) : _reader(_::PointerHelpers>::getInternalReader(kj::fwd(value))) {} - kj::ArrayPtr getDataSection() { + inline MessageSize totalSize() const { return _reader.totalSize().asPublic(); } + + kj::ArrayPtr getDataSection() const { return _reader.getDataSectionAsBlob(); } - List::Reader getPointerSection() { + List::Reader getPointerSection() const { return List::Reader(_reader.getPointerSectionAsList()); } @@ -472,9 +477,9 @@ class AnyStruct::Reader { return _reader.canonicalize(); } - Equality equals(AnyStruct::Reader right); - bool operator==(AnyStruct::Reader right); - inline bool operator!=(AnyStruct::Reader right) { + Equality equals(AnyStruct::Reader right) const; + bool operator==(AnyStruct::Reader right) const; + inline bool operator!=(AnyStruct::Reader right) const { return !(*this == right); } @@ -483,6 +488,11 @@ class AnyStruct::Reader { // T must be a struct type. return typename T::Reader(_reader); } + + template + ReaderFor as(StructSchema schema) const; + // T must be DynamicStruct. Defined in dynamic.h. + private: _::StructReader _reader; @@ -498,7 +508,7 @@ class AnyStruct::Builder { inline Builder(decltype(nullptr)) {} inline Builder(_::StructBuilder builder): _builder(builder) {} -#if !_MSC_VER // TODO(msvc): MSVC ICEs on this. Try restoring when compiler improves. +#if !_MSC_VER || defined(__clang__) // TODO(msvc): MSVC ICEs on this. Try restoring when compiler improves. template ) == Kind::STRUCT>> inline Builder(T&& value) : _builder(_::PointerHelpers>::getInternalBuilder(kj::fwd(value))) {} @@ -511,13 +521,13 @@ class AnyStruct::Builder { return List::Builder(_builder.getPointerSectionAsList()); } - inline Equality equals(AnyStruct::Reader right) { + inline Equality equals(AnyStruct::Reader right) const { return asReader().equals(right); } - inline bool operator==(AnyStruct::Reader right) { + inline bool operator==(AnyStruct::Reader right) const { return asReader() == right; } - inline bool operator!=(AnyStruct::Reader right) { + inline bool operator!=(AnyStruct::Reader right) const { return !(*this == right); } @@ -529,6 +539,11 @@ class AnyStruct::Builder { // T must be a struct type. return typename T::Builder(_builder); } + + template + BuilderFor as(StructSchema schema); + // T must be DynamicStruct. Defined in dynamic.h. + private: _::StructBuilder _builder; friend class Orphanage; @@ -573,6 +588,10 @@ class List::Reader { inline Iterator begin() const { return Iterator(this, 0); } inline Iterator end() const { return Iterator(this, size()); } + inline MessageSize totalSize() const { + return reader.totalSize().asPublic(); + } + private: _::ListReader reader; template @@ -621,24 +640,28 @@ class AnyList::Reader { inline Reader(): _reader(ElementSize::VOID) {} inline Reader(_::ListReader reader): _reader(reader) {} -#if !_MSC_VER // TODO(msvc): MSVC ICEs on this. Try restoring when compiler improves. +#if !_MSC_VER || defined(__clang__) // TODO(msvc): MSVC ICEs on this. Try restoring when compiler improves. template ) == Kind::LIST>> inline Reader(T&& value) : _reader(_::PointerHelpers>::getInternalReader(kj::fwd(value))) {} #endif - inline ElementSize getElementSize() { return _reader.getElementSize(); } - inline uint size() { return unbound(_reader.size() / ELEMENTS); } + inline ElementSize getElementSize() const { return _reader.getElementSize(); } + inline uint size() const { return unbound(_reader.size() / ELEMENTS); } - inline kj::ArrayPtr getRawBytes() { return _reader.asRawBytes(); } + inline kj::ArrayPtr getRawBytes() const { return _reader.asRawBytes(); } - Equality equals(AnyList::Reader right); - bool operator==(AnyList::Reader right); - inline bool operator!=(AnyList::Reader right) { + Equality equals(AnyList::Reader right) const; + bool operator==(AnyList::Reader right) const; + inline bool operator!=(AnyList::Reader right) const { return !(*this == right); } - template ReaderFor as() { + inline MessageSize totalSize() const { + return _reader.totalSize().asPublic(); + } + + template ReaderFor as() const { // T must be List. return ReaderFor(_reader); } @@ -657,7 +680,7 @@ class AnyList::Builder { inline Builder(decltype(nullptr)): _builder(ElementSize::VOID) {} inline Builder(_::ListBuilder builder): _builder(builder) {} -#if !_MSC_VER // TODO(msvc): MSVC ICEs on this. Try restoring when compiler improves. +#if !_MSC_VER || defined(__clang__) // TODO(msvc): MSVC ICEs on this. Try restoring when compiler improves. template ) == Kind::LIST>> inline Builder(T&& value) : _builder(_::PointerHelpers>::getInternalBuilder(kj::fwd(value))) {} @@ -666,11 +689,11 @@ class AnyList::Builder { inline ElementSize getElementSize() { return _builder.getElementSize(); } inline uint size() { return unbound(_builder.size() / ELEMENTS); } - Equality equals(AnyList::Reader right); - inline bool operator==(AnyList::Reader right) { + Equality equals(AnyList::Reader right) const; + inline bool operator==(AnyList::Reader right) const{ return asReader() == right; } - inline bool operator!=(AnyList::Reader right) { + inline bool operator!=(AnyList::Reader right) const{ return !(*this == right); } @@ -713,6 +736,27 @@ struct PipelineOp { }; }; +inline uint KJ_HASHCODE(const PipelineOp& op) { + switch (op.type) { + case PipelineOp::NOOP: return kj::hashCode(op.type); + case PipelineOp::GET_POINTER_FIELD: return kj::hashCode(op.type, op.pointerIndex); + } + KJ_CLANG_KNOWS_THIS_IS_UNREACHABLE_BUT_GCC_DOESNT +} + +inline bool operator==(const PipelineOp& a, const PipelineOp& b) { + if (a.type != b.type) return false; + switch (a.type) { + case PipelineOp::NOOP: return true; + case PipelineOp::GET_POINTER_FIELD: return a.pointerIndex == b.pointerIndex; + } + KJ_CLANG_KNOWS_THIS_IS_UNREACHABLE_BUT_GCC_DOESNT +} + +inline bool operator!=(const PipelineOp& a, const PipelineOp& b) { + return !(a == b); +} + class PipelineHook { // Represents a currently-running call, and implements pipelined requests on its result. @@ -730,6 +774,9 @@ class PipelineHook { template > static inline kj::Own from(Pipeline&& pipeline); + template > + static inline PipelineHook& from(Pipeline& pipeline); + private: template struct FromImpl; }; @@ -1052,6 +1099,9 @@ struct PipelineHook::FromImpl { static inline kj::Own apply(typename T::Pipeline&& pipeline) { return from(kj::mv(pipeline._typeless)); } + static inline PipelineHook& apply(typename T::Pipeline& pipeline) { + return from(pipeline._typeless); + } }; template <> @@ -1059,6 +1109,9 @@ struct PipelineHook::FromImpl { static inline kj::Own apply(AnyPointer::Pipeline&& pipeline) { return kj::mv(pipeline.hook); } + static inline PipelineHook& apply(AnyPointer::Pipeline& pipeline) { + return *pipeline.hook; + } }; template @@ -1066,8 +1119,13 @@ inline kj::Own PipelineHook::from(Pipeline&& pipeline) { return FromImpl::apply(kj::fwd(pipeline)); } +template +inline PipelineHook& PipelineHook::from(Pipeline& pipeline) { + return FromImpl::apply(pipeline); +} + #endif // !CAPNP_LITE } // namespace capnp -#endif // CAPNP_ANY_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/arena.c++ b/c++/src/capnp/arena.c++ index a681b3d594..77061db1d7 100644 --- a/c++/src/capnp/arena.c++ +++ b/c++/src/capnp/arena.c++ @@ -42,10 +42,10 @@ void ReadLimiter::unread(WordCount64 amount) { // Be careful not to overflow here. Since ReadLimiter has no thread-safety, it's possible that // the limit value was not updated correctly for one or more reads, and therefore unread() could // overflow it even if it is only unreading bytes that were actually read. - uint64_t oldValue = limit; + uint64_t oldValue = readLimit(); uint64_t newValue = oldValue + unbound(amount / WORDS); if (newValue > oldValue) { - limit = newValue; + setLimit(newValue); } } @@ -71,6 +71,22 @@ static SegmentWordCount verifySegmentSize(size_t size) { }); } +static SegmentWordCount verifySegment(kj::ArrayPtr segment) { +#if !CAPNP_ALLOW_UNALIGNED + KJ_REQUIRE(reinterpret_cast(segment.begin()) % sizeof(void*) == 0, + "Detected unaligned data in Cap'n Proto message. Messages must be aligned to the " + "architecture's word size. Yes, even on x86: Unaligned access is undefined behavior " + "under the C/C++ language standard, and compilers can and do assume alignment for the " + "purpose of optimizations. Unaligned access may lead to crashes or subtle corruption. " + "For example, GCC will use SIMD instructions in optimizations, and those instrsuctions " + "require alignment. If you really insist on taking your changes with unaligned data, " + "compile the Cap'n Proto library with -DCAPNP_ALLOW_UNALIGNED to remove this check.") { + break; + } +#endif + return verifySegmentSize(segment.size()); +} + inline ReaderArena::ReaderArena(MessageReader* message, const word* firstSegment, SegmentWordCount firstSegmentSize) : message(message), @@ -78,13 +94,23 @@ inline ReaderArena::ReaderArena(MessageReader* message, const word* firstSegment segment0(this, SegmentId(0), firstSegment, firstSegmentSize, &readLimiter) {} inline ReaderArena::ReaderArena(MessageReader* message, kj::ArrayPtr firstSegment) - : ReaderArena(message, firstSegment.begin(), verifySegmentSize(firstSegment.size())) {} + : ReaderArena(message, firstSegment.begin(), verifySegment(firstSegment)) {} ReaderArena::ReaderArena(MessageReader* message) : ReaderArena(message, message->getSegment(0)) {} ReaderArena::~ReaderArena() noexcept(false) {} +size_t ReaderArena::sizeInWords() { + size_t total = segment0.getArray().size(); + + for (uint i = 1; ; i++) { + SegmentReader* segment = tryGetSegment(SegmentId(i)); + if (segment == nullptr) return total; + total += unboundAs(segment->getSize() / WORDS); + } +} + SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { if (id == SegmentId(0)) { if (segment0.getArray() == nullptr) { @@ -98,11 +124,10 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { SegmentMap* segments = nullptr; KJ_IF_MAYBE(s, *lock) { - auto iter = s->get()->find(id.value); - if (iter != s->get()->end()) { - return iter->second; + KJ_IF_MAYBE(segment, s->find(id.value)) { + return *segment; } - segments = *s; + segments = s; } kj::ArrayPtr newSegment = message->getSegment(id.value); @@ -110,19 +135,17 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { return nullptr; } - SegmentWordCount newSegmentSize = verifySegmentSize(newSegment.size()); + SegmentWordCount newSegmentSize = verifySegment(newSegment); if (*lock == nullptr) { // OK, the segment exists, so allocate the map. - auto s = kj::heap(); - segments = s; - *lock = kj::mv(s); + segments = &lock->emplace(); } auto segment = kj::heap( this, id, newSegment.begin(), newSegmentSize, &readLimiter); SegmentReader* result = segment; - segments->insert(std::make_pair(id.value, mv(segment))); + segments->insert(id.value, kj::mv(segment)); return result; } @@ -141,7 +164,7 @@ BuilderArena::BuilderArena(MessageBuilder* message, kj::ArrayPtr segments) : message(message), segment0(this, SegmentId(0), segments[0].space.begin(), - verifySegmentSize(segments[0].space.size()), + verifySegment(segments[0].space), &this->dummyLimiter, verifySegmentSize(segments[0].wordsUsed)) { if (segments.size() > 1) { kj::Vector> builders(segments.size() - 1); @@ -149,7 +172,7 @@ BuilderArena::BuilderArena(MessageBuilder* message, uint i = 1; for (auto& segment: segments.slice(1, segments.size())) { builders.add(kj::heap( - this, SegmentId(i++), segment.space.begin(), verifySegmentSize(segment.space.size()), + this, SegmentId(i++), segment.space.begin(), verifySegment(segment.space), &this->dummyLimiter, verifySegmentSize(segment.wordsUsed))); } @@ -168,6 +191,24 @@ BuilderArena::BuilderArena(MessageBuilder* message, BuilderArena::~BuilderArena() noexcept(false) {} +size_t BuilderArena::sizeInWords() { + KJ_IF_MAYBE(segmentState, moreSegments) { + size_t total = segment0.currentlyAllocated().size(); + for (auto& builder: segmentState->get()->builders) { + total += builder->currentlyAllocated().size(); + } + return total; + } else { + if (segment0.getArena() == nullptr) { + // We haven't actually allocated any segments yet. + return 0; + } else { + // We have only one segment so far. + return segment0.currentlyAllocated().size(); + } + } +} + SegmentBuilder* BuilderArena::getSegment(SegmentId id) { // This method is allowed to fail if the segment ID is not valid. if (id == SegmentId(0)) { @@ -186,7 +227,7 @@ BuilderArena::AllocateResult BuilderArena::allocate(SegmentWordCount amount) { if (segment0.getArena() == nullptr) { // We're allocating the first segment. kj::ArrayPtr ptr = message->allocateSegment(unbound(amount / WORDS)); - auto actualSize = verifySegmentSize(ptr.size()); + auto actualSize = verifySegment(ptr); // Re-allocate segment0 in-place. This is a bit of a hack, but we have not returned any // pointers to this segment yet, so it should be fine. @@ -314,28 +355,38 @@ void BuilderArena::reportReadLimitReached() { } } -#if !CAPNP_LITE kj::Maybe> BuilderArena::LocalCapTable::extractCap(uint index) { +#if CAPNP_LITE + KJ_UNIMPLEMENTED("no cap tables in lite mode"); +#else if (index < capTable.size()) { return capTable[index].map([](kj::Own& cap) { return cap->addRef(); }); } else { return nullptr; } +#endif } uint BuilderArena::LocalCapTable::injectCap(kj::Own&& cap) { +#if CAPNP_LITE + KJ_UNIMPLEMENTED("no cap tables in lite mode"); +#else uint result = capTable.size(); capTable.add(kj::mv(cap)); return result; +#endif } void BuilderArena::LocalCapTable::dropCap(uint index) { +#if CAPNP_LITE + KJ_UNIMPLEMENTED("no cap tables in lite mode"); +#else KJ_ASSERT(index < capTable.size(), "Invalid capability descriptor in message.") { return; } capTable[index] = nullptr; +#endif } -#endif // !CAPNP_LITE } // namespace _ (private) } // namespace capnp diff --git a/c++/src/capnp/arena.h b/c++/src/capnp/arena.h index e3cacf32ad..aeaff8448d 100644 --- a/c++/src/capnp/arena.h +++ b/c++/src/capnp/arena.h @@ -19,12 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_ARENA_H_ -#define CAPNP_ARENA_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #ifndef CAPNP_PRIVATE #error "This header is only meant to be included by Cap'n Proto's own source code." @@ -38,12 +33,14 @@ #include "common.h" #include "message.h" #include "layout.h" -#include +#include #if !CAPNP_LITE #include "capability.h" #endif // !CAPNP_LITE +CAPNP_BEGIN_HEADER + namespace capnp { #if !CAPNP_LITE @@ -93,12 +90,32 @@ class ReadLimiter { // some data. private: - volatile uint64_t limit; - // Current limit, decremented each time catRead() is called. Volatile because multiple threads - // could be trying to modify it at once. (This is not real thread-safety, but good enough for - // the purpose of this class. See class comment.) + alignas(8) volatile uint64_t limit; + // Current limit, decremented each time catRead() is called. We modify this variable using atomics + // with "relaxed" thread safety to make TSAN happy (on ARM & x86 this is no different from a + // regular read/write of the variable). See the class comment for why this is OK (previously we + // used a regular volatile variable - this is just to make ASAN happy). + // + // alignas(8) is the default on 64-bit systems, but needed on 32-bit to avoid an expensive + // unaligned atomic operation. + + KJ_DISALLOW_COPY_AND_MOVE(ReadLimiter); - KJ_DISALLOW_COPY(ReadLimiter); + KJ_ALWAYS_INLINE(void setLimit(uint64_t newLimit)) { +#if defined(__GNUC__) || defined(__clang__) + __atomic_store_n(&limit, newLimit, __ATOMIC_RELAXED); +#else + limit = newLimit; +#endif + } + + KJ_ALWAYS_INLINE(uint64_t readLimit() const) { +#if defined(__GNUC__) || defined(__clang__) + return __atomic_load_n(&limit, __ATOMIC_RELAXED); +#else + return limit; +#endif + } }; #if !CAPNP_LITE @@ -157,11 +174,11 @@ class SegmentReader { kj::ArrayPtr ptr; // size guaranteed to fit in SEGMENT_WORD_COUNT_BITS bits ReadLimiter* readLimiter; - KJ_DISALLOW_COPY(SegmentReader); + KJ_DISALLOW_COPY_AND_MOVE(SegmentReader); friend class SegmentBuilder; - static void abortCheckObjectFault(); + [[noreturn]] static void abortCheckObjectFault(); // Called in debug mode in cases that would segfault in opt mode. (Should be impossible!) }; @@ -207,9 +224,9 @@ class SegmentBuilder: public SegmentReader { bool readOnly; - void throwNotWritable(); + [[noreturn]] void throwNotWritable(); - KJ_DISALLOW_COPY(SegmentBuilder); + KJ_DISALLOW_COPY_AND_MOVE(SegmentBuilder); }; class Arena { @@ -229,7 +246,9 @@ class ReaderArena final: public Arena { public: explicit ReaderArena(MessageReader* message); ~ReaderArena() noexcept(false); - KJ_DISALLOW_COPY(ReaderArena); + KJ_DISALLOW_COPY_AND_MOVE(ReaderArena); + + size_t sizeInWords(); // implements Arena ------------------------------------------------ SegmentReader* tryGetSegment(SegmentId id) override; @@ -242,8 +261,8 @@ class ReaderArena final: public Arena { // Optimize for single-segment messages so that small messages are handled quickly. SegmentReader segment0; - typedef std::unordered_map> SegmentMap; - kj::MutexGuarded>> moreSegments; + typedef kj::HashMap> SegmentMap; + kj::MutexGuarded> moreSegments; // We need to mutex-guard the segment map because we lazily initialize segments when they are // first requested, but a Reader is allowed to be used concurrently in multiple threads. Luckily // this only applies to large messages. @@ -263,7 +282,9 @@ class BuilderArena final: public Arena { explicit BuilderArena(MessageBuilder* message); BuilderArena(MessageBuilder* message, kj::ArrayPtr segments); ~BuilderArena() noexcept(false); - KJ_DISALLOW_COPY(BuilderArena); + KJ_DISALLOW_COPY_AND_MOVE(BuilderArena); + + size_t sizeInWords(); inline SegmentBuilder* getRootSegment() { return &segment0; } @@ -289,6 +310,10 @@ class BuilderArena final: public Arena { return &localCapTable; } + kj::Own<_::CapTableBuilder> releaseLocalCapTable() { + return kj::heap(kj::mv(localCapTable)); + } + SegmentBuilder* getSegment(SegmentId id); // Get the segment with the given id. Crashes or throws an exception if no such segment exists. @@ -322,13 +347,13 @@ class BuilderArena final: public Arena { MessageBuilder* message; ReadLimiter dummyLimiter; - class LocalCapTable: public CapTableBuilder { -#if !CAPNP_LITE + class LocalCapTable final: public CapTableBuilder { public: kj::Maybe> extractCap(uint index) override; uint injectCap(kj::Own&& cap) override; void dropCap(uint index) override; +#if !CAPNP_LITE private: kj::Vector>> capTable; #endif // ! CAPNP_LITE @@ -361,17 +386,19 @@ inline ReadLimiter::ReadLimiter() inline ReadLimiter::ReadLimiter(WordCount64 limit): limit(unbound(limit / WORDS)) {} -inline void ReadLimiter::reset(WordCount64 limit) { this->limit = unbound(limit / WORDS); } +inline void ReadLimiter::reset(WordCount64 limit) { + setLimit(unbound(limit / WORDS)); +} inline bool ReadLimiter::canRead(WordCount64 amount, Arena* arena) { // Be careful not to store an underflowed value into `limit`, even if multiple threads are // decrementing it. - uint64_t current = limit; + uint64_t current = readLimit(); if (KJ_UNLIKELY(unbound(amount / WORDS) > current)) { arena->reportReadLimitReached(); return false; } else { - limit = current - unbound(amount / WORDS); + setLimit(current - unbound(amount / WORDS)); return true; } } @@ -493,4 +520,4 @@ inline bool SegmentBuilder::tryExtend(word* from, word* to) { } // namespace _ (private) } // namespace capnp -#endif // CAPNP_ARENA_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/blob.h b/c++/src/capnp/blob.h index 07d40759c7..847e2ea368 100644 --- a/c++/src/capnp/blob.h +++ b/c++/src/capnp/blob.h @@ -19,18 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_BLOB_H_ -#define CAPNP_BLOB_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include #include "common.h" #include +CAPNP_BEGIN_HEADER + namespace capnp { struct Data { @@ -102,7 +99,9 @@ class Data::Builder: public kj::ArrayPtr { inline Builder(kj::Array& value): ArrayPtr(value) {} inline Builder(ArrayPtr value): ArrayPtr(value) {} - inline Data::Reader asReader() const { return Data::Reader(*this); } + inline Data::Reader asReader() const { + return Data::Reader(kj::implicitCast&>(*this)); + } inline operator Reader() const { return asReader(); } }; @@ -220,4 +219,4 @@ inline kj::ArrayPtr Text::Builder::slice(size_t start, size_t end) { } // namespace capnp -#endif // CAPNP_BLOB_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/bootstrap-test.ekam-rule b/c++/src/capnp/bootstrap-test.ekam-rule index e70ee02854..ae93b9ace8 100755 --- a/c++/src/capnp/bootstrap-test.ekam-rule +++ b/c++/src/capnp/bootstrap-test.ekam-rule @@ -51,8 +51,9 @@ fi mkdir -p tmp/capnp/bootstrap-test-tmp -INPUTS="capnp/c++.capnp capnp/schema.capnp capnp/compiler/lexer.capnp capnp/compiler/grammar.capnp \ -capnp/rpc.capnp capnp/rpc-twoparty.capnp capnp/persistent.capnp" +INPUTS="capnp/c++.capnp capnp/schema.capnp capnp/stream.capnp capnp/compiler/lexer.capnp \ +capnp/compiler/grammar.capnp capnp/rpc.capnp capnp/rpc-twoparty.capnp capnp/persistent.capnp \ +capnp/compat/json.capnp" SRC_INPUTS="" for file in $INPUTS; do diff --git a/c++/src/capnp/c++.capnp b/c++/src/capnp/c++.capnp index 2bda547179..9eaff6d1d2 100644 --- a/c++/src/capnp/c++.capnp +++ b/c++/src/capnp/c++.capnp @@ -24,3 +24,25 @@ $namespace("capnp::annotations"); annotation namespace(file): Text; annotation name(field, enumerant, struct, enum, interface, method, param, group, union): Text; + +annotation allowCancellation(interface, method, file) :Void; +# Indicates that the server-side implementation of a method is allowed to be canceled when the +# client requests cancellation. Without this annotation, once a method call has been delivered to +# the server-side application code, any requests by the client to cancel it will be ignored, and +# the method will run to completion anyway. This applies even for local in-process calls. +# +# This behavior applies specifically to implementations that inherit from the C++ `Foo::Server` +# interface. The annotation won't affect DynamicCapability::Server implementations; they must set +# the cancellation mode at runtime. +# +# When applied to an interface rather than an individual method, the annotation applies to all +# methods in the interface. When applied to a file, it applies to all methods defined in the file. +# +# It's generally recommended that this annotation be applied to all methods. However, when doing +# so, it is important that the server implementation use cancellation-safe code. See: +# +# https://github.com/capnproto/capnproto/blob/master/kjdoc/tour.md#cancellation +# +# If your code is not cancellation-safe, then allowing cancellation might give a malicious client +# an easy way to induce use-after-free or other bugs in your server, by requesting cancellation +# when not expected. diff --git a/c++/src/capnp/c++.capnp.c++ b/c++/src/capnp/c++.capnp.c++ index 576d733b23..02378a9c88 100644 --- a/c++/src/capnp/c++.capnp.c++ +++ b/c++/src/capnp/c++.capnp.c++ @@ -32,7 +32,7 @@ static const ::capnp::_::AlignedData<21> b_b9c6f99ebf805f2c = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_b9c6f99ebf805f2c = { 0xb9c6f99ebf805f2c, b_b9c6f99ebf805f2c.words, 21, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_b9c6f99ebf805f2c, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_b9c6f99ebf805f2c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<20> b_f264a779fef191ce = { @@ -61,7 +61,38 @@ static const ::capnp::_::AlignedData<20> b_f264a779fef191ce = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_f264a779fef191ce = { 0xf264a779fef191ce, b_f264a779fef191ce.words, 20, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_f264a779fef191ce, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_f264a779fef191ce, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<22> b_ac7096ff8cfc9dce = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 206, 157, 252, 140, 255, 150, 112, 172, + 16, 0, 0, 0, 5, 0, 1, 3, + 129, 78, 48, 184, 123, 125, 248, 189, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 18, 1, 0, 0, + 37, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 32, 0, 0, 0, 3, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 43, + 43, 46, 99, 97, 112, 110, 112, 58, + 97, 108, 108, 111, 119, 67, 97, 110, + 99, 101, 108, 108, 97, 116, 105, 111, + 110, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_ac7096ff8cfc9dce = b_ac7096ff8cfc9dce.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_ac7096ff8cfc9dce = { + 0xac7096ff8cfc9dce, b_ac7096ff8cfc9dce.words, 22, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_ac7096ff8cfc9dce, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas diff --git a/c++/src/capnp/c++.capnp.h b/c++/src/capnp/c++.capnp.h index 6d9817fbde..444e5793bd 100644 --- a/c++/src/capnp/c++.capnp.h +++ b/c++/src/capnp/c++.capnp.h @@ -1,21 +1,26 @@ // Generated by Cap'n Proto compiler, DO NOT EDIT // source: c++.capnp -#ifndef CAPNP_INCLUDED_bdf87d7bb8304e81_ -#define CAPNP_INCLUDED_bdf87d7bb8304e81_ +#pragma once #include +#include -#if CAPNP_VERSION != 6001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace schemas { CAPNP_DECLARE_SCHEMA(b9c6f99ebf805f2c); CAPNP_DECLARE_SCHEMA(f264a779fef191ce); +CAPNP_DECLARE_SCHEMA(ac7096ff8cfc9dce); } // namespace schemas } // namespace capnp @@ -30,4 +35,5 @@ namespace annotations { } // namespace } // namespace -#endif // CAPNP_INCLUDED_bdf87d7bb8304e81_ +CAPNP_END_HEADER + diff --git a/c++/src/capnp/capability-test.c++ b/c++/src/capnp/capability-test.c++ index d15efd16e0..a645abce40 100644 --- a/c++/src/capnp/capability-test.c++ +++ b/c++/src/capnp/capability-test.c++ @@ -21,13 +21,13 @@ #include "schema.capnp.h" -#ifdef CAPNP_CAPABILITY_H_ +#ifdef CAPNP_CAPABILITY_H_INCLUDED #error "schema.capnp should not depend on capability.h, because it contains no interfaces." #endif #include -#ifndef CAPNP_CAPABILITY_H_ +#ifndef CAPNP_CAPABILITY_H_INCLUDED #error "test.capnp did not include capability.h." #endif @@ -80,6 +80,39 @@ TEST(Capability, Basic) { EXPECT_TRUE(barFailed); } +TEST(Capability, CapabilityList) { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + MallocMessageBuilder builder; + auto root = builder.initRoot(); + auto initCapList = root.initCapList(2); + + int callCount0 = 0; + int callCount1 = 0; + initCapList.set(0, kj::heap(callCount0)); + initCapList.set(1, kj::heap(callCount1)); + + auto capList = root.getCapList(); + auto cap0 = capList[0].castAs(); + auto cap1 = capList[1].castAs(); + + EXPECT_EQ(2u, root.getCapList().size()); + + auto request0 = cap0.fooRequest(); + request0.setI(123); + request0.setJ(true); + EXPECT_EQ("foo", request0.send().wait(waitScope).getX()); + + auto request1 = cap1.fooRequest(); + request1.setI(123); + request1.setJ(true); + EXPECT_EQ("foo", request1.send().wait(waitScope).getX()); + + EXPECT_EQ(1, callCount0); + EXPECT_EQ(1, callCount1); +} + TEST(Capability, Inheritance) { kj::EventLoop loop; kj::WaitScope waitScope(loop); @@ -145,6 +178,73 @@ TEST(Capability, Pipelining) { EXPECT_EQ(1, chainedCallCount); } +KJ_TEST("use pipeline after dropping response") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + int callCount = 0; + int chainedCallCount = 0; + test::TestPipeline::Client client(kj::heap(callCount)); + + auto request = client.getCapRequest(); + request.setN(234); + request.setInCap(test::TestInterface::Client(kj::heap(chainedCallCount))); + + auto promise = request.send(); + test::TestPipeline::GetCapResults::Pipeline pipeline = kj::mv(promise); + + { + auto response = promise.wait(waitScope); + KJ_EXPECT(response.getS() == "bar"); + } + + auto pipelineRequest = pipeline.getOutBox().getCap().fooRequest(); + pipelineRequest.setI(321); + auto pipelinePromise = pipelineRequest.send(); + + auto pipelineRequest2 = pipeline.getOutBox().getCap().castAs().graultRequest(); + auto pipelinePromise2 = pipelineRequest2.send(); + + auto response = pipelinePromise.wait(waitScope); + EXPECT_EQ("bar", response.getX()); + + auto response2 = pipelinePromise2.wait(waitScope); + checkTestMessage(response2); + + EXPECT_EQ(3, callCount); + EXPECT_EQ(1, chainedCallCount); +} + +KJ_TEST("context.setPipeline") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + int callCount = 0; + test::TestPipeline::Client client(kj::heap(callCount)); + + auto promise = client.getCapPipelineOnlyRequest().send(); + + auto pipelineRequest = promise.getOutBox().getCap().fooRequest(); + pipelineRequest.setI(321); + auto pipelinePromise = pipelineRequest.send(); + + auto pipelineRequest2 = promise.getOutBox().getCap().castAs().graultRequest(); + auto pipelinePromise2 = pipelineRequest2.send(); + + EXPECT_EQ(0, callCount); + + auto response = pipelinePromise.wait(waitScope); + EXPECT_EQ("bar", response.getX()); + + auto response2 = pipelinePromise2.wait(waitScope); + checkTestMessage(response2); + + EXPECT_EQ(3, callCount); + + // The original promise never completed. + KJ_EXPECT(!promise.poll(waitScope)); +} + TEST(Capability, TailCall) { kj::EventLoop loop; kj::WaitScope waitScope(loop); @@ -180,7 +280,7 @@ TEST(Capability, TailCall) { } TEST(Capability, AsyncCancelation) { - // Tests allowCancellation(). + // Tests cancellation. kj::EventLoop loop; kj::WaitScope waitScope(loop); @@ -558,7 +658,7 @@ public: // in 4.8.x nor in 4.9.4: // https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=781060 // - // Unfortunatley 4.9.2 is present on many Debian Jessie systems.. + // Unfortunately 4.9.2 is present on many Debian Jessie systems.. // // For the moment, we can get away with skipping the last line as the previous line // will set things up in a way that allows the test to complete successfully. @@ -1031,6 +1131,294 @@ TEST(Capability, TransferCap) { }).wait(waitScope); } +KJ_TEST("Promise> automatically reduces to RemotePromise") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + int callCount = 0; + test::TestInterface::Client client(kj::heap(callCount)); + + RemotePromise promise = kj::evalLater([&]() { + auto request = client.fooRequest(); + request.setI(123); + request.setJ(true); + return request.send(); + }); + + EXPECT_EQ(0, callCount); + auto response = promise.wait(waitScope); + EXPECT_EQ("foo", response.getX()); + EXPECT_EQ(1, callCount); +} + +KJ_TEST("Promise> automatically reduces to RemotePromise with pipelining") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + int callCount = 0; + int chainedCallCount = 0; + test::TestPipeline::Client client(kj::heap(callCount)); + + auto promise = kj::evalLater([&]() { + auto request = client.getCapRequest(); + request.setN(234); + request.setInCap(test::TestInterface::Client(kj::heap(chainedCallCount))); + return request.send(); + }); + + auto pipelineRequest = promise.getOutBox().getCap().fooRequest(); + pipelineRequest.setI(321); + auto pipelinePromise = pipelineRequest.send(); + + EXPECT_EQ(0, callCount); + EXPECT_EQ(0, chainedCallCount); + + auto response = pipelinePromise.wait(waitScope); + EXPECT_EQ("bar", response.getX()); + + EXPECT_EQ(2, callCount); + EXPECT_EQ(1, chainedCallCount); +} + +KJ_TEST("clone() with caps") { + int dummy = 0; + MallocMessageBuilder builder(2048); + auto root = builder.getRoot().initAs>(3); + root.set(0, kj::heap(dummy)); + root.set(1, kj::heap(dummy)); + root.set(2, kj::heap(dummy)); + + auto copyPtr = clone(root.asReader()); + auto& copy = *copyPtr; + + KJ_ASSERT(copy.size() == 3); + KJ_EXPECT(ClientHook::from(copy[0]).get() == ClientHook::from(root[0]).get()); + KJ_EXPECT(ClientHook::from(copy[1]).get() == ClientHook::from(root[1]).get()); + KJ_EXPECT(ClientHook::from(copy[2]).get() == ClientHook::from(root[2]).get()); + + KJ_EXPECT(ClientHook::from(copy[0]).get() != ClientHook::from(root[1]).get()); + KJ_EXPECT(ClientHook::from(copy[1]).get() != ClientHook::from(root[2]).get()); + KJ_EXPECT(ClientHook::from(copy[2]).get() != ClientHook::from(root[0]).get()); +} + +KJ_TEST("Streaming calls block subsequent calls") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto ownServer = kj::heap(); + auto& server = *ownServer; + test::TestStreaming::Client cap = kj::mv(ownServer); + + kj::Promise promise1 = nullptr, promise2 = nullptr, promise3 = nullptr; + + { + auto req = cap.doStreamIRequest(); + req.setI(123); + promise1 = req.send(); + } + + { + auto req = cap.doStreamJRequest(); + req.setJ(321); + promise2 = req.send(); + } + + { + auto req = cap.doStreamIRequest(); + req.setI(456); + promise3 = req.send(); + } + + auto promise4 = cap.finishStreamRequest().send(); + + KJ_EXPECT(server.iSum == 0); + KJ_EXPECT(server.jSum == 0); + + KJ_EXPECT(!promise1.poll(waitScope)); + KJ_EXPECT(!promise2.poll(waitScope)); + KJ_EXPECT(!promise3.poll(waitScope)); + KJ_EXPECT(!promise4.poll(waitScope)); + + KJ_EXPECT(server.iSum == 123); + KJ_EXPECT(server.jSum == 0); + + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + KJ_EXPECT(promise1.poll(waitScope)); + KJ_EXPECT(!promise2.poll(waitScope)); + KJ_EXPECT(!promise3.poll(waitScope)); + KJ_EXPECT(!promise4.poll(waitScope)); + + KJ_EXPECT(server.iSum == 123); + KJ_EXPECT(server.jSum == 321); + + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + KJ_EXPECT(promise1.poll(waitScope)); + KJ_EXPECT(promise2.poll(waitScope)); + KJ_EXPECT(!promise3.poll(waitScope)); + KJ_EXPECT(!promise4.poll(waitScope)); + + KJ_EXPECT(server.iSum == 579); + KJ_EXPECT(server.jSum == 321); + + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + KJ_EXPECT(promise1.poll(waitScope)); + KJ_EXPECT(promise2.poll(waitScope)); + KJ_EXPECT(promise3.poll(waitScope)); + KJ_EXPECT(promise4.poll(waitScope)); + + auto result = promise4.wait(waitScope); + KJ_EXPECT(result.getTotalI() == 579); + KJ_EXPECT(result.getTotalJ() == 321); +} + +KJ_TEST("Streaming calls can be canceled") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto ownServer = kj::heap(); + auto& server = *ownServer; + test::TestStreaming::Client cap = kj::mv(ownServer); + + kj::Promise promise1 = nullptr, promise2 = nullptr, promise3 = nullptr; + + { + auto req = cap.doStreamIRequest(); + req.setI(123); + promise1 = req.send(); + } + + { + auto req = cap.doStreamJRequest(); + req.setJ(321); + promise2 = req.send(); + } + + { + auto req = cap.doStreamIRequest(); + req.setI(456); + promise3 = req.send(); + } + + auto promise4 = cap.finishStreamRequest().send(); + + // Cancel the doStreamJ() request. + promise2 = nullptr; + + KJ_EXPECT(server.iSum == 0); + KJ_EXPECT(server.jSum == 0); + + KJ_EXPECT(!promise4.poll(waitScope)); + + KJ_EXPECT(server.iSum == 123); + KJ_EXPECT(server.jSum == 0); + + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + KJ_EXPECT(!promise4.poll(waitScope)); + + // The call to doStreamJ() was canceled, so the next call to doStreamI() happens immediately. + KJ_EXPECT(server.iSum == 579); + KJ_EXPECT(server.jSum == 0); + + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + KJ_EXPECT(promise4.poll(waitScope)); + + auto result = promise4.wait(waitScope); + KJ_EXPECT(result.getTotalI() == 579); + KJ_EXPECT(result.getTotalJ() == 0); +} + +KJ_TEST("Streaming call throwing cascades to following calls") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto ownServer = kj::heap(); + auto& server = *ownServer; + test::TestStreaming::Client cap = kj::mv(ownServer); + + server.jShouldThrow = true; + + kj::Promise promise1 = nullptr, promise2 = nullptr, promise3 = nullptr; + + { + auto req = cap.doStreamIRequest(); + req.setI(123); + promise1 = req.send(); + } + + { + auto req = cap.doStreamJRequest(); + req.setJ(321); + promise2 = req.send(); + } + + { + auto req = cap.doStreamIRequest(); + req.setI(456); + promise3 = req.send(); + } + + auto promise4 = cap.finishStreamRequest().send(); + + KJ_EXPECT(server.iSum == 0); + KJ_EXPECT(server.jSum == 0); + + KJ_EXPECT(!promise1.poll(waitScope)); + KJ_EXPECT(!promise2.poll(waitScope)); + KJ_EXPECT(!promise3.poll(waitScope)); + KJ_EXPECT(!promise4.poll(waitScope)); + + KJ_EXPECT(server.iSum == 123); + KJ_EXPECT(server.jSum == 0); + + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + KJ_EXPECT(promise1.poll(waitScope)); + KJ_EXPECT(promise2.poll(waitScope)); + KJ_EXPECT(promise3.poll(waitScope)); + KJ_EXPECT(promise4.poll(waitScope)); + + KJ_EXPECT(server.iSum == 123); + KJ_EXPECT(server.jSum == 321); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("throw requested", promise2.wait(waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("throw requested", promise3.wait(waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("throw requested", promise4.ignoreResult().wait(waitScope)); +} + +KJ_TEST("RevocableServer") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + class ServerImpl: public test::TestMembrane::Server { + public: + kj::Promise waitForever(WaitForeverContext context) override { + return kj::NEVER_DONE; + } + }; + + ServerImpl server; + + RevocableServer revocable(server); + + auto promise = revocable.getClient().waitForeverRequest().send(); + KJ_EXPECT(!promise.poll(waitScope)); + + revocable.revoke(); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "capability was revoked", + promise.ignoreResult().wait(waitScope)); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "capability was revoked", + revocable.getClient().waitForeverRequest().send().ignoreResult().wait(waitScope)); +} + } // namespace } // namespace _ } // namespace capnp diff --git a/c++/src/capnp/capability.c++ b/c++/src/capnp/capability.c++ index 61028fd070..9462c6c218 100644 --- a/c++/src/capnp/capability.c++ +++ b/c++/src/capnp/capability.c++ @@ -61,10 +61,6 @@ ClientHook::ClientHook() { setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory); } -void* ClientHook::getLocalServer(_::CapabilityServerSetBase& capServerSet) { - return nullptr; -} - // ======================================================================================= Capability::Client::Client(decltype(nullptr)) @@ -73,15 +69,38 @@ Capability::Client::Client(decltype(nullptr)) Capability::Client::Client(kj::Exception&& exception) : hook(newBrokenCap(kj::mv(exception))) {} -kj::Promise Capability::Server::internalUnimplemented( +kj::Promise> Capability::Client::getFd() { + auto fd = hook->getFd(); + if (fd != nullptr) { + return fd; + } else KJ_IF_MAYBE(promise, hook->whenMoreResolved()) { + return promise->attach(hook->addRef()).then([](kj::Own newHook) { + return Client(kj::mv(newHook)).getFd(); + }); + } else { + return kj::Maybe(nullptr); + } +} + +kj::Maybe> Capability::Server::shortenPath() { + return nullptr; +} + +Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented( const char* actualInterfaceName, uint64_t requestedTypeId) { - return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.", - actualInterfaceName, requestedTypeId); + return { + KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.", + actualInterfaceName, requestedTypeId), + false, true + }; } -kj::Promise Capability::Server::internalUnimplemented( +Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented( const char* interfaceName, uint64_t typeId, uint16_t methodId) { - return KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName, typeId, methodId); + return { + KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName, typeId, methodId), + false, true + }; } kj::Promise Capability::Server::internalUnimplemented( @@ -102,6 +121,10 @@ kj::Promise ClientHook::whenResolved() { } } +kj::Promise Capability::Client::whenResolved() { + return hook->whenResolved().attach(hook->addRef()); +} + // ======================================================================================= static inline uint firstSegmentSize(kj::Maybe sizeHint) { @@ -112,7 +135,7 @@ static inline uint firstSegmentSize(kj::Maybe sizeHint) { } } -class LocalResponse final: public ResponseHook, public kj::Refcounted { +class LocalResponse final: public ResponseHook { public: LocalResponse(kj::Maybe sizeHint) : message(firstSegmentSize(sizeHint)) {} @@ -120,12 +143,12 @@ public: MallocMessageBuilder message; }; -class LocalCallContext final: public CallContextHook, public kj::Refcounted { +class LocalCallContext final: public CallContextHook, public ResponseHook, public kj::Refcounted { public: LocalCallContext(kj::Own&& request, kj::Own clientRef, - kj::Own> cancelAllowedFulfiller) - : request(kj::mv(request)), clientRef(kj::mv(clientRef)), - cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {} + ClientHook::CallHints hints, bool isStreaming) + : request(kj::mv(request)), clientRef(kj::mv(clientRef)), hints(hints), + isStreaming(isStreaming) {} AnyPointer::Reader getParams() override { KJ_IF_MAYBE(r, request) { @@ -139,12 +162,17 @@ public: } AnyPointer::Builder getResults(kj::Maybe sizeHint) override { if (response == nullptr) { - auto localResponse = kj::refcounted(sizeHint); + auto localResponse = kj::heap(sizeHint); responseBuilder = localResponse->message.getRoot(); response = Response(responseBuilder.asReader(), kj::mv(localResponse)); } return responseBuilder; } + void setPipeline(kj::Own&& pipeline) override { + KJ_IF_MAYBE(f, tailCallPipelineFulfiller) { + f->get()->fulfill(AnyPointer::Pipeline(kj::mv(pipeline))); + } + } kj::Promise tailCall(kj::Own&& request) override { auto result = directTailCall(kj::mv(request)); KJ_IF_MAYBE(f, tailCallPipelineFulfiller) { @@ -155,22 +183,31 @@ public: ClientHook::VoidPromiseAndPipeline directTailCall(kj::Own&& request) override { KJ_REQUIRE(response == nullptr, "Can't call tailCall() after initializing the results struct."); - auto promise = request->send(); + if (hints.onlyPromisePipeline) { + return { + kj::NEVER_DONE, + PipelineHook::from(request->sendForPipeline()) + }; + } - auto voidPromise = promise.then([this](Response&& tailResponse) { - response = kj::mv(tailResponse); - }); + if (isStreaming) { + auto promise = request->sendStreaming(); + return { kj::mv(promise), getDisabledPipeline() }; + } else { + auto promise = request->send(); - return { kj::mv(voidPromise), PipelineHook::from(kj::mv(promise)) }; + auto voidPromise = promise.then([this](Response&& tailResponse) { + response = kj::mv(tailResponse); + }); + + return { kj::mv(voidPromise), PipelineHook::from(kj::mv(promise)) }; + } } kj::Promise onTailCall() override { auto paf = kj::newPromiseAndFulfiller(); tailCallPipelineFulfiller = kj::mv(paf.fulfiller); return kj::mv(paf.promise); } - void allowCancellation() override { - cancelAllowedFulfiller->fulfill(); - } kj::Own addRef() override { return kj::addRef(*this); } @@ -180,50 +217,41 @@ public: AnyPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null kj::Own clientRef; kj::Maybe>> tailCallPipelineFulfiller; - kj::Own> cancelAllowedFulfiller; + ClientHook::CallHints hints; + bool isStreaming; }; class LocalRequest final: public RequestHook { public: inline LocalRequest(uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint, kj::Own client) + kj::Maybe sizeHint, ClientHook::CallHints hints, + kj::Own client) : message(kj::heap(firstSegmentSize(sizeHint))), - interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {} + interfaceId(interfaceId), methodId(methodId), hints(hints), client(kj::mv(client)) {} RemotePromise send() override { - KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request."); + bool isStreaming = false; + return sendImpl(isStreaming); + } - // For the lambda capture. - uint64_t interfaceId = this->interfaceId; - uint16_t methodId = this->methodId; + kj::Promise sendStreaming() override { + // We don't do any special handling of streaming in RequestHook for local requests, because + // there is no latency to compensate for between the client and server in this case. However, + // we record whether the call was streaming, so that it can be preserved as a streaming call + // if the local capability later resolves to a remote capability. + bool isStreaming = true; + return sendImpl(isStreaming).ignoreResult(); + } - auto cancelPaf = kj::newPromiseAndFulfiller(); + AnyPointer::Pipeline sendForPipeline() override { + KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request."); + hints.onlyPromisePipeline = true; + bool isStreaming = false; auto context = kj::refcounted( - kj::mv(message), client->addRef(), kj::mv(cancelPaf.fulfiller)); - auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context)); - - // We have to make sure the call is not canceled unless permitted. We need to fork the promise - // so that if the client drops their copy, the promise isn't necessarily canceled. - auto forked = promiseAndPipeline.promise.fork(); - - // We daemonize one branch, but only after joining it with the promise that fires if - // cancellation is allowed. - forked.addBranch() - .attach(kj::addRef(*context)) - .exclusiveJoin(kj::mv(cancelPaf.promise)) - .detach([](kj::Exception&&) {}); // ignore exceptions - - // Now the other branch returns the response from the context. - auto promise = forked.addBranch().then(kj::mvCapture(context, - [](kj::Own&& context) { - context->getResults(MessageSize { 0, 0 }); // force response allocation - return kj::mv(KJ_ASSERT_NONNULL(context->response)); - })); - - // We return the other branch. - return RemotePromise( - kj::mv(promise), AnyPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline))); + kj::mv(message), client->addRef(), hints, isStreaming); + auto vpap = client->call(interfaceId, methodId, kj::addRef(*context), hints); + return AnyPointer::Pipeline(kj::mv(vpap.pipeline)); } const void* getBrand() override { @@ -235,7 +263,40 @@ public: private: uint64_t interfaceId; uint16_t methodId; + ClientHook::CallHints hints; kj::Own client; + + RemotePromise sendImpl(bool isStreaming) { + KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request."); + + auto context = kj::refcounted(kj::mv(message), client->addRef(), hints, isStreaming); + auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context), hints); + + // Now the other branch returns the response from the context. + auto promise = promiseAndPipeline.promise.then([context=kj::mv(context)]() mutable { + // force response allocation + auto reader = context->getResults(MessageSize { 0, 0 }).asReader(); + + if (context->isShared()) { + // We can't just move away context->response as `context` itself is still referenced by + // something -- probably a Pipeline object. As a bit of a hack, LocalCallContext itself + // implements ResponseHook so that we can just return a ref on it. + // + // TODO(cleanup): Maybe ResponseHook should be refcounted? Note that context->response + // might not necessarily contain a LocalResponse if it was resolved by a tail call, so + // we'd have to add refcounting to all ResponseHook implementations. + context->releaseParams(); // The call is done so params can definitely be dropped. + context->clientRef = nullptr; // Definitely not using the client cap anymore either. + return Response(reader, kj::mv(context)); + } else { + return kj::mv(KJ_ASSERT_NONNULL(context->response)); + } + }); + + // We return the other branch. + return RemotePromise( + kj::mv(promise), AnyPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline))); + } }; // ======================================================================================= @@ -279,6 +340,49 @@ private: kj::Promise selfResolutionOp; // Represents the operation which will set `redirect` when possible. + + kj::HashMap, kj::Own> clientMap; + // If the same pipelined cap is requested twice, we have to return the same object. This is + // necessary because each ClientHook we create is a QueuedClient which queues up calls. If we + // return a new one each time, there will be several queues, and ordering of calls will be lost + // between the queues. + // + // One case where this is particularly problematic is with promises resolved over RPC. Consider + // this case: + // + // * Alice holds a promise capability P pointing towards Bob. + // * Bob makes a call Q on an object hosted by Alice. + // * Without waiting for Q to complete, Bob obtains a pipelined-promise capability for Q's + // eventual result, P2. + // * Alice invokes a method M on P. The call is sent to Bob. + // * Bob resolves Alice's original promise P to P2. + // * Alice receives a Resolve message from Bob resolving P to Q's eventual result. + // * As a result, Alice calls getPipelinedCap() on the QueuedPipeline for Q's result, which + // returns a QueuedClient for that result, which we'll call QR1. + // * Alice also sends a Disembargo to Bob. + // * Alice calls a method M2 on P. This call is blocked locally waiting for the disembargo to + // complete. + // * Bob receives Alice's first method call, M. Since it's addressed to P, which later resolved + // to Q's result, Bob reflects the call back to Alice. + // * Alice receives the reflected call, which is addressed to Q's result. + // * Alice calls getPipelinedCap() on the QueuedPipeline for Q's result, which returns a + // QueuedClient for that result, which we'll call QR2. + // * Alice enqueues the call M on QR2. + // * Bob receives Alice's Disembargo message, and reflects it back. + // * Alices receives the Disembrago. + // * Alice unblocks the method cgall M2, which had been blocked on the embargo. + // * The call M2 is then equeued onto QR1. + // * Finally, the call Q completes. + // * This causes QR1 and QR2 to resolve to their final destinations. But if QR1 and QR2 are + // separate objects, then one of them must resolve first. QR1 was created first, so naturally + // it resolves first, followed by QR2. + // * Because QR1 resolves first, method call M2 is delivered first. + // * QR2 resolves second, so method call M1 is delivered next. + // * THIS IS THE WRONG ORDER! + // + // In order to avoid this problem, it's necessary for QR1 and QR2 to be the same object, so that + // they share the same call queue. In this case, M2 is correctly enqueued onto QR2 *after* M1 was + // enqueued on QR1, and so the method calls are delivered in the correct order. }; class QueuedClient final: public ClientHook, public kj::Refcounted { @@ -297,65 +401,47 @@ public: promiseForClientResolution(promise.addBranch().fork()) {} Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { auto hook = kj::heap( - interfaceId, methodId, sizeHint, kj::addRef(*this)); + interfaceId, methodId, sizeHint, hints, kj::addRef(*this)); auto root = hook->message->getRoot(); return Request(root, kj::mv(hook)); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { - // This is a bit complicated. We need to initiate this call later on. When we initiate the - // call, we'll get a void promise for its completion and a pipeline object. Right now, we have - // to produce a similar void promise and pipeline that will eventually be chained to those. - // The problem is, these are two independent objects, but they both depend on the result of - // one future call. - // - // So, we need to set up a continuation that will initiate the call later, then we need to - // fork the promise for that continuation in order to send the completion promise and the - // pipeline to their respective places. - // - // TODO(perf): Too much reference counting? Can we do better? Maybe a way to fork - // Promise> into Tuple, Promise>? - - struct CallResultHolder: public kj::Refcounted { - // Essentially acts as a refcounted \VoidPromiseAndPipeline, so that we can create a promise - // for it and fork that promise. - - VoidPromiseAndPipeline content; - // One branch of the fork will use content.promise, the other branch will use - // content.pipeline. Neither branch will touch the other's piece. - - inline CallResultHolder(VoidPromiseAndPipeline&& content): content(kj::mv(content)) {} - - kj::Own addRef() { return kj::addRef(*this); } - }; - - // Create a promise for the call initiation. - kj::ForkedPromise> callResultPromise = - promiseForCallForwarding.addBranch().then(kj::mvCapture(context, - [=](kj::Own&& context, kj::Own&& client){ - return kj::refcounted( - client->call(interfaceId, methodId, kj::mv(context))); - })).fork(); - - // Create a promise that extracts the pipeline from the call initiation, and construct our - // QueuedPipeline to chain to it. - auto pipelinePromise = callResultPromise.addBranch().then( - [](kj::Own&& callResult){ - return kj::mv(callResult->content.pipeline); - }); - auto pipeline = kj::refcounted(kj::mv(pipelinePromise)); + kj::Own&& context, CallHints hints) override { + if (hints.noPromisePipelining) { + // Optimize for no pipelining. + auto promise = promiseForCallForwarding.addBranch() + .then([=,context=kj::mv(context)](kj::Own&& client) mutable { + return client->call(interfaceId, methodId, kj::mv(context), hints).promise; + }); + return VoidPromiseAndPipeline { kj::mv(promise), getDisabledPipeline() }; + } else if (hints.onlyPromisePipeline) { + auto pipelinePromise = promiseForCallForwarding.addBranch() + .then([=,context=kj::mv(context)](kj::Own&& client) mutable { + return client->call(interfaceId, methodId, kj::mv(context), hints).pipeline; + }); + return VoidPromiseAndPipeline { + kj::NEVER_DONE, + kj::refcounted(kj::mv(pipelinePromise)) + }; + } else { + auto split = promiseForCallForwarding.addBranch() + .then([=,context=kj::mv(context)](kj::Own&& client) mutable { + auto vpap = client->call(interfaceId, methodId, kj::mv(context), hints); + return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline)); + }).split(); - // Create a promise that simply chains to the void promise produced by the call initiation. - auto completionPromise = callResultPromise.addBranch().then( - [](kj::Own&& callResult){ - return kj::mv(callResult->content.promise); - }); + kj::Promise completionPromise = kj::mv(kj::get<0>(split)); + kj::Promise> pipelinePromise = kj::mv(kj::get<1>(split)); - // OK, now we can actually return our thing. - return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::mv(pipeline) }; + auto pipeline = kj::refcounted(kj::mv(pipelinePromise)); + + // OK, now we can actually return our thing. + return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::mv(pipeline) }; + } } kj::Maybe getResolved() override { @@ -378,6 +464,14 @@ public: return nullptr; } + kj::Maybe getFd() override { + KJ_IF_MAYBE(r, redirect) { + return r->get()->getFd(); + } else { + return nullptr; + } + } + private: typedef kj::ForkedPromise> ClientHookPromiseFork; @@ -411,12 +505,15 @@ kj::Own QueuedPipeline::getPipelinedCap(kj::Array&& ops) KJ_IF_MAYBE(r, redirect) { return r->get()->getPipelinedCap(kj::mv(ops)); } else { - auto clientPromise = promise.addBranch().then(kj::mvCapture(ops, - [](kj::Array&& ops, kj::Own pipeline) { - return pipeline->getPipelinedCap(kj::mv(ops)); - })); - - return kj::refcounted(kj::mv(clientPromise)); + return clientMap.findOrCreate(ops.asPtr(), [&]() { + auto clientPromise = promise.addBranch() + .then([ops = KJ_MAP(op, ops) { return op; }](kj::Own pipeline) { + return pipeline->getPipelinedCap(kj::mv(ops)); + }); + return kj::HashMap, kj::Own>::Entry { + kj::mv(ops), kj::refcounted(kj::mv(clientPromise)) + }; + })->addRef(); } } @@ -443,30 +540,64 @@ private: class LocalClient final: public ClientHook, public kj::Refcounted { public: - LocalClient(kj::Own&& serverParam) - : server(kj::mv(serverParam)) { - server->thisHook = this; + LocalClient(kj::Own&& serverParam, bool revocable = false) { + auto& serverRef = *server.emplace(kj::mv(serverParam)); + serverRef.thisHook = this; + if (revocable) revoker.emplace(); + startResolveTask(serverRef); } LocalClient(kj::Own&& serverParam, - _::CapabilityServerSetBase& capServerSet, void* ptr) - : server(kj::mv(serverParam)), capServerSet(&capServerSet), ptr(ptr) { - server->thisHook = this; + _::CapabilityServerSetBase& capServerSet, void* ptr, + bool revocable = false) + : capServerSet(&capServerSet), ptr(ptr) { + auto& serverRef = *server.emplace(kj::mv(serverParam)); + serverRef.thisHook = this; + if (revocable) revoker.emplace(); + startResolveTask(serverRef); } ~LocalClient() noexcept(false) { - server->thisHook = nullptr; + KJ_IF_MAYBE(s, server) { + s->get()->thisHook = nullptr; + } + } + + void revoke(kj::Exception&& e) { + KJ_IF_MAYBE(s, server) { + KJ_ASSERT_NONNULL(revoker).cancel(e); + brokenException = kj::mv(e); + s->get()->thisHook = nullptr; + server = nullptr; + } } Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { + KJ_IF_MAYBE(r, resolved) { + // We resolved to a shortened path. New calls MUST go directly to the replacement capability + // so that their ordering is consistent with callers who call getResolved() to get direct + // access to the new capability. In particular it's important that we don't place these calls + // in our streaming queue. + return r->get()->newCall(interfaceId, methodId, sizeHint, hints); + } + auto hook = kj::heap( - interfaceId, methodId, sizeHint, kj::addRef(*this)); + interfaceId, methodId, sizeHint, hints, kj::addRef(*this)); auto root = hook->message->getRoot(); return Request(root, kj::mv(hook)); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { + kj::Own&& context, CallHints hints) override { + KJ_IF_MAYBE(r, resolved) { + // We resolved to a shortened path. New calls MUST go directly to the replacement capability + // so that their ordering is consistent with callers who call getResolved() to get direct + // access to the new capability. In particular it's important that we don't place these calls + // in our streaming queue. + return r->get()->call(interfaceId, methodId, kj::mv(context), hints); + } + auto contextPtr = context.get(); // We don't want to actually dispatch the call synchronously, because we don't want the callee @@ -478,66 +609,314 @@ public: // Note also that QueuedClient depends on this evalLater() to ensure that pipelined calls don't // complete before 'whenMoreResolved()' promises resolve. auto promise = kj::evalLater([this,interfaceId,methodId,contextPtr]() { - return server->dispatchCall(interfaceId, methodId, - CallContext(*contextPtr)); + if (blocked) { + return kj::newAdaptedPromise, BlockedCall>( + *this, interfaceId, methodId, *contextPtr); + } else { + return callInternal(interfaceId, methodId, *contextPtr); + } }).attach(kj::addRef(*this)); - // We have to fork this promise for the pipeline to receive a copy of the answer. - auto forked = promise.fork(); + if (hints.noPromisePipelining) { + // No need to set up pipelining.. + + // Make sure we release the params on return, since we would on the normal pipelining path. + // TODO(perf): Experiment with whether this is actually useful. It seems likely the params + // will be released soon anyway, so maybe this is a waste? + promise = promise.then([context=kj::mv(context)]() mutable { + context->releaseParams(); + }); + + // When we do apply pipelining, the use of `.fork()` has the side effect of eagerly + // evaluating the promise. To match the behavior here, we use `.eagerlyEvaluate()`. + // TODO(perf): Maybe we don't need to match behavior? It did break some tests but arguably + // those tests are weird and not what a real program would do... + promise = promise.eagerlyEvaluate(nullptr); + return VoidPromiseAndPipeline { kj::mv(promise), getDisabledPipeline() }; + } + + kj::Promise completionPromise = nullptr; + kj::Promise pipelineBranch = nullptr; + + if (hints.onlyPromisePipeline) { + pipelineBranch = kj::mv(promise); + completionPromise = kj::NEVER_DONE; + } else { + // We have to fork this promise for the pipeline to receive a copy of the answer. + auto forked = promise.fork(); + pipelineBranch = forked.addBranch(); + completionPromise = forked.addBranch().attach(context->addRef()); + } - auto pipelinePromise = forked.addBranch().then(kj::mvCapture(context->addRef(), - [=](kj::Own&& context) -> kj::Own { + auto pipelinePromise = pipelineBranch + .then([=,context=context->addRef()]() mutable -> kj::Own { context->releaseParams(); return kj::refcounted(kj::mv(context)); - })); + }); - auto tailPipelinePromise = context->onTailCall().then([](AnyPointer::Pipeline&& pipeline) { + auto tailPipelinePromise = context->onTailCall() + .then([context = context->addRef()](AnyPointer::Pipeline&& pipeline) { return kj::mv(pipeline.hook); }); pipelinePromise = pipelinePromise.exclusiveJoin(kj::mv(tailPipelinePromise)); - auto completionPromise = forked.addBranch().attach(kj::mv(context)); - return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::refcounted(kj::mv(pipelinePromise)) }; } kj::Maybe getResolved() override { - return nullptr; + return resolved.map([](kj::Own& hook) -> ClientHook& { return *hook; }); } kj::Maybe>> whenMoreResolved() override { - return nullptr; + KJ_IF_MAYBE(r, resolved) { + return kj::Promise>(r->get()->addRef()); + } else KJ_IF_MAYBE(t, resolveTask) { + return t->addBranch().then([this]() { + return KJ_ASSERT_NONNULL(resolved)->addRef(); + }); + } else { + return nullptr; + } } kj::Own addRef() override { return kj::addRef(*this); } + static const uint BRAND; + // Value is irrelevant; used for pointer. + const void* getBrand() override { - // We have no need to detect local objects. - return nullptr; + return &BRAND; } - void* getLocalServer(_::CapabilityServerSetBase& capServerSet) override { + kj::Maybe> getLocalServer(_::CapabilityServerSetBase& capServerSet) { + // If this is a local capability created through `capServerSet`, return the underlying Server. + // Otherwise, return nullptr. Default implementation (which everyone except LocalClient should + // use) always returns nullptr. + if (this->capServerSet == &capServerSet) { - return ptr; + if (blocked) { + // If streaming calls are in-flight, it could be the case that they were originally sent + // over RPC and reflected back, before the capability had resolved to a local object. In + // that case, the client may already perceive these calls as "done" because the RPC + // implementation caused the client promise to resolve early. However, the capability is + // now local, and the app is trying to break through the LocalClient wrapper and access + // the server directly, bypassing the stream queue. Since the app thinks that all + // previous calls already completed, it may then try to queue a new call directly on the + // server, jumping the queue. + // + // We can solve this by delaying getLocalServer() until all current streaming calls have + // finished. Note that if a new streaming call is started *after* this point, we need not + // worry about that, because in this case it is presumably a local call and the caller + // won't be informed of completion until the call actually does complete. Thus the caller + // is well-aware that this call is still in-flight. + // + // However, the app still cannot assume that there aren't multiple clients, perhaps even + // a malicious client that tries to send stream requests that overlap with the app's + // direct use of the server... so it's up to the app to check for and guard against + // concurrent calls after using getLocalServer(). + return kj::newAdaptedPromise, BlockedCall>(*this) + .then([this]() { return ptr; }); + } else { + return kj::Promise(ptr); + } + } else { + return nullptr; + } + } + + kj::Maybe getFd() override { + KJ_IF_MAYBE(s, server) { + return s->get()->getFd(); } else { return nullptr; } } private: - kj::Own server; + kj::Maybe> server; _::CapabilityServerSetBase* capServerSet = nullptr; void* ptr = nullptr; + + kj::Maybe> resolveTask; + kj::Maybe> resolved; + + kj::Maybe revoker; + // If non-null, all promises must be wrapped in this revoker. + + void startResolveTask(Capability::Server& serverRef) { + resolveTask = serverRef.shortenPath().map([this](kj::Promise promise) { + KJ_IF_MAYBE(r, revoker) { + promise = r->wrap(kj::mv(promise)); + } + + return promise.then([this](Capability::Client&& cap) { + auto hook = ClientHook::from(kj::mv(cap)); + + if (blocked) { + // This is a streaming interface and we have some calls queued up as a result. We cannot + // resolve directly to the new shorter path because this may allow new calls to hop + // the queue -- we need to embargo new calls until the queue clears out. + auto promise = kj::newAdaptedPromise, BlockedCall>(*this) + .then([hook = kj::mv(hook)]() mutable { return kj::mv(hook); }); + hook = newLocalPromiseClient(kj::mv(promise)); + } + + resolved = kj::mv(hook); + }).fork(); + }); + } + + class BlockedCall { + public: + BlockedCall(kj::PromiseFulfiller>& fulfiller, LocalClient& client, + uint64_t interfaceId, uint16_t methodId, CallContextHook& context) + : fulfiller(fulfiller), client(client), + interfaceId(interfaceId), methodId(methodId), context(context), + prev(client.blockedCallsEnd) { + *prev = *this; + client.blockedCallsEnd = &next; + } + + BlockedCall(kj::PromiseFulfiller>& fulfiller, LocalClient& client) + : fulfiller(fulfiller), client(client), prev(client.blockedCallsEnd) { + *prev = *this; + client.blockedCallsEnd = &next; + } + + ~BlockedCall() noexcept(false) { + unlink(); + } + + void unblock() { + unlink(); + KJ_IF_MAYBE(c, context) { + fulfiller.fulfill(kj::evalNow([&]() { + return client.callInternal(interfaceId, methodId, *c); + })); + } else { + // This is just a barrier. + fulfiller.fulfill(kj::READY_NOW); + } + } + + private: + kj::PromiseFulfiller>& fulfiller; + LocalClient& client; + uint64_t interfaceId; + uint16_t methodId; + kj::Maybe context; + + kj::Maybe next; + kj::Maybe* prev; + + void unlink() { + if (prev != nullptr) { + *prev = next; + KJ_IF_MAYBE(n, next) { + n->prev = prev; + } else { + client.blockedCallsEnd = prev; + } + prev = nullptr; + } + } + }; + + class BlockingScope { + public: + BlockingScope(LocalClient& client): client(client) { client.blocked = true; } + BlockingScope(): client(nullptr) {} + BlockingScope(BlockingScope&& other): client(other.client) { other.client = nullptr; } + KJ_DISALLOW_COPY(BlockingScope); + + ~BlockingScope() noexcept(false) { + KJ_IF_MAYBE(c, client) { + c->unblock(); + } + } + + private: + kj::Maybe client; + }; + + bool blocked = false; + kj::Maybe brokenException; + kj::Maybe blockedCalls; + kj::Maybe* blockedCallsEnd = &blockedCalls; + + void unblock() { + blocked = false; + while (!blocked) { + KJ_IF_MAYBE(t, blockedCalls) { + t->unblock(); + } else { + break; + } + } + } + + kj::Promise callInternal(uint64_t interfaceId, uint16_t methodId, + CallContextHook& context) { + KJ_ASSERT(!blocked); + + KJ_IF_MAYBE(e, brokenException) { + // Previous streaming call threw, so everything fails from now on. + return kj::cp(*e); + } + + // `server` can't be null here since `brokenException` is null. + auto result = KJ_ASSERT_NONNULL(server)->dispatchCall(interfaceId, methodId, + CallContext(context)); + + KJ_IF_MAYBE(r, revoker) { + result.promise = r->wrap(kj::mv(result.promise)); + } + + if (!result.allowCancellation) { + // Make sure this call cannot be canceled by forking the promise and detaching one branch. + auto fork = result.promise.attach(kj::addRef(*this), context.addRef()).fork(); + result.promise = fork.addBranch(); + fork.addBranch().detach([](kj::Exception&&) { + // Exception from canceled call is silently discarded. The caller should have waited for + // it if they cared. + }); + } + + if (result.isStreaming) { + return result.promise + .catch_([this](kj::Exception&& e) { + brokenException = kj::cp(e); + kj::throwRecoverableException(kj::mv(e)); + }).attach(BlockingScope(*this)); + } else { + return kj::mv(result.promise); + } + } }; +const uint LocalClient::BRAND = 0; + kj::Own Capability::Client::makeLocalClient(kj::Own&& server) { return kj::refcounted(kj::mv(server)); } +kj::Own Capability::Client::makeRevocableLocalClient(Capability::Server& server) { + auto result = kj::refcounted( + kj::Own(&server, kj::NullDisposer::instance), true /* revocable */); + return result; +} +void Capability::Client::revokeLocalClient(ClientHook& hook) { + revokeLocalClient(hook, KJ_EXCEPTION(FAILED, + "capability was revoked (RevocableServer was destroyed)")); +} +void Capability::Client::revokeLocalClient(ClientHook& hook, kj::Exception&& e) { + kj::downcast(hook).revoke(kj::mv(e)); +} + kj::Own newLocalPromiseClient(kj::Promise>&& promise) { return kj::refcounted(kj::mv(promise)); } @@ -548,6 +927,36 @@ kj::Own newLocalPromisePipeline(kj::Promise> // ======================================================================================= +namespace _ { // private + +class PipelineBuilderHook final: public PipelineHook, public kj::Refcounted { +public: + PipelineBuilderHook(uint firstSegmentWords) + : message(firstSegmentWords), + root(message.getRoot()) {} + + kj::Own addRef() override { + return kj::addRef(*this); + } + + kj::Own getPipelinedCap(kj::ArrayPtr ops) override { + return root.asReader().getPipelinedCap(ops); + } + + MallocMessageBuilder message; + AnyPointer::Builder root; +}; + +PipelineBuilderPair newPipelineBuilder(uint firstSegmentWords) { + auto hook = kj::refcounted(firstSegmentWords); + auto root = hook->root; + return { root, kj::mv(hook) }; +} + +} // namespace _ (private) + +// ======================================================================================= + namespace { class BrokenPipeline final: public PipelineHook, public kj::Refcounted { @@ -574,6 +983,14 @@ public: AnyPointer::Pipeline(kj::refcounted(exception))); } + kj::Promise sendStreaming() override { + return kj::cp(exception); + } + + AnyPointer::Pipeline sendForPipeline() override { + return AnyPointer::Pipeline(kj::refcounted(exception)); + } + const void* getBrand() override { return nullptr; } @@ -584,19 +1001,20 @@ public: class BrokenClient final: public ClientHook, public kj::Refcounted { public: - BrokenClient(const kj::Exception& exception, bool resolved, const void* brand = nullptr) + BrokenClient(const kj::Exception& exception, bool resolved, const void* brand) : exception(exception), resolved(resolved), brand(brand) {} - BrokenClient(const kj::StringPtr description, bool resolved, const void* brand = nullptr) + BrokenClient(const kj::StringPtr description, bool resolved, const void* brand) : exception(kj::Exception::Type::FAILED, "", 0, kj::str(description)), resolved(resolved), brand(brand) {} Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { return newBrokenRequest(kj::cp(exception), sizeHint); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { + kj::Own&& context, CallHints hints) override { return VoidPromiseAndPipeline { kj::cp(exception), kj::refcounted(exception) }; } @@ -620,6 +1038,10 @@ public: return brand; } + kj::Maybe getFd() override { + return nullptr; + } + private: kj::Exception exception; bool resolved; @@ -627,7 +1049,7 @@ private: }; kj::Own BrokenPipeline::getPipelinedCap(kj::ArrayPtr ops) { - return kj::refcounted(exception, false); + return kj::refcounted(exception, false, &ClientHook::BROKEN_CAPABILITY_BRAND); } kj::Own newNullCap() { @@ -639,11 +1061,11 @@ kj::Own newNullCap() { } // namespace kj::Own newBrokenCap(kj::StringPtr reason) { - return kj::refcounted(reason, false); + return kj::refcounted(reason, false, &ClientHook::BROKEN_CAPABILITY_BRAND); } kj::Own newBrokenCap(kj::Exception&& reason) { - return kj::refcounted(kj::mv(reason), false); + return kj::refcounted(kj::mv(reason), false, &ClientHook::BROKEN_CAPABILITY_BRAND); } kj::Own newBrokenPipeline(kj::Exception&& reason) { @@ -657,6 +1079,27 @@ Request newBrokenRequest( return Request(root, kj::mv(hook)); } +kj::Own getDisabledPipeline() { + class DisabledPipelineHook final: public PipelineHook { + public: + kj::Own addRef() override { + return kj::Own(this, kj::NullDisposer::instance); + } + + kj::Own getPipelinedCap(kj::ArrayPtr ops) override { + return newBrokenCap(KJ_EXCEPTION(FAILED, + "caller specified noPromisePipelining hint, but then tried to pipeline")); + } + + kj::Own getPipelinedCap(kj::Array&& ops) override { + return newBrokenCap(KJ_EXCEPTION(FAILED, + "caller specified noPromisePipelining hint, but then tried to pipeline")); + } + }; + static DisabledPipelineHook instance; + return instance.addRef(); +} + // ======================================================================================= ReaderCapabilityTable::ReaderCapabilityTable( @@ -712,19 +1155,35 @@ kj::Promise CapabilityServerSetBase::getLocalServerInternal(Capability::C ClientHook* hook = client.hook.get(); // Get the most-resolved-so-far version of the hook. - KJ_IF_MAYBE(h, hook->getResolved()) { - hook = h; - }; + for (;;) { + KJ_IF_MAYBE(h, hook->getResolved()) { + hook = h; + } else { + break; + } + } + + // Try to unwrap that. + if (hook->getBrand() == &LocalClient::BRAND) { + KJ_IF_MAYBE(promise, kj::downcast(*hook).getLocalServer(*this)) { + // This is definitely a member of our set and will resolve to non-null. We just have to wait + // for any existing streaming calls to complete. + return kj::mv(*promise); + } + } + // OK, the capability isn't part of this set. KJ_IF_MAYBE(p, hook->whenMoreResolved()) { - // This hook is an unresolved promise. We need to wait for it. + // This hook is an unresolved promise. It might resolve eventually to a local server, so wait + // for it. return p->attach(hook->addRef()) .then([this](kj::Own&& resolved) { Capability::Client client(kj::mv(resolved)); return getLocalServerInternal(client); }); } else { - return hook->getLocalServer(*this); + // Cap is settled, so it definitely will never resolve to a member of this set. + return kj::implicitCast(nullptr); } } diff --git a/c++/src/capnp/capability.h b/c++/src/capnp/capability.h index 56a5e6f6de..1e71840ac5 100644 --- a/c++/src/capnp/capability.h +++ b/c++/src/capnp/capability.h @@ -19,12 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_CAPABILITY_H_ -#define CAPNP_CAPABILITY_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #if CAPNP_LITE #error "RPC APIs, including this header, are not available in lite mode." @@ -36,6 +31,8 @@ #include "any.h" #include "pointer-helpers.h" +CAPNP_BEGIN_HEADER + namespace capnp { template @@ -61,12 +58,21 @@ class RemotePromise: public kj::Promise>, public T::Pipeline { KJ_DISALLOW_COPY(RemotePromise); RemotePromise(RemotePromise&& other) = default; RemotePromise& operator=(RemotePromise&& other) = default; + + kj::Promise> dropPipeline() { + // Convenience method to convert this into a plain promise. + return kj::mv(*this); + } + + static RemotePromise reducePromise(kj::Promise&& promise); + // Hook for KJ so that Promise> automatically reduces to RemotePromise. }; class LocalClient; namespace _ { // private extern const RawSchema NULL_INTERFACE_SCHEMA; // defined in schema.c++ class CapabilityServerSetBase; +struct PipelineBuilderPair; } // namespace _ (private) struct Capability { @@ -95,6 +101,8 @@ class RequestHook; class ResponseHook; class PipelineHook; class ClientHook; +template +class RevocableServer; template class Request: public Params::Builder { @@ -113,6 +121,45 @@ class Request: public Params::Builder { RemotePromise send() KJ_WARN_UNUSED_RESULT; // Send the call and return a promise for the results. + typename Results::Pipeline sendForPipeline(); + // Send the call in pipeline-only mode. The returned object can be used to make pipelined calls, + // but there is no way to wait for the completion of the original call. This allows some + // bookkeeping to be skipped under the hood, saving some time. + // + // Generally, this method should only be used when the caller will immediately make one or more + // pipelined calls on the result, and then throw away the pipeline and all pipelined + // capabilities. Other uses may run into caveats, such as: + // - Normally, calling `whenResolved()` on a pipelined capability would wait for the original RPC + // to complete (and possibly other things, if that RPC itself returned a promise capability), + // but when using `sendPipelineOnly()`, `whenResolved()` may complete immediately, or never, or + // at an arbitrary time. Do not rely on it. + // - Normal path shortening may not work with these capabilities. For exmaple, if the caller + // forwards a pipelined capability back to the callee's vat, calls made by the callee to that + // capability may continue to proxy through the caller. Conversely, if the callee ends up + // returning a capability that points back to the caller's vat, calls on the pipelined + // capability may continue to proxy through the callee. + +private: + kj::Own hook; + + friend class Capability::Client; + friend struct DynamicCapability; + template + friend class CallContext; + friend class RequestHook; +}; + +template +class StreamingRequest: public Params::Builder { + // Like `Request` but for streaming requests. + +public: + inline StreamingRequest(typename Params::Builder builder, kj::Own&& hook) + : Params::Builder(builder), hook(kj::mv(hook)) {} + inline StreamingRequest(decltype(nullptr)): Params::Builder(nullptr) {} + + kj::Promise send() KJ_WARN_UNUSED_RESULT; + private: kj::Own hook; @@ -198,12 +245,41 @@ class Capability::Client { // where no calls are being made. There is no reason to wait for this before making calls; if // the capability does not resolve, the call results will propagate the error. + struct CallHints { + bool noPromisePipelining = false; + // Hints that the pipeline part of the VoidPromiseAndPipeline won't be used, so it can be + // a bogus object. + + bool onlyPromisePipeline = false; + // Hints that the promise part of the VoidPromiseAndPipeline won't be used, so it can be a + // bogus promise. + // + // This hint is primarily intended to be passed to `ClientHook::call()`. When using + // `ClientHook::newCall()`, you would instead indicate the hint by calling the `ResponseHook`'s + // `sendForPipeline()` method. The effect of setting `onlyPromisePipeline = true` when invoking + // `ClientHook::newCall()` is unspecified; it might cause the returned `Request` to support + // only pipelining even when `send()` is called, or it might not. + }; + Request typelessRequest( uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint); + kj::Maybe sizeHint, CallHints hints); // Make a request without knowing the types of the params or results. You specify the type ID // and method number manually. + kj::Promise> getFd(); + // If the capability's server implemented Capability::Server::getFd() returning non-null, and all + // RPC links between the client and server support FD passing, returns a file descriptor pointing + // to the same underlying file description as the server did. Returns null if the server provided + // no FD or if FD passing was unavailable at some intervening link. + // + // This returns a Promise to handle the case of an unresolved promise capability, e.g. a + // pipelined capability. The promise resolves no later than when the capability settles, i.e. + // the same time `whenResolved()` would complete. + // + // The file descriptor will remain open at least as long as the Capability::Client remains alive. + // If you need it to last longer, you will need to `dup()` it. + // TODO(someday): method(s) for Join protected: @@ -211,12 +287,18 @@ class Capability::Client { template Request newCall(uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint); + kj::Maybe sizeHint, CallHints hints); + template + StreamingRequest newStreamingCall(uint64_t interfaceId, uint16_t methodId, + kj::Maybe sizeHint, CallHints hints); private: kj::Own hook; static kj::Own makeLocalClient(kj::Own&& server); + static kj::Own makeRevocableLocalClient(Capability::Server& server); + static void revokeLocalClient(ClientHook& hook); + static void revokeLocalClient(ClientHook& hook, kj::Exception&& reason); template friend struct _::PointerHelpers; @@ -228,6 +310,8 @@ class Capability::Client { friend struct List; friend class _::CapabilityServerSetBase; friend class ClientHook; + template + friend class RevocableServer; }; // ======================================================================================= @@ -273,6 +357,62 @@ class CallContext: public kj::DisallowConstCopy { // should not be included in the size. So, if you are simply going to copy some existing message // directly into the results, just call `.totalSize()` and pass that in. + void setPipeline(typename Results::Pipeline&& pipeline); + void setPipeline(typename Results::Pipeline& pipeline); + // Tells the system where the capabilities in the response will eventually resolve to. This + // allows requests that are promise-pipelined on this call's results to continue their journey + // to the final destination before this call itself has completed. + // + // This is particularly useful when forwarding RPC calls to other remote servers, but where a + // tail call can't be used. For example, imagine Alice calls `foo()` on Bob. In `foo()`'s + // implementation, Bob calls `bar()` on Charlie. `bar()` returns a capability to Bob, and then + // `foo()` returns the same capability on to Alice. Now imagine Alice is actually using promise + // pipelining in a chain like `foo().getCap().baz()`. The `baz()` call will travel to Bob as a + // pipelined call without waiting for `foo()` to return first. But once it gets to Bob, the + // message has to patiently wait until `foo()` has completed there, before it can then be + // forwarded on to Charlie. It would be better if immediately upon Bob calling `bar()` on + // Charlie, then Alice's call to `baz()` could be forwarded to Charlie as a pipelined call, + // without waiting for `bar()` to return. This would avoid a network round trip of latency + // between Bob and Charlie. + // + // To solve this problem, Bob takes the pipeline object from the `bar()` call, transforms it into + // an appropriate pipeline for a `foo()` call, and passes that to `setPipeline()`. This allows + // Alice's pipelined `baz()` call to flow through immediately. The code looks like: + // + // kj::Promise foo(FooContext context) { + // auto barPromise = charlie.barRequest().send(); + // + // // Set up the final pipeline using pipelined capabilities from `barPromise`. + // capnp::PipelineBuilder pipeline; + // pipeline.setResultCap(barPromise.getSomeCap()); + // context.setPipeline(pipeline.build()); + // + // // Now actually wait for the results and process them. + // return barPromise + // .then([context](capnp::Response response) mutable { + // auto results = context.initResults(); + // + // // Make sure to set up the capabilities exactly as we did in the pipeline. + // results.setResultCap(response.getSomeCap()); + // + // // ... do other stuff with the real response ... + // }); + // } + // + // Of course, if `foo()` and `bar()` return exactly the same type, and Bob doesn't intend + // to do anything with `bar()`'s response except pass it through, then `tailCall()` is a better + // choice here. `setPipeline()` is useful when some transformation is needed on the response, + // or the middleman needs to inspect the response for some reason. + // + // Note: This method has an overload that takes an lvalue reference for convenience. This + // overload increments the refcount on the underlying PipelineHook -- it does not keep the + // reference. + // + // Note: Capabilities returned by the replacement pipeline MUST either be exactly the same + // capabilities as in the final response, or eventually resolve to exactly the same + // capabilities, where "exactly the same" means the underlying `ClientHook` object is exactly + // the same object by identity. Resolving to some "equivalent" capability is not good enough. + template kj::Promise tailCall(Request&& tailRequest); // Resolve the call by making a tail call. `tailRequest` is a request that has been filled in @@ -286,33 +426,51 @@ class CallContext: public kj::DisallowConstCopy { // In general, this should be the last thing a method implementation calls, and the promise // returned from `tailCall()` should then be returned by the method implementation. - void allowCancellation(); - // Indicate that it is OK for the RPC system to discard its Promise for this call's result if - // the caller cancels the call, thereby transitively canceling any asynchronous operations the - // call implementation was performing. This is not done by default because it could represent a - // security risk: applications must be carefully written to ensure that they do not end up in - // a bad state if an operation is canceled at an arbitrary point. However, for long-running - // method calls that hold significant resources, prompt cancellation is often useful. - // - // Keep in mind that asynchronous cancellation cannot occur while the method is synchronously - // executing on a local thread. The method must perform an asynchronous operation or call - // `EventLoop::current().evalLater()` to yield control. - // - // Note: You might think that we should offer `onCancel()` and/or `isCanceled()` methods that - // provide notification when the caller cancels the request without forcefully killing off the - // promise chain. Unfortunately, this composes poorly with promise forking: the canceled - // path may be just one branch of a fork of the result promise. The other branches still want - // the call to continue. Promise forking is used within the Cap'n Proto implementation -- in - // particular each pipelined call forks the result promise. So, if a caller made a pipelined - // call and then dropped the original object, the call should not be canceled, but it would be - // excessively complicated for the framework to avoid notififying of cancellation as long as - // pipelined calls still exist. + void allowCancellation() + KJ_UNAVAILABLE( + "As of Cap'n Proto 1.0, allowCancellation must be applied statically using an " + "annotation in the schema. See annotations defined in /capnp/c++.capnp. For " + "DynamicCapability::Server, use the constructor option (the annotation does not apply " + "to DynamicCapability). This change was made to gain a significant performance boost -- " + "dynamically allowing cancellation required excessive bookkeeping."); private: CallContextHook* hook; friend class Capability::Server; friend struct DynamicCapability; + friend class CallContextHook; +}; + +template +class StreamingCallContext: public kj::DisallowConstCopy { + // Like CallContext but for streaming calls. + +public: + explicit StreamingCallContext(CallContextHook& hook); + + typename Params::Reader getParams(); + void releaseParams(); + + // Note: tailCall() is not supported because: + // - It would significantly complicate the implementation of streaming. + // - It wouldn't be particularly useful since streaming calls don't return anything, and they + // already compensate for latency. + + void allowCancellation() + KJ_UNAVAILABLE( + "As of Cap'n Proto 1.0, allowCancellation must be applied statically using an " + "annotation in the schema. See annotations defined in /capnp/c++.capnp. For " + "DynamicCapability::Server, use the constructor option (the annotation does not apply " + "to DynamicCapability). This change was made to gain a significant performance boost -- " + "dynamically allowing cancellation required excessive bookkeeping."); + +private: + CallContextHook* hook; + + friend class Capability::Server; + friend struct DynamicCapability; + friend class CallContextHook; }; class Capability::Server { @@ -323,11 +481,50 @@ class Capability::Server { public: typedef Capability Serves; - virtual kj::Promise dispatchCall(uint64_t interfaceId, uint16_t methodId, - CallContext context) = 0; + struct DispatchCallResult { + kj::Promise promise; + // Promise for completion of the call. + + bool isStreaming; + // If true, this method was declared as `-> stream;`. No other calls should be permitted until + // this call finishes, and if this call throws an exception, all future calls will throw the + // same exception. + + bool allowCancellation = false; + // If true, the call can be canceled normally. If false, the immediate caller is responsible + // for ensuring that cancellation is prevented and that `context` remains valid until the + // call completes normally. + // + // See the `allowCancellation` annotation defined in `c++.capnp`. + }; + + virtual DispatchCallResult dispatchCall(uint64_t interfaceId, uint16_t methodId, + CallContext context) = 0; // Call the given method. `params` is the input struct, and should be released as soon as it - // is no longer needed. `context` may be used to allocate the output struct and deal with - // cancellation. + // is no longer needed. `context` may be used to allocate the output struct and other call + // logistics. + + virtual kj::Maybe getFd() { return nullptr; } + // If this capability is backed by a file descriptor that is safe to directly expose to clients, + // returns that FD. When FD passing has been enabled in the RPC layer, this FD may be sent to + // other processes along with the capability. + + virtual kj::Maybe> shortenPath(); + // If this returns non-null, then it is a promise which, when resolved, points to a new + // capability to which future calls can be sent. Use this in cases where an object implementation + // might discover a more-optimized path some time after it starts. + // + // Implementing this (and returning non-null) will cause the capability to be advertised as a + // promise at the RPC protocol level. Once the promise returned by shortenPath() resolves, the + // remote client will receive a `Resolve` message updating it to point at the new destination. + // + // `shortenPath()` can also be used as a hack to shut up the client. If shortenPath() returns + // a promise that resolves to an exception, then the client will be notified that the capability + // is now broken. Assuming the client is using a correct RPC implemnetation, this should cause + // all further calls initiated by the client to this capability to immediately fail client-side, + // sparing the server's bandwidth. + // + // The default implementation always returns nullptr. // TODO(someday): Method which can optionally be overridden to implement Join when the object is // a proxy. @@ -341,16 +538,20 @@ class Capability::Server { // the server's constructor.) // - The capability client pointing at this object has been destroyed. (This is always the case // in the server's destructor.) + // - The capability client pointing at this object has been revoked using RevocableServer. // - Multiple capability clients have been created around the same server (possible if the server // is refcounted, which is not recommended since the client itself provides refcounting). template CallContext internalGetTypedContext( CallContext typeless); - kj::Promise internalUnimplemented(const char* actualInterfaceName, - uint64_t requestedTypeId); - kj::Promise internalUnimplemented(const char* interfaceName, - uint64_t typeId, uint16_t methodId); + template + StreamingCallContext internalGetTypedStreamingContext( + CallContext typeless); + DispatchCallResult internalUnimplemented(const char* actualInterfaceName, + uint64_t requestedTypeId); + DispatchCallResult internalUnimplemented(const char* interfaceName, + uint64_t typeId, uint16_t methodId); kj::Promise internalUnimplemented(const char* interfaceName, const char* methodName, uint64_t typeId, uint16_t methodId); @@ -359,6 +560,67 @@ class Capability::Server { friend class LocalClient; }; +template +class RevocableServer { + // Allows you to create a capability client pointing to a capability server without taking + // ownership of the server. When `RevocableServer` is destroyed, all clients created through it + // will become broken. All outstanding RPCs via those clients will be canceled and all future + // RPCs will immediately throw. Hence, once the `RevocableServer` is destroyed, it is safe + // to destroy the server object it referenced. + // + // This is particularly useful when you want to create a capability server that points to an + // object that you do not own, and thus cannot keep alive beyond some defined lifetime. Since + // you cannot force the client to respect lifetime rules, you should use a RevocableServer to + // revoke access before the lifetime ends. + // + // The RevocableServer object can be moved (as long as the server outlives it). + +public: + RevocableServer(typename T::Server& server); + RevocableServer(RevocableServer&&) = default; + RevocableServer& operator=(RevocableServer&&) = default; + ~RevocableServer() noexcept(false); + KJ_DISALLOW_COPY(RevocableServer); + + typename T::Client getClient(); + + void revoke(); + void revoke(kj::Exception&& reason); + // Revokes the capability immediately, rather than waiting for the destructor. This can also + // be used to specify a custom exception to use when revoking. + +private: + kj::Own hook; +}; + +// ======================================================================================= + +template +class PipelineBuilder: public T::Builder { + // Convenience class to build a Pipeline object for use with CallContext::setPipeline(). + // + // Building a pipeline object is like building an RPC result message, except that you only need + // to fill in the capabilities, since the purpose is only to allow pipelined RPC requests to + // flow through. + // + // See the docs for `CallContext::setPipeline()` for an example. + +public: + PipelineBuilder(uint firstSegmentWords = 64); + // Construct a builder, allocating the given number of words for the first segment of the backing + // message. Since `PipelineBuilder` is typically used with small RPC messages, the default size + // here is considerably smaller than with MallocMessageBuilder. + + typename T::Pipeline build(); + // Constructs a `Pipeline` object backed by the current content of this builder. Calling this + // consumes the `PipelineBuilder`; no further methods can be invoked. + +private: + kj::Own hook; + + PipelineBuilder(_::PipelineBuilderPair pair); +}; + // ======================================================================================= class ReaderCapabilityTable: private _::CapTableReader { @@ -377,7 +639,7 @@ class ReaderCapabilityTable: private _::CapTableReader { public: explicit ReaderCapabilityTable(kj::Array>> table); - KJ_DISALLOW_COPY(ReaderCapabilityTable); + KJ_DISALLOW_COPY_AND_MOVE(ReaderCapabilityTable); template T imbue(T reader); @@ -398,7 +660,7 @@ class BuilderCapabilityTable: private _::CapTableBuilder { public: BuilderCapabilityTable(); - KJ_DISALLOW_COPY(BuilderCapabilityTable); + KJ_DISALLOW_COPY_AND_MOVE(BuilderCapabilityTable); inline kj::ArrayPtr>> getTable() { return table; } @@ -443,7 +705,7 @@ class CapabilityServerSet: private _::CapabilityServerSetBase { public: CapabilityServerSet() = default; - KJ_DISALLOW_COPY(CapabilityServerSet); + KJ_DISALLOW_COPY_AND_MOVE(CapabilityServerSet); typename T::Client add(kj::Own&& server); // Create a new capability Client for the given Server and also add this server to the set. @@ -467,6 +729,12 @@ class RequestHook { virtual RemotePromise send() = 0; // Send the call and return a promise for the result. + virtual kj::Promise sendStreaming() = 0; + // Send a streaming call. + + virtual AnyPointer::Pipeline sendForPipeline() = 0; + // Send a call for pipelining purposes only. + virtual const void* getBrand() = 0; // Returns a void* that identifies who made this request. This can be used by an RPC adapter to // discover when tail call is going to be sent over its own connection and therefore can be @@ -500,8 +768,11 @@ class ClientHook { public: ClientHook(); + using CallHints = Capability::Client::CallHints; + virtual Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) = 0; + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) = 0; // Start a new call, allowing the client to allocate request/response objects as it sees fit. // This version is used when calls are made from application code in the local process. @@ -511,17 +782,13 @@ class ClientHook { }; virtual VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) = 0; + kj::Own&& context, CallHints hints) = 0; // Call the object, but the caller controls allocation of the request/response objects. If the // callee insists on allocating these objects itself, it must make a copy. This version is used // when calls come in over the network via an RPC system. Note that even if the returned // `Promise` is discarded, the call may continue executing if any pipelined calls are // waiting for it. // - // Since the caller of this method chooses the CallContext implementation, it is the caller's - // responsibility to ensure that the returned promise is not canceled unless allowed via - // the context's `allowCancellation()`. - // // The call must not begin synchronously; the callee must arrange for the call to begin in a // later turn of the event loop. Otherwise, application code may call back and affect the // callee's state in an unexpected way. @@ -531,12 +798,22 @@ class ClientHook { // of the capability. The caller may permanently replace this client with the resolved one if // desired. Returns null if the client isn't a promise or hasn't resolved yet -- use // `whenMoreResolved()` to distinguish between them. + // + // Once a particular ClientHook's `getResolved()` returns non-null, it must permanently return + // exactly the same resolution. This is why `getResolved()` returns a reference -- it is assumed + // this object must have a strong reference to the resolution which it intends to keep + // permanently, therefore the returned reference will live at least as long as this `ClientHook`. + // This "only one resolution" policy is necessary for the RPC system to implement embargoes + // properly. virtual kj::Maybe>> whenMoreResolved() = 0; // If this client is a settled reference (not a promise), return nullptr. Otherwise, return a // promise that eventually resolves to a new client that is closer to being the final, settled // client (i.e. the value eventually returned by `getResolved()`). Calling this repeatedly // should eventually produce a settled client. + // + // Once the promise resolves, `getResolved()` must return exactly the same `ClientHook` as the + // one this Promise resolved to. kj::Promise whenResolved(); // Repeatedly calls whenMoreResolved() until it returns nullptr. @@ -550,20 +827,29 @@ class ClientHook { // therefore it can transfer the capability without proxying. static const uint NULL_CAPABILITY_BRAND; - // Value is irrelevant; used for pointer. + static const uint BROKEN_CAPABILITY_BRAND; + // Values are irrelevant; used for pointers. inline bool isNull() { return getBrand() == &NULL_CAPABILITY_BRAND; } // Returns true if the capability was created as a result of assigning a Client to null or by // reading a null pointer out of a Cap'n Proto message. - virtual void* getLocalServer(_::CapabilityServerSetBase& capServerSet); - // If this is a local capability created through `capServerSet`, return the underlying Server. - // Otherwise, return nullptr. Default implementation (which everyone except LocalClient should - // use) always returns nullptr. + inline bool isError() { return getBrand() == &BROKEN_CAPABILITY_BRAND; } + // Returns true if the capability was created by newBrokenCap(). + + virtual kj::Maybe getFd() = 0; + // Implements Capability::Client::getFd(). If this returns null but whenMoreResolved() returns + // non-null, then Capability::Client::getFd() waits for resolution and tries again. static kj::Own from(Capability::Client client) { return kj::mv(client.hook); } }; +class RevocableClientHook: public ClientHook { +public: + virtual void revoke() = 0; + virtual void revoke(kj::Exception&& reason) = 0; +}; + class CallContextHook { // Hook interface implemented by RPC system to manage a call on the server side. See // CallContext. @@ -573,7 +859,8 @@ class CallContextHook { virtual void releaseParams() = 0; virtual AnyPointer::Builder getResults(kj::Maybe sizeHint) = 0; virtual kj::Promise tailCall(kj::Own&& request) = 0; - virtual void allowCancellation() = 0; + + virtual void setPipeline(kj::Own&& pipeline) = 0; virtual kj::Promise onTailCall() = 0; // If `tailCall()` is called, resolves to the PipelineHook from the tail call. An @@ -585,6 +872,11 @@ class CallContextHook { // promise fulfiller for onTailCall() with the returned pipeline. virtual kj::Own addRef() = 0; + + template + static CallContextHook& from(CallContext& context) { return *context.hook; } + template + static CallContextHook& from(StreamingCallContext& context) { return *context.hook; } }; kj::Own newLocalPromiseClient(kj::Promise>&& promise); @@ -607,6 +899,11 @@ Request newBrokenRequest( kj::Exception&& reason, kj::Maybe sizeHint); // Helper function that creates a Request object that simply throws exceptions when sent. +kj::Own getDisabledPipeline(); +// Gets a PipelineHook appropriate to use when CallHints::noPromisePipelining is true. This will +// throw from all calls. This does not actually allocate the object; a static global object is +// returned with a null disposer. + // ======================================================================================= // Extend PointerHelpers for interfaces @@ -661,6 +958,10 @@ struct List { inline Iterator begin() const { return Iterator(this, 0); } inline Iterator end() const { return Iterator(this, size()); } + inline MessageSize totalSize() const { + return reader.totalSize().asPublic(); + } + private: _::ListReader reader; template @@ -734,6 +1035,23 @@ struct List { // ======================================================================================= // Inline implementation details +template +RemotePromise RemotePromise::reducePromise(kj::Promise&& promise) { + kj::Tuple>, kj::Promise>> splitPromise = + promise.then([](RemotePromise&& inner) { + // `inner` is multiply-inherited, and we want to move away each superclass separately. + // Let's create two references to make clear what we're doing (though this is not strictly + // necessary). + kj::Promise>& innerPromise = inner; + typename T::Pipeline& innerPipeline = inner; + return kj::tuple(kj::mv(innerPromise), PipelineHook::from(kj::mv(innerPipeline))); + }).split(); + + return RemotePromise(kj::mv(kj::get<0>(splitPromise)), + typename T::Pipeline(AnyPointer::Pipeline( + newLocalPromisePipeline(kj::mv(kj::get<1>(splitPromise)))))); +} + template RemotePromise Request::send() { auto typelessPromise = hook->send(); @@ -754,6 +1072,20 @@ RemotePromise Request::send() { return RemotePromise(kj::mv(typedPromise), kj::mv(typedPipeline)); } +template +typename Results::Pipeline Request::sendForPipeline() { + auto typelessPipeline = hook->sendForPipeline(); + hook = nullptr; // prevent reuse + return typename Results::Pipeline(kj::mv(typelessPipeline)); +} + +template +kj::Promise StreamingRequest::send() { + auto promise = hook->sendStreaming(); + hook = nullptr; // prevent reuse + return promise; +} + inline Capability::Client::Client(kj::Own&& hook): hook(kj::mv(hook)) {} template inline Capability::Client::Client(kj::Own&& server) @@ -770,31 +1102,44 @@ template inline typename T::Client Capability::Client::castAs() { return typename T::Client(hook->addRef()); } -inline kj::Promise Capability::Client::whenResolved() { - return hook->whenResolved(); -} inline Request Capability::Client::typelessRequest( uint64_t interfaceId, uint16_t methodId, - kj::Maybe sizeHint) { - return newCall(interfaceId, methodId, sizeHint); + kj::Maybe sizeHint, CallHints hints) { + return newCall(interfaceId, methodId, sizeHint, hints); } template inline Request Capability::Client::newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) { - auto typeless = hook->newCall(interfaceId, methodId, sizeHint); + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, CallHints hints) { + auto typeless = hook->newCall(interfaceId, methodId, sizeHint, hints); return Request(typeless.template getAs(), kj::mv(typeless.hook)); } +template +inline StreamingRequest Capability::Client::newStreamingCall( + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, CallHints hints) { + auto typeless = hook->newCall(interfaceId, methodId, sizeHint, hints); + return StreamingRequest(typeless.template getAs(), kj::mv(typeless.hook)); +} template inline CallContext::CallContext(CallContextHook& hook): hook(&hook) {} +template +inline StreamingCallContext::StreamingCallContext(CallContextHook& hook): hook(&hook) {} template inline typename Params::Reader CallContext::getParams() { return hook->getParams().template getAs(); } +template +inline typename Params::Reader StreamingCallContext::getParams() { + return hook->getParams().template getAs(); +} template inline void CallContext::releaseParams() { hook->releaseParams(); } +template +inline void StreamingCallContext::releaseParams() { + hook->releaseParams(); +} template inline typename Results::Builder CallContext::getResults( kj::Maybe sizeHint) { @@ -821,15 +1166,19 @@ inline Orphanage CallContext::getResultsOrphanage( return Orphanage::getForMessageContaining(hook->getResults(sizeHint)); } template +void CallContext::setPipeline(typename Results::Pipeline&& pipeline) { + hook->setPipeline(PipelineHook::from(kj::mv(pipeline))); +} +template +void CallContext::setPipeline(typename Results::Pipeline& pipeline) { + hook->setPipeline(PipelineHook::from(pipeline).addRef()); +} +template template inline kj::Promise CallContext::tailCall( Request&& tailRequest) { return hook->tailCall(kj::mv(tailRequest.hook)); } -template -inline void CallContext::allowCancellation() { - hook->allowCancellation(); -} template CallContext Capability::Server::internalGetTypedContext( @@ -837,10 +1186,70 @@ CallContext Capability::Server::internalGetTypedContext( return CallContext(*typeless.hook); } +template +StreamingCallContext Capability::Server::internalGetTypedStreamingContext( + CallContext typeless) { + return StreamingCallContext(*typeless.hook); +} + Capability::Client Capability::Server::thisCap() { return Client(thisHook->addRef()); } +template +RevocableServer::RevocableServer(typename T::Server& server) + : hook(Capability::Client::makeRevocableLocalClient(server)) {} +template +RevocableServer::~RevocableServer() noexcept(false) { + // Check if moved away. + if (hook.get() != nullptr) { + Capability::Client::revokeLocalClient(*hook); + } +} + +template +typename T::Client RevocableServer::getClient() { + return typename T::Client(hook->addRef()); +} + +template +void RevocableServer::revoke() { + Capability::Client::revokeLocalClient(*hook); +} +template +void RevocableServer::revoke(kj::Exception&& exception) { + Capability::Client::revokeLocalClient(*hook, kj::mv(exception)); +} + +namespace _ { // private + +struct PipelineBuilderPair { + AnyPointer::Builder root; + kj::Own hook; +}; + +PipelineBuilderPair newPipelineBuilder(uint firstSegmentWords); + +} // namespace _ (private) + +template +PipelineBuilder::PipelineBuilder(uint firstSegmentWords) + : PipelineBuilder(_::newPipelineBuilder(firstSegmentWords)) {} + +template +PipelineBuilder::PipelineBuilder(_::PipelineBuilderPair pair) + : T::Builder(pair.root.initAs()), + hook(kj::mv(pair.hook)) {} + +template +typename T::Pipeline PipelineBuilder::build() { + // Prevent subsequent accidental modification. A good compiler should be able to optimize this + // assignment away assuming the PipelineBuilder is not accessed again after this point. + static_cast(*this) = nullptr; + + return typename T::Pipeline(AnyPointer::Pipeline(kj::mv(hook))); +} + template T ReaderCapabilityTable::imbue(T reader) { return T(_::PointerHelpers>::getInternalReader(reader).imbue(this)); @@ -879,6 +1288,8 @@ struct Orphanage::GetInnerReader { } }; +#define CAPNP_CAPABILITY_H_INCLUDED // for testing includes in unit test + } // namespace capnp -#endif // CAPNP_CAPABILITY_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/capnpc.ekam-rule b/c++/src/capnp/capnpc.ekam-rule index 30bf26dd01..dcecb48077 100755 --- a/c++/src/capnp/capnpc.ekam-rule +++ b/c++/src/capnp/capnpc.ekam-rule @@ -33,6 +33,7 @@ INPUT=$1 case "$INPUT" in *capnp/c++.capnp | \ *capnp/schema.capnp | \ + *capnp/stream.capnp | \ *capnp/rpc.capnp | \ *capnp/rpc-twoparty.capnp | \ *capnp/persistent.capnp | \ diff --git a/c++/src/capnp/cc_capnp_library.bzl b/c++/src/capnp/cc_capnp_library.bzl new file mode 100644 index 0000000000..9e4acd35b9 --- /dev/null +++ b/c++/src/capnp/cc_capnp_library.bzl @@ -0,0 +1,128 @@ +"""Bazel rule to compile .capnp files into c++.""" + +capnp_provider = provider("Capnproto Provider", fields = { + "includes": "includes for this target (transitive)", + "inputs": "src + data for the target", + "src_prefix": "src_prefix of the target", +}) + +def _workspace_path(label, path): + if label.workspace_root == "": + return path + return label.workspace_root + "/" + path + +def _capnp_gen_impl(ctx): + label = ctx.label + src_prefix = _workspace_path(label, ctx.attr.src_prefix) if ctx.attr.src_prefix != "" else "" + includes = [] + + inputs = ctx.files.srcs + ctx.files.data + for dep_target in ctx.attr.deps: + includes += dep_target[capnp_provider].includes + inputs += dep_target[capnp_provider].inputs + + if src_prefix != "": + includes.append(src_prefix) + + system_include = ctx.files._capnp_system[0].dirname.removesuffix("/capnp") + + gen_dir = ctx.var["GENDIR"] + out_dir = gen_dir + if src_prefix != "": + out_dir = out_dir + "/" + src_prefix + + cc_out = "-o%s:%s" % (ctx.executable._capnpc_cxx.path, out_dir) + args = ctx.actions.args() + args.add_all(["compile", "--verbose", cc_out]) + args.add_all(["-I" + inc for inc in includes]) + args.add_all(["-I", system_include]) + + if src_prefix == "": + # guess src_prefix for generated files + for src in ctx.files.srcs: + if src.path.startswith(gen_dir): + src_prefix = gen_dir + break + + if src_prefix != "": + args.add_all(["--src-prefix", src_prefix]) + + args.add_all([s for s in ctx.files.srcs]) + + ctx.actions.run( + inputs = inputs + ctx.files._capnpc_cxx + ctx.files._capnpc_capnp + ctx.files._capnp_system, + outputs = ctx.outputs.outs, + executable = ctx.executable._capnpc, + arguments = [args], + mnemonic = "GenCapnp", + ) + + return [ + capnp_provider( + includes = includes, + inputs = inputs, + src_prefix = src_prefix, + ), + ] + +_capnp_gen = rule( + attrs = { + "srcs": attr.label_list(allow_files = True), + "deps": attr.label_list(providers = [capnp_provider]), + "data": attr.label_list(allow_files = True), + "outs": attr.output_list(), + "src_prefix": attr.string(), + "_capnpc": attr.label(executable = True, allow_single_file = True, cfg = "exec", default = "@capnp-cpp//src/capnp:capnp_tool"), + "_capnpc_cxx": attr.label(executable = True, allow_single_file = True, cfg = "exec", default = "@capnp-cpp//src/capnp:capnpc-c++"), + "_capnpc_capnp": attr.label(executable = True, allow_single_file = True, cfg = "exec", default = "@capnp-cpp//src/capnp:capnpc-capnp"), + "_capnp_system": attr.label(default = "@capnp-cpp//src/capnp:capnp_system_library"), + }, + output_to_genfiles = True, + implementation = _capnp_gen_impl, +) + +def cc_capnp_library( + name, + srcs = [], + data = [], + deps = [], + src_prefix = "", + visibility = None, + target_compatible_with = None, + **kwargs): + """Bazel rule to create a C++ capnproto library from capnp source files + + Args: + name: library name + srcs: list of files to compile + data: additional files to provide to the compiler - data files and includes that need not to + be compiled + deps: other cc_capnp_library rules to depend on + src_prefix: src_prefix for capnp compiler to the source root + visibility: rule visibility + target_compatible_with: target compatibility + **kwargs: rest of the arguments to cc_library rule + """ + + hdrs = [s + ".h" for s in srcs] + srcs_cpp = [s + ".c++" for s in srcs] + + _capnp_gen( + name = name + "_gen", + srcs = srcs, + deps = [s + "_gen" for s in deps], + data = data, + outs = hdrs + srcs_cpp, + src_prefix = src_prefix, + visibility = visibility, + target_compatible_with = target_compatible_with, + ) + native.cc_library( + name = name, + srcs = srcs_cpp, + hdrs = hdrs, + deps = deps + ["@capnp-cpp//src/capnp:capnp_runtime"], + visibility = visibility, + target_compatible_with = target_compatible_with, + **kwargs + ) diff --git a/c++/src/capnp/common.h b/c++/src/capnp/common.h index 3fc7a42112..573c9445a1 100644 --- a/c++/src/capnp/common.h +++ b/c++/src/capnp/common.h @@ -23,26 +23,32 @@ // time, but should then be optimized down to basic primitives (usually, integers) by the // compiler. -#ifndef CAPNP_COMMON_H_ -#define CAPNP_COMMON_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include #include +#include // work-around macro conflict with `VOID` #if CAPNP_DEBUG_TYPES #include #endif +#if !defined(CAPNP_HEADER_WARNINGS) || !CAPNP_HEADER_WARNINGS +#define CAPNP_BEGIN_HEADER KJ_BEGIN_SYSTEM_HEADER +#define CAPNP_END_HEADER KJ_END_SYSTEM_HEADER +#else +#define CAPNP_BEGIN_HEADER +#define CAPNP_END_HEADER +#endif + +CAPNP_BEGIN_HEADER + namespace capnp { -#define CAPNP_VERSION_MAJOR 0 -#define CAPNP_VERSION_MINOR 6 -#define CAPNP_VERSION_MICRO 1 +#define CAPNP_VERSION_MAJOR 1 +#define CAPNP_VERSION_MINOR 1 +#define CAPNP_VERSION_MICRO 0 #define CAPNP_VERSION \ (CAPNP_VERSION_MAJOR * 1000000 + CAPNP_VERSION_MINOR * 1000 + CAPNP_VERSION_MICRO) @@ -51,6 +57,12 @@ namespace capnp { #define CAPNP_LITE 0 #endif +#if CAPNP_TESTING_CAPNP // defined in Cap'n Proto's own unit tests; others should not define this +#define CAPNP_DEPRECATED(reason) +#else +#define CAPNP_DEPRECATED KJ_DEPRECATED +#endif + typedef unsigned int uint; struct Void { @@ -161,7 +173,7 @@ inline constexpr Kind kind() { return k; } -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) #define CAPNP_KIND(T) ::capnp::_::Kind_::kind // Avoid constexpr methods in MSVC (it remains buggy in many situations). @@ -187,7 +199,7 @@ inline constexpr Style style() { template struct List; -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) template struct List {}; @@ -302,7 +314,7 @@ namespace _ { // private template struct PointerHelpers; -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) template struct PointerHelpers {}; @@ -316,9 +328,13 @@ struct PointerHelpers {}; } // namespace _ (private) struct MessageSize { - // Size of a message. Every struct type has a method `.totalSize()` that returns this. + // Size of a message. Every struct and list type has a method `.totalSize()` that returns this. uint64_t wordCount; uint capCount; + + inline constexpr MessageSize operator+(const MessageSize& other) const { + return { wordCount + other.wordCount, capCount + other.capCount }; + } }; // ======================================================================================= @@ -326,16 +342,28 @@ struct MessageSize { using kj::byte; -class word { uint64_t content KJ_UNUSED_MEMBER; KJ_DISALLOW_COPY(word); public: word() = default; }; -// word is an opaque type with size of 64 bits. This type is useful only to make pointer -// arithmetic clearer. Since the contents are private, the only way to access them is to first -// reinterpret_cast to some other pointer type. -// -// Copying is disallowed because you should always use memcpy(). Otherwise, you may run afoul of -// aliasing rules. -// -// A pointer of type word* should always be word-aligned even if won't actually be dereferenced as -// that type. +class word { + // word is an opaque type with size of 64 bits. This type is useful only to make pointer + // arithmetic clearer. Since the contents are private, the only way to access them is to first + // reinterpret_cast to some other pointer type. + // + // Copying is disallowed because you should always use memcpy(). Otherwise, you may run afoul of + // aliasing rules. + // + // A pointer of type word* should always be word-aligned even if won't actually be dereferenced + // as that type. +public: + word() = default; +private: + uint64_t content KJ_UNUSED_MEMBER; +#if __GNUC__ < 8 || __clang__ + // GCC 8's -Wclass-memaccess complains whenever we try to memcpy() a `word` if we've disallowed + // the copy constructor. We don't want to disable the warning because it's a useful warning and + // we'd have to disable it for all applications that include this header. Instead we allow `word` + // to be copyable on GCC. + KJ_DISALLOW_COPY_AND_MOVE(word); +#endif +}; static_assert(sizeof(byte) == 1, "uint8_t is not one byte?"); static_assert(sizeof(word) == 8, "uint64_t is not 8 bytes?"); @@ -720,4 +748,4 @@ inline constexpr kj::ArrayPtr arrayPtr(U* ptr, T size) { } // namespace capnp -#endif // CAPNP_COMMON_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/compat/BUILD.bazel b/c++/src/capnp/compat/BUILD.bazel new file mode 100644 index 0000000000..fcaecfa3fa --- /dev/null +++ b/c++/src/capnp/compat/BUILD.bazel @@ -0,0 +1,98 @@ +load("@capnp-cpp//src/capnp:cc_capnp_library.bzl", "cc_capnp_library") + +exports_files([ + "json.capnp", +]) + +# because git contains generated artifacts (which are used to bootstrap the compiler) +# we can't have cc_capnp_library for json.capnp. Expose it as cc library and a file. +cc_library( + name = "json", + srcs = [ + "json.c++", + "json.capnp.c++", + ], + hdrs = [ + "json.capnp.h", + "json.h", + ], + include_prefix = "capnp/compat", + visibility = ["//visibility:public"], + deps = [ + "//src/capnp", + ], +) + +cc_capnp_library( + name = "http-over-capnp_capnp", + srcs = [ + "byte-stream.capnp", + "http-over-capnp.capnp", + ], + include_prefix = "capnp/compat", + src_prefix = "src", + visibility = ["//visibility:public"], +) + +cc_library( + name = "http-over-capnp", + srcs = [ + "byte-stream.c++", + "http-over-capnp.c++", + ], + hdrs = [ + "byte-stream.h", + "http-over-capnp.h", + ], + include_prefix = "capnp/compat", + visibility = ["//visibility:public"], + deps = [ + ":http-over-capnp_capnp", + "//src/kj/compat:kj-http", + ], +) + +cc_library( + name = "websocket-rpc", + srcs = [ + "websocket-rpc.c++", + ], + hdrs = [ + "websocket-rpc.h", + ], + include_prefix = "capnp/compat", + visibility = ["//visibility:public"], + deps = [ + "//src/capnp", + "//src/kj/compat:kj-http", + ], +) + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + deps = [ + ":websocket-rpc", + ":http-over-capnp", + "//src/capnp:capnp-test" + ], +) for f in [ + "byte-stream-test.c++", + "http-over-capnp-test.c++", + "websocket-rpc-test.c++", +]] + +cc_library( + name = "http-over-capnp-test-as-header", + hdrs = ["http-over-capnp-test.c++"], +) + +cc_test( + name = "http-over-capnp-old-test", + srcs = ["http-over-capnp-old-test.c++"], + deps = [ + ":http-over-capnp-test-as-header", + ":http-over-capnp", + "//src/capnp:capnp-test" + ], +) diff --git a/c++/src/capnp/compat/byte-stream-test.c++ b/c++/src/capnp/compat/byte-stream-test.c++ new file mode 100644 index 0000000000..49fede1fec --- /dev/null +++ b/c++/src/capnp/compat/byte-stream-test.c++ @@ -0,0 +1,715 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "byte-stream.h" +#include +#include +#include + +namespace capnp { +namespace { + +kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { + if (expected.size() == 0) return kj::READY_NOW; + + auto buffer = kj::heapArray(expected.size()); + + auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { + if (amount == 0) { + KJ_FAIL_ASSERT("expected data never sent", expected); + } + + auto actual = buffer.slice(0, amount); + if (memcmp(actual.begin(), expected.begin(), actual.size()) != 0) { + KJ_FAIL_ASSERT("data from stream doesn't match expected", expected, actual); + } + + return expectRead(in, expected.slice(amount)); + }); +} + +kj::String makeString(size_t size) { + auto bytes = kj::heapArray(size); + for (char& c: bytes) { + c = 'a' + rand() % 26; + } + bytes[bytes.size() - 1] = 0; + return kj::String(kj::mv(bytes)); +}; + +KJ_TEST("KJ -> ByteStream -> KJ without shortening") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory factory1; + ByteStreamFactory factory2; + + auto pipe = kj::newOneWayPipe(); + + auto wrapped = factory1.capnpToKj(factory2.kjToCapnp(kj::mv(pipe.out))); + + { + auto promise = wrapped->write("foo", 3); + KJ_EXPECT(!promise.poll(waitScope)); + expectRead(*pipe.in, "foo").wait(waitScope); + promise.wait(waitScope); + } + + { + // Write more than 1 << 16 bytes at once to exercise write splitting. + auto str = makeString(1 << 17); + auto promise = wrapped->write(str.begin(), str.size()); + KJ_EXPECT(!promise.poll(waitScope)); + expectRead(*pipe.in, str).wait(waitScope); + promise.wait(waitScope); + } + + { + // Write more than 1 << 16 bytes via an array to exercise write splitting. + auto str = makeString(1 << 18); + auto pieces = kj::heapArrayBuilder>(4); + + // Two 2^15 pieces will be combined. + pieces.add(kj::arrayPtr(reinterpret_cast(str.begin()), 1 << 15)); + pieces.add(kj::arrayPtr(reinterpret_cast(str.begin() + (1 << 15)), 1 << 15)); + + // One 2^16 piece will be written alone. + pieces.add(kj::arrayPtr(reinterpret_cast( + str.begin() + (1 << 16)), 1 << 16)); + + // One 2^17 piece will be split. + pieces.add(kj::arrayPtr(reinterpret_cast( + str.begin() + (1 << 17)), str.size() - (1 << 17))); + + auto promise = wrapped->write(pieces); + KJ_EXPECT(!promise.poll(waitScope)); + expectRead(*pipe.in, str).wait(waitScope); + promise.wait(waitScope); + } + + wrapped = nullptr; + KJ_EXPECT(pipe.in->readAllText().wait(waitScope) == ""); +} + +class ExactPointerWriter: public kj::AsyncOutputStream { +public: + kj::ArrayPtr receivedBuffer; + + void fulfill() { + KJ_ASSERT_NONNULL(fulfiller)->fulfill(); + fulfiller = nullptr; + receivedBuffer = nullptr; + } + + kj::Promise write(const void* buffer, size_t size) override { + KJ_ASSERT(fulfiller == nullptr); + receivedBuffer = kj::arrayPtr(reinterpret_cast(buffer), size); + auto paf = kj::newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + kj::Promise write(kj::ArrayPtr> pieces) override { + KJ_UNIMPLEMENTED("not implemented for test"); + } + kj::Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + void expectBuffer(kj::StringPtr expected) { + KJ_EXPECT(receivedBuffer == expected.asArray(), receivedBuffer, expected); + } + +private: + kj::Maybe>> fulfiller; +}; + +KJ_TEST("KJ -> ByteStream -> KJ with shortening") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory factory; + + auto pipe = kj::newOneWayPipe(); + + ExactPointerWriter exactPointerWriter; + auto pumpPromise = pipe.in->pumpTo(exactPointerWriter); + + auto wrapped = factory.capnpToKj(factory.kjToCapnp(kj::mv(pipe.out))); + + { + char buffer[4] = "foo"; + auto promise = wrapped->write(buffer, 3); + KJ_EXPECT(!promise.poll(waitScope)); + + // This first write won't have been path-shortened because we didn't know about the shorter + // path yet when it started. + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() != buffer); + KJ_EXPECT(kj::str(exactPointerWriter.receivedBuffer) == "foo"); + exactPointerWriter.fulfill(); + promise.wait(waitScope); + } + + { + char buffer[4] = "foo"; + auto promise = wrapped->write(buffer, 3); + KJ_EXPECT(!promise.poll(waitScope)); + + // The second write was path-shortened so passes through the exact buffer! + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 3); + exactPointerWriter.fulfill(); + promise.wait(waitScope); + } + + wrapped = nullptr; + KJ_EXPECT(pipe.in->readAllText().wait(waitScope) == ""); +} + +KJ_TEST("KJ -> ByteStream -> KJ -> ByteStream -> KJ with shortening") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory factory; + + auto pipe = kj::newOneWayPipe(); + + ExactPointerWriter exactPointerWriter; + auto pumpPromise = pipe.in->pumpTo(exactPointerWriter); + + auto wrapped = factory.capnpToKj(factory.kjToCapnp( + factory.capnpToKj(factory.kjToCapnp(kj::mv(pipe.out))))); + + { + char buffer[4] = "foo"; + auto promise = wrapped->write(buffer, 3); + KJ_EXPECT(!promise.poll(waitScope)); + + // This first write won't have been path-shortened because we didn't know about the shorter + // path yet when it started. + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() != buffer); + KJ_EXPECT(kj::str(exactPointerWriter.receivedBuffer) == "foo"); + exactPointerWriter.fulfill(); + promise.wait(waitScope); + } + + { + char buffer[4] = "bar"; + auto promise = wrapped->write(buffer, 3); + KJ_EXPECT(!promise.poll(waitScope)); + + // The second write was path-shortened so passes through the exact buffer! + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 3); + exactPointerWriter.fulfill(); + promise.wait(waitScope); + } + + wrapped = nullptr; + KJ_EXPECT(pumpPromise.wait(waitScope) == 6); +} + +KJ_TEST("KJ -> ByteStream -> KJ pipe -> ByteStream -> KJ with shortening") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory factory; + + auto backPipe = kj::newOneWayPipe(); + auto middlePipe = kj::newOneWayPipe(); + + ExactPointerWriter exactPointerWriter; + auto backPumpPromise = backPipe.in->pumpTo(exactPointerWriter); + + auto backWrapped = factory.capnpToKj(factory.kjToCapnp(kj::mv(backPipe.out))); + auto midPumpPormise = middlePipe.in->pumpTo(*backWrapped, 3); + + auto wrapped = factory.capnpToKj(factory.kjToCapnp(kj::mv(middlePipe.out))); + + // Poll whenWriteDisconnected(), mainly as a way to let all the path-shortening settle. + auto disconnectPromise = wrapped->whenWriteDisconnected(); + KJ_EXPECT(!disconnectPromise.poll(waitScope)); + + char buffer[7] = "foobar"; + auto writePromise = wrapped->write(buffer, 6); + KJ_EXPECT(!writePromise.poll(waitScope)); + + // The first three bytes will tunnel all the way down to the destination. + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 3); + exactPointerWriter.fulfill(); + + KJ_EXPECT(midPumpPormise.wait(waitScope) == 3); + + ExactPointerWriter exactPointerWriter2; + midPumpPormise = middlePipe.in->pumpTo(exactPointerWriter2, 6); + KJ_EXPECT(!writePromise.poll(waitScope)); + + // The second half of the "foobar" write will have taken a slow path, because the write was + // restarted in the middle of the stream re-resolving itself. + KJ_EXPECT(kj::str(exactPointerWriter2.receivedBuffer) == "bar"); + exactPointerWriter2.fulfill(); + + // Now that write is done. + writePromise.wait(waitScope); + KJ_EXPECT(!midPumpPormise.poll(waitScope)); + + // If we write again, it'll hit the fast path. + char buffer2[4] = "baz"; + writePromise = wrapped->write(buffer2, 3); + KJ_EXPECT(!writePromise.poll(waitScope)); + KJ_EXPECT(exactPointerWriter2.receivedBuffer.begin() == buffer2); + KJ_EXPECT(exactPointerWriter2.receivedBuffer.size() == 3); + exactPointerWriter2.fulfill(); + + KJ_EXPECT(midPumpPormise.wait(waitScope) == 6); + writePromise.wait(waitScope); +} + +KJ_TEST("KJ -> ByteStream RPC -> KJ pipe -> ByteStream RPC -> KJ with shortening") { + // For this test, we're going to verify that if we have ByteStreams over RPC in both directions + // and we pump a ByteStream to another ByteStream at one end of the connection, it gets shortened + // all the way to the other end! + + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory clientFactory; + ByteStreamFactory serverFactory; + + auto backPipe = kj::newOneWayPipe(); + auto middlePipe = kj::newOneWayPipe(); + + ExactPointerWriter exactPointerWriter; + auto backPumpPromise = backPipe.in->pumpTo(exactPointerWriter); + + auto rpcConnection = kj::newTwoWayPipe(); + capnp::TwoPartyClient client(*rpcConnection.ends[0], + clientFactory.kjToCapnp(kj::mv(backPipe.out)), + rpc::twoparty::Side::CLIENT); + capnp::TwoPartyClient server(*rpcConnection.ends[1], + serverFactory.kjToCapnp(kj::mv(middlePipe.out)), + rpc::twoparty::Side::SERVER); + + auto backWrapped = serverFactory.capnpToKj(server.bootstrap().castAs()); + auto midPumpPormise = middlePipe.in->pumpTo(*backWrapped, 3); + + auto wrapped = clientFactory.capnpToKj(client.bootstrap().castAs()); + + // Poll whenWriteDisconnected(), mainly as a way to let all the path-shortening settle. + auto disconnectPromise = wrapped->whenWriteDisconnected(); + KJ_EXPECT(!disconnectPromise.poll(waitScope)); + + char buffer[7] = "foobar"; + auto writePromise = wrapped->write(buffer, 6); + + // The server side did a 3-byte pump. Path-shortening magic kicks in, and the first three bytes + // of the write on the client side go *directly* to the endpoint without a copy! + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 3); + exactPointerWriter.fulfill(); + + KJ_EXPECT(midPumpPormise.wait(waitScope) == 3); + + ExactPointerWriter exactPointerWriter2; + midPumpPormise = middlePipe.in->pumpTo(exactPointerWriter2, 6); + midPumpPormise.poll(waitScope); + + // The second half of the "foobar" write will have taken a slow path, because the write was + // restarted in the middle of the stream re-resolving itself. + KJ_EXPECT(kj::str(exactPointerWriter2.receivedBuffer) == "bar"); + exactPointerWriter2.fulfill(); + + // Now that write is done. + writePromise.wait(waitScope); + KJ_EXPECT(!midPumpPormise.poll(waitScope)); + + // If we write again, it'll finish the server-side pump (but won't be a zero-copy write since + // it has to go over RPC). + char buffer2[4] = "baz"; + writePromise = wrapped->write(buffer2, 3); + KJ_EXPECT(!midPumpPormise.poll(waitScope)); + KJ_EXPECT(kj::str(exactPointerWriter2.receivedBuffer) == "baz"); + exactPointerWriter2.fulfill(); + + KJ_EXPECT(midPumpPormise.wait(waitScope) == 6); + writePromise.wait(waitScope); +} + +KJ_TEST("KJ -> ByteStream RPC -> KJ pipe -> ByteStream RPC -> KJ with concurrent shortening") { + // This is similar to the previous test, but we start writing before the path-shortening has + // settled. This should result in some writes optimistically bouncing back and forth before + // the stream settles in. + + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory clientFactory; + ByteStreamFactory serverFactory; + + auto backPipe = kj::newOneWayPipe(); + auto middlePipe = kj::newOneWayPipe(); + + ExactPointerWriter exactPointerWriter; + auto backPumpPromise = backPipe.in->pumpTo(exactPointerWriter); + + auto rpcConnection = kj::newTwoWayPipe(); + capnp::TwoPartyClient client(*rpcConnection.ends[0], + clientFactory.kjToCapnp(kj::mv(backPipe.out)), + rpc::twoparty::Side::CLIENT); + capnp::TwoPartyClient server(*rpcConnection.ends[1], + serverFactory.kjToCapnp(kj::mv(middlePipe.out)), + rpc::twoparty::Side::SERVER); + + auto backWrapped = serverFactory.capnpToKj(server.bootstrap().castAs()); + auto midPumpPormise = middlePipe.in->pumpTo(*backWrapped); + + auto wrapped = clientFactory.capnpToKj(client.bootstrap().castAs()); + + char buffer[7] = "foobar"; + auto writePromise = wrapped->write(buffer, 6); + + // The write went to RPC so it's not immediately received. + KJ_EXPECT(exactPointerWriter.receivedBuffer == nullptr); + + // Write should be received after we turn the event loop. + waitScope.poll(); + KJ_EXPECT(exactPointerWriter.receivedBuffer != nullptr); + + // Note that the promise that write() returned above has already resolved, because it hit RPC + // and went into the streaming window. + KJ_ASSERT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); + + // Let's start a second write. Even though the first write technically isn't done yet, it's + // legal for us to start a second one because the first write's returned promise optimistically + // resolved for streaming window reasons. This ends up being a very tricky case for our code! + char buffer2[7] = "bazqux"; + auto writePromise2 = wrapped->write(buffer2, 6); + + // Now check the first write was correct, and close it out. + KJ_EXPECT(kj::str(exactPointerWriter.receivedBuffer) == "foobar"); + exactPointerWriter.fulfill(); + + // Turn event loop again. Now the second write arrives. + waitScope.poll(); + KJ_EXPECT(kj::str(exactPointerWriter.receivedBuffer) == "bazqux"); + exactPointerWriter.fulfill(); + writePromise2.wait(waitScope); + + // If we do another write now, it should be zero-copy, because everything has settled. + char buffer3[6] = "corge"; + auto writePromise3 = wrapped->write(buffer3, 5); + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer3); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 5); + KJ_EXPECT(!writePromise3.poll(waitScope)); + exactPointerWriter.fulfill(); + writePromise3.wait(waitScope); +} + +KJ_TEST("KJ -> KJ pipe -> ByteStream RPC -> KJ pipe -> ByteStream RPC -> KJ with concurrent shortening") { + // Same as previous test, except we add a KJ pipe at the beginning and pump it into the top of + // the pipe, which invokes tryPumpFrom() on the KjToCapnpStreamAdapter. + + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory clientFactory; + ByteStreamFactory serverFactory; + + auto backPipe = kj::newOneWayPipe(); + auto middlePipe = kj::newOneWayPipe(); + auto frontPipe = kj::newOneWayPipe(); + + ExactPointerWriter exactPointerWriter; + auto backPumpPromise = backPipe.in->pumpTo(exactPointerWriter); + + auto rpcConnection = kj::newTwoWayPipe(); + capnp::TwoPartyClient client(*rpcConnection.ends[0], + clientFactory.kjToCapnp(kj::mv(backPipe.out)), + rpc::twoparty::Side::CLIENT); + capnp::TwoPartyClient server(*rpcConnection.ends[1], + serverFactory.kjToCapnp(kj::mv(middlePipe.out)), + rpc::twoparty::Side::SERVER); + + auto backWrapped = serverFactory.capnpToKj(server.bootstrap().castAs()); + auto midPumpPormise = middlePipe.in->pumpTo(*backWrapped); + + auto wrapped = clientFactory.capnpToKj(client.bootstrap().castAs()); + auto frontPumpPromise = frontPipe.in->pumpTo(*wrapped); + + char buffer[7] = "foobar"; + auto writePromise = frontPipe.out->write(buffer, 6); + + // The write went to RPC so it's not immediately received. + KJ_EXPECT(exactPointerWriter.receivedBuffer == nullptr); + + // Write should be received after we turn the event loop. + waitScope.poll(); + KJ_EXPECT(exactPointerWriter.receivedBuffer != nullptr); + + // Note that the promise that write() returned above has already resolved, because it hit RPC + // and went into the streaming window. + KJ_ASSERT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); + + // Let's start a second write. Even though the first write technically isn't done yet, it's + // legal for us to start a second one because the first write's returned promise optimistically + // resolved for streaming window reasons. This ends up being a very tricky case for our code! + char buffer2[7] = "bazqux"; + auto writePromise2 = frontPipe.out->write(buffer2, 6); + + // Now check the first write was correct, and close it out. + KJ_EXPECT(kj::str(exactPointerWriter.receivedBuffer) == "foobar"); + exactPointerWriter.fulfill(); + + // Turn event loop again. Now the second write arrives. + waitScope.poll(); + KJ_EXPECT(kj::str(exactPointerWriter.receivedBuffer) == "bazqux"); + exactPointerWriter.fulfill(); + writePromise2.wait(waitScope); + + // If we do another write now, it should be zero-copy, because everything has settled. + char buffer3[6] = "corge"; + auto writePromise3 = frontPipe.out->write(buffer3, 5); + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer3); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 5); + KJ_EXPECT(!writePromise3.poll(waitScope)); + exactPointerWriter.fulfill(); + writePromise3.wait(waitScope); +} + +KJ_TEST("Two Substreams on one destination") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory factory; + + auto backPipe = kj::newOneWayPipe(); + auto middlePipe1 = kj::newOneWayPipe(); + auto middlePipe2 = kj::newOneWayPipe(); + + ExactPointerWriter exactPointerWriter; + auto backPumpPromise = backPipe.in->pumpTo(exactPointerWriter); + + auto backWrapped = factory.capnpToKj(factory.kjToCapnp(kj::mv(backPipe.out))); + + auto wrapped1 = factory.capnpToKj(factory.kjToCapnp(kj::mv(middlePipe1.out))); + auto wrapped2 = factory.capnpToKj(factory.kjToCapnp(kj::mv(middlePipe2.out))); + + // Declare these buffers out here so that they can't possibly end up with the same address. + char buffer1[4] = "foo"; + char buffer2[4] = "bar"; + + { + auto wrapped = kj::mv(wrapped1); + + // First pump 3 bytes from the first stream. + auto midPumpPormise = middlePipe1.in->pumpTo(*backWrapped, 3); + + // Poll whenWriteDisconnected(), mainly as a way to let all the path-shortening settle. + auto disconnectPromise = wrapped->whenWriteDisconnected(); + KJ_EXPECT(!disconnectPromise.poll(waitScope)); + + auto writePromise = wrapped->write(buffer1, 3); + KJ_EXPECT(!writePromise.poll(waitScope)); + + // The first write will tunnel all the way down to the destination. + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer1); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 3); + exactPointerWriter.fulfill(); + + writePromise.wait(waitScope); + KJ_EXPECT(midPumpPormise.wait(waitScope) == 3); + } + + { + auto wrapped = kj::mv(wrapped2); + + // Now pump another 3 bytes from the second stream. + auto midPumpPormise = middlePipe2.in->pumpTo(*backWrapped, 3); + + // Poll whenWriteDisconnected(), mainly as a way to let all the path-shortening settle. + auto disconnectPromise = wrapped->whenWriteDisconnected(); + KJ_EXPECT(!disconnectPromise.poll(waitScope)); + + auto writePromise = wrapped->write(buffer2, 3); + KJ_EXPECT(!writePromise.poll(waitScope)); + + // The second write will also tunnel all the way down to the destination. + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer2); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 3); + exactPointerWriter.fulfill(); + + writePromise.wait(waitScope); + KJ_EXPECT(midPumpPormise.wait(waitScope) == 3); + } +} + +KJ_TEST("Two Substreams on one destination no limits (pump to EOF)") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory factory; + + auto backPipe = kj::newOneWayPipe(); + auto middlePipe1 = kj::newOneWayPipe(); + auto middlePipe2 = kj::newOneWayPipe(); + + ExactPointerWriter exactPointerWriter; + auto backPumpPromise = backPipe.in->pumpTo(exactPointerWriter); + + auto backWrapped = factory.capnpToKj(factory.kjToCapnp(kj::mv(backPipe.out))); + + auto wrapped1 = factory.capnpToKj(factory.kjToCapnp(kj::mv(middlePipe1.out))); + auto wrapped2 = factory.capnpToKj(factory.kjToCapnp(kj::mv(middlePipe2.out))); + + // Declare these buffers out here so that they can't possibly end up with the same address. + char buffer1[4] = "foo"; + char buffer2[4] = "bar"; + + { + auto wrapped = kj::mv(wrapped1); + + // First pump from the first stream until EOF. + auto midPumpPormise = middlePipe1.in->pumpTo(*backWrapped); + + // Poll whenWriteDisconnected(), mainly as a way to let all the path-shortening settle. + auto disconnectPromise = wrapped->whenWriteDisconnected(); + KJ_EXPECT(!disconnectPromise.poll(waitScope)); + + auto writePromise = wrapped->write(buffer1, 3); + KJ_EXPECT(!writePromise.poll(waitScope)); + + // The first write will tunnel all the way down to the destination. + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer1); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 3); + exactPointerWriter.fulfill(); + + writePromise.wait(waitScope); + { auto drop = kj::mv(wrapped); } + KJ_EXPECT(midPumpPormise.wait(waitScope) == 3); + } + + { + auto wrapped = kj::mv(wrapped2); + + // Now pump from the second stream until EOF. + auto midPumpPormise = middlePipe2.in->pumpTo(*backWrapped); + + // Poll whenWriteDisconnected(), mainly as a way to let all the path-shortening settle. + auto disconnectPromise = wrapped->whenWriteDisconnected(); + KJ_EXPECT(!disconnectPromise.poll(waitScope)); + + auto writePromise = wrapped->write(buffer2, 3); + KJ_EXPECT(!writePromise.poll(waitScope)); + + // The second write will also tunnel all the way down to the destination. + KJ_EXPECT(exactPointerWriter.receivedBuffer.begin() == buffer2); + KJ_EXPECT(exactPointerWriter.receivedBuffer.size() == 3); + exactPointerWriter.fulfill(); + + writePromise.wait(waitScope); + { auto drop = kj::mv(wrapped); } + KJ_EXPECT(midPumpPormise.wait(waitScope) == 3); + } +} + +KJ_TEST("KJ -> ByteStream RPC -> KJ promise stream -> ByteStream -> KJ") { + // Test what happens if we queue up several requests on a ByteStream and then it resolves to + // a shorter path. + + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory factory; + ExactPointerWriter exactPointerWriter; + + auto paf = kj::newPromiseAndFulfiller>(); + auto backCap = factory.kjToCapnp(kj::newPromisedStream(kj::mv(paf.promise))); + + auto rpcPipe = kj::newTwoWayPipe(); + capnp::TwoPartyClient client(*rpcPipe.ends[0]); + capnp::TwoPartyClient server(*rpcPipe.ends[1], kj::mv(backCap), rpc::twoparty::Side::SERVER); + auto front = factory.capnpToKj(client.bootstrap().castAs()); + + // These will all queue up in the RPC layer. + front->write("foo", 3).wait(waitScope); + front->write("bar", 3).wait(waitScope); + front->write("baz", 3).wait(waitScope); + front->write("qux", 3).wait(waitScope); + + // Make sure those writes manage to get all the way through the RPC system and queue up in the + // LocalClient wrapping the CapnpToKjStreamAdapter at the other end. + waitScope.poll(); + + // Fulfill the promise. + paf.fulfiller->fulfill(factory.capnpToKj(factory.kjToCapnp(kj::attachRef(exactPointerWriter)))); + waitScope.poll(); + + // Now: + // - "foo" should have made it all the way down to the final output stream. + // - "bar", "baz", and "qux" are queued on the CapnpToKjStreamAdapter immediately wrapping the + // KJ promise stream. + // - But that stream adapter has discovered that there's another capnp stream downstream and has + // resolved itself to the later stream. + // - A new call at this time should NOT be allowed to hop the queue. + + exactPointerWriter.expectBuffer("foo"); + + front->write("corge", 5).wait(waitScope); + waitScope.poll(); + + exactPointerWriter.fulfill(); + + waitScope.poll(); + exactPointerWriter.expectBuffer("bar"); + exactPointerWriter.fulfill(); + + waitScope.poll(); + exactPointerWriter.expectBuffer("baz"); + exactPointerWriter.fulfill(); + + waitScope.poll(); + exactPointerWriter.expectBuffer("qux"); + exactPointerWriter.fulfill(); + + waitScope.poll(); + exactPointerWriter.expectBuffer("corge"); + exactPointerWriter.fulfill(); + + // There may still be some detach()ed promises holding on to some capabilities that transitively + // hold a fake Own pointing at exactPointerWriter, which is actually on the + // stack. We created a fake Own pointing to a stack variable by using + // kj::attachRef(exactPointerWriter), above; it does not actually own the object it points to. + // We need to make sure those Owns are dropped before exactPoniterWriter is destroyed, otherwise + // ASAN will flag some invalid reads (of exactPointerWriter's vtable, in particular). + waitScope.cancelAllDetached(); +} + +// TODO: +// - Parallel writes (requires streaming) +// - Write to KJ -> capnp -> RPC -> capnp -> KJ loopback without shortening, verify we can write +// several things to buffer (requires streaming). +// - Again, but with shortening which only occurs after some promise resolve. + +} // namespace +} // namespace capnp diff --git a/c++/src/capnp/compat/byte-stream.c++ b/c++/src/capnp/compat/byte-stream.c++ new file mode 100644 index 0000000000..abb1606628 --- /dev/null +++ b/c++/src/capnp/compat/byte-stream.c++ @@ -0,0 +1,1162 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "byte-stream.h" +#include +#include + +namespace capnp { + +const uint MAX_BYTES_PER_WRITE = 1 << 16; + +class ByteStreamFactory::StreamServerBase: public capnp::ByteStream::Server { +public: + virtual void returnStream(uint64_t written) = 0; + // Called after the StreamServerBase's internal kj::AsyncOutputStream has been borrowed, to + // indicate that the borrower is done. + // + // A stream becomes borrowed either when getShortestPath() returns a BorrowedStream, or when + // a SubstreamImpl is constructed wrapping an existing stream. + + struct BorrowedStream { + // Represents permission to use the StreamServerBase's inner AsyncOutputStream directly, up + // to some limit of bytes written. + + StreamServerBase& lender; + kj::AsyncOutputStream& stream; + uint64_t limit; + }; + + typedef kj::OneOf, capnp::ByteStream::Client*, BorrowedStream> ShortestPath; + + virtual ShortestPath getShortestPath() = 0; + // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client + // actually points back to a StreamServerBase in the same process created by the same + // ByteStreamFactory. Returns the best shortened path to use, or a promise that resolves when the + // shortest path is known. + + virtual void directEnd() = 0; + // Called by KjToCapnpStreamAdapter's destructor when it has determined that its inner + // ByteStream::Client actually points back to a StreamServerBase in the same process created by + // the same ByteStreamFactory. Since destruction of a KJ stream signals EOF, we need to propagate + // that by destroying our underlying stream. + // TODO(cleanup): When KJ streams evolve an end() method, this can go away. + + virtual kj::Promise directExplicitEnd() = 0; + // Like directEnd() but used in cases where an explicit end() call actually was made. +}; + +class ByteStreamFactory::SubstreamImpl final: public StreamServerBase { +public: + SubstreamImpl(ByteStreamFactory& factory, + StreamServerBase& parent, + capnp::ByteStream::Client ownParent, + kj::AsyncOutputStream& stream, + capnp::ByteStream::SubstreamCallback::Client callback, + uint64_t limit, + kj::Maybe> tlsStarter, + kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller()) + : factory(factory), + state(Streaming {parent, kj::mv(ownParent), stream, kj::mv(callback), kj::mv(tlsStarter)}), + limit(limit), + resolveFulfiller(kj::mv(paf.fulfiller)), + resolvePromise(paf.promise.fork()) {} + + // --------------------------------------------------------------------------- + // implements StreamServerBase + + void returnStream(uint64_t written) override { + completed += written; + KJ_ASSERT(completed <= limit); + auto borrowed = kj::mv(state.get()); + state = kj::mv(borrowed.originalState); + + if (completed == limit) { + limitReached(); + } + } + + ShortestPath getShortestPath() override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(redirected, Redirected) { + return &redirected.replacement; + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("already called end()"); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("can't call other methods while substream is active"); + } + KJ_CASE_ONEOF(streaming, Streaming) { + auto& stream = streaming.stream; + auto oldState = kj::mv(streaming); + state = Borrowed { kj::mv(oldState) }; + return BorrowedStream { *this, stream, limit - completed }; + } + } + KJ_UNREACHABLE; + } + + void directEnd() override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(redirected, Redirected) { + // Ugh I guess we need to send a real end() request here. + redirected.replacement.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){}); + } + KJ_CASE_ONEOF(e, Ended) { + // whatever + } + KJ_CASE_ONEOF(b, Borrowed) { + // ... whatever. + } + KJ_CASE_ONEOF(streaming, Streaming) { + auto req = streaming.callback.endedRequest(MessageSize {4, 0}); + req.setByteCount(completed); + req.send().detach([](kj::Exception&&){}); + streaming.parent.returnStream(completed); + state = Ended(); + } + } + } + + kj::Promise directExplicitEnd() override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(redirected, Redirected) { + // Ugh I guess we need to send a real end() request here. + return redirected.replacement.endRequest(MessageSize {2, 0}).send().ignoreResult(); + } + KJ_CASE_ONEOF(e, Ended) { + // whatever + return kj::READY_NOW; + } + KJ_CASE_ONEOF(b, Borrowed) { + // ... whatever. + return kj::READY_NOW; + } + KJ_CASE_ONEOF(streaming, Streaming) { + auto req = streaming.callback.endedRequest(MessageSize {4, 0}); + req.setByteCount(completed); + auto promise = req.send().ignoreResult(); + streaming.parent.returnStream(completed); + state = Ended(); + return promise; + } + } + KJ_UNREACHABLE; + } + + // --------------------------------------------------------------------------- + // implements ByteStream::Server RPC interface + + kj::Maybe> shortenPath() override { + return resolvePromise.addBranch() + .then([this]() -> Capability::Client { + return state.get().replacement; + }); + } + + kj::Promise write(WriteContext context) override { + auto params = context.getParams(); + auto data = params.getBytes(); + + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(redirected, Redirected) { + auto req = redirected.replacement.writeRequest(params.totalSize()); + req.setBytes(data); + return req.send(); + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("already called end()"); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed"); + } + KJ_CASE_ONEOF(streaming, Streaming) { + if (completed + data.size() < limit) { + completed += data.size(); + return streaming.stream.write(data.begin(), data.size()); + } else { + // This write passes the limit. + uint64_t remainder = limit - completed; + auto leftover = data.slice(remainder, data.size()); + return streaming.stream.write(data.begin(), remainder) + .then([this, leftover]() -> kj::Promise { + completed = limit; + limitReached(); + + if (leftover.size() > 0) { + // Need to forward the leftover bytes to the next stream. + auto req = state.get().replacement.writeRequest( + MessageSize { 4 + leftover.size() / sizeof(capnp::word), 0 }); + req.setBytes(leftover); + return req.send(); + } else { + return kj::READY_NOW; + } + }); + } + } + } + KJ_UNREACHABLE; + } + + kj::Promise end(EndContext context) override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(redirected, Redirected) { + return context.tailCall(redirected.replacement.endRequest(MessageSize {2,0})); + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("already called end()"); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed"); + } + KJ_CASE_ONEOF(streaming, Streaming) { + // Revoke the TLS starter when stream is ended. This will ensure any startTls calls + // cannot be falsely invoked after the stream is destroyed. + auto drop = kj::mv(streaming.tlsStarter); + + auto req = streaming.callback.endedRequest(MessageSize {4, 0}); + req.setByteCount(completed); + auto result = req.send().ignoreResult(); + streaming.parent.returnStream(completed); + state = Ended(); + return result; + } + } + KJ_UNREACHABLE; + } + + kj::Promise startTls(StartTlsContext context) override { + KJ_UNIMPLEMENTED("A substream does not support TLS initiation"); + } + + kj::Promise getSubstream(GetSubstreamContext context) override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(redirected, Redirected) { + auto params = context.getParams(); + auto req = redirected.replacement.getSubstreamRequest(params.totalSize()); + req.setCallback(params.getCallback()); + req.setLimit(params.getLimit()); + return context.tailCall(kj::mv(req)); + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("already called end()"); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed"); + } + KJ_CASE_ONEOF(streaming, Streaming) { + auto params = context.getParams(); + auto callback = params.getCallback(); + auto limit = params.getLimit(); + context.releaseParams(); + auto results = context.getResults(MessageSize { 2, 1 }); + results.setSubstream(factory.streamSet.add(kj::heap( + factory, *this, thisCap(), streaming.stream, kj::mv(callback), kj::mv(limit), + kj::mv(streaming.tlsStarter)))); + state = Borrowed { kj::mv(streaming) }; + return kj::READY_NOW; + } + } + KJ_UNREACHABLE; + } + +private: + ByteStreamFactory& factory; + + struct Streaming { + StreamServerBase& parent; + capnp::ByteStream::Client ownParent; + kj::AsyncOutputStream& stream; + capnp::ByteStream::SubstreamCallback::Client callback; + kj::Maybe> tlsStarter; + }; + struct Borrowed { + Streaming originalState; + }; + struct Redirected { + capnp::ByteStream::Client replacement; + }; + struct Ended {}; + + kj::OneOf state; + + uint64_t limit; + uint64_t completed = 0; + + kj::Own> resolveFulfiller; + kj::ForkedPromise resolvePromise; + + void limitReached() { + auto& streaming = state.get(); + auto next = streaming.callback.reachedLimitRequest(capnp::MessageSize {2,0}) + .send().getNext(); + + // Set the next stream as our replacement. + streaming.parent.returnStream(limit); + state = Redirected { kj::mv(next) }; + resolveFulfiller->fulfill(); + } +}; + +// ======================================================================================= + +class ByteStreamFactory::CapnpToKjStreamAdapter final: public StreamServerBase { + // Implements Cap'n Proto ByteStream as a wrapper around a KJ stream. + + class SubstreamCallbackImpl; + +public: + class PathProber; + + CapnpToKjStreamAdapter(ByteStreamFactory& factory, + kj::Own inner) + : factory(factory), + state(kj::heap(*this, kj::mv(inner))) { + state.get>()->startProbing(); + } + + CapnpToKjStreamAdapter(ByteStreamFactory& factory, + kj::Own inner, + kj::Maybe> starter) + : factory(factory), + tlsStarter(kj::mv(starter)), + state(kj::heap(*this, kj::mv(inner))) { + state.get>()->startProbing(); + } + + CapnpToKjStreamAdapter(ByteStreamFactory& factory, + kj::Own pathProber) + : factory(factory), + state(kj::mv(pathProber)) { + state.get>()->setNewParent(*this); + } + + // --------------------------------------------------------------------------- + // implements StreamServerBase + + void returnStream(uint64_t written) override { + auto stream = kj::mv(state.get().stream); + state = kj::mv(stream); + } + + ShortestPath getShortestPath() override { + // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client + // actually points back to a CapnpToKjStreamAdapter in the same process. Returns the best + // shortened path to use, or a promise that resolves when the shortest path is known. + + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + return prober->whenReady(); + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + auto& streamRef = *kjStream; + state = Borrowed { kj::mv(kjStream) }; + return StreamServerBase::BorrowedStream { *this, streamRef, kj::maxValue }; + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + return &capnpStream; + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } + return kj::Promise(kj::READY_NOW); + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("already ended") { break; } + return kj::Promise(kj::READY_NOW); + } + } + KJ_UNREACHABLE; + } + + void directEnd() override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + state = Ended(); + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + state = Ended(); + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + // Ugh I guess we need to send a real end() request here. + capnpStream.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){}); + } + KJ_CASE_ONEOF(b, Borrowed) { + // Fine, ignore. + } + KJ_CASE_ONEOF(e, Ended) { + // Fine, ignore. + } + } + } + + kj::Promise directExplicitEnd() override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + state = Ended(); + return kj::READY_NOW; + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + state = Ended(); + return kj::READY_NOW; + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + // Ugh I guess we need to send a real end() request here. + return capnpStream.endRequest(MessageSize {2, 0}).send().ignoreResult(); + } + KJ_CASE_ONEOF(b, Borrowed) { + // Fine, ignore. + return kj::READY_NOW; + } + KJ_CASE_ONEOF(e, Ended) { + // Fine, ignore. + return kj::READY_NOW; + } + } + KJ_UNREACHABLE; + } + + // --------------------------------------------------------------------------- + // PathProber + + class PathProber final: public kj::AsyncInputStream { + public: + PathProber(CapnpToKjStreamAdapter& parent, kj::Own inner, + kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller()) + : parent(parent), inner(kj::mv(inner)), + readyPromise(paf.promise.fork()), + readyFulfiller(kj::mv(paf.fulfiller)), + task(nullptr) {} + + void startProbing() { + task = probeForShorterPath(); + } + + void setNewParent(CapnpToKjStreamAdapter& newParent) { + KJ_ASSERT(parent == nullptr); + parent = newParent; + auto paf = kj::newPromiseAndFulfiller(); + readyPromise = paf.promise.fork(); + readyFulfiller = kj::mv(paf.fulfiller); + } + + kj::Promise whenReady() { + return readyPromise.addBranch(); + } + + kj::Promise pumpToShorterPath(capnp::ByteStream::Client target, uint64_t limit) { + // If our probe succeeds in finding a KjToCapnpStreamAdapter somewhere down the stack, that + // will call this method to provide the shortened path. + + KJ_IF_MAYBE(currentParent, parent) { + parent = nullptr; + + auto self = kj::mv(currentParent->state.get>()); + currentParent->state = Ended(); // temporary, we'll set this properly below + KJ_ASSERT(self.get() == this); + + // Open a substream on the target stream. + auto req = target.getSubstreamRequest(); + req.setLimit(limit); + auto paf = kj::newPromiseAndFulfiller(); + req.setCallback(kj::heap(currentParent->factory, + kj::mv(self), kj::mv(paf.fulfiller), limit)); + + // Now we hook up the incoming stream adapter to point directly to this substream, yay. + currentParent->state = req.send().getSubstream(); + + // Let the original CapnpToKjStreamAdapter know that it's safe to handle incoming requests. + readyFulfiller->fulfill(); + + // It's now up to the SubstreamCallbackImpl to signal when the pump is done. + return kj::mv(paf.promise); + } else { + // We already completed a path-shortening. Probably SubstreamCallbackImpl::ended() was + // eventually called, meaning the substream was ended without redirecting back to us. So, + // we're at EOF. + return kj::constPromise(); + } + } + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // If this is called, it means the tryPumpFrom() in probeForShorterPath() eventually invoked + // code that tries to read manually from the source. We don't know what this code is doing + // exactly, but we do know for sure that the endpoint is not a KjToCapnpStreamAdapter, so + // we can't optimize. Instead, we pretend that we immediately hit EOF, ending the pump. This + // works because pumps do not propagate EOF -- the destination can still receive further + // writes and pumps. Basically our probing pump becomes a no-op, and then we revert to having + // each write() RPC directly call write() on the inner stream. + return kj::constPromise(); + } + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + // Call the stream's `tryPumpFrom()` as a way to discover where the data will eventually go, + // in hopes that we find we can shorten the path. + KJ_IF_MAYBE(promise, output.tryPumpFrom(*this, amount)) { + // tryPumpFrom() returned non-null. Either it called `tryRead()` or `pumpTo()` (see + // below), or it plans to do so in the future. + return kj::mv(*promise); + } else { + // There is no shorter path. As with tryRead(), we pretend we get immediate EOF. + return kj::constPromise(); + } + } + + private: + kj::Maybe parent; + kj::Own inner; + kj::ForkedPromise readyPromise; + kj::Own> readyFulfiller; + kj::Promise task; + + friend class SubstreamCallbackImpl; + + kj::Promise probeForShorterPath() { + return kj::evalNow([&]() -> kj::Promise { + return pumpTo(*inner, kj::maxValue); + }).then([this](uint64_t actual) { + KJ_IF_MAYBE(currentParent, parent) { + KJ_IF_MAYBE(prober, currentParent->state.tryGet>()) { + // Either we didn't find any shorter path at all during probing and faked an EOF + // to get out of the probe (see comments in tryRead(), or we DID find a shorter path, + // completed a pumpTo() using a substream, and that substream redirected back to us, + // and THEN we couldn't find any further shorter paths for subsequent pumps. + + // HACK: If we overwrite the Probing state now, we'll delete ourselves and delete + // this task promise, which is an error... let the event loop do it later by + // detaching. + task.attach(kj::mv(*prober)).detach([](kj::Exception&&){}); + parent = nullptr; + + // OK, now we can change the parent state and signal it to proceed. + currentParent->state = kj::mv(inner); + readyFulfiller->fulfill(); + } + } + }).eagerlyEvaluate([this](kj::Exception&& exception) mutable { + // Something threw, so propagate the exception to break the parent. + readyFulfiller->reject(kj::mv(exception)); + }); + } + }; + +protected: + // --------------------------------------------------------------------------- + // implements ByteStream::Server RPC interface + + kj::Maybe> shortenPath() override { + return shortenPathImpl(); + } + kj::Promise shortenPathImpl() { + // Called by RPC implementation to find out if a shorter path presents itself. + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + return prober->whenReady().then([this]() { + KJ_ASSERT(!state.is>()); + return shortenPathImpl(); + }); + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + // No shortening possible. Pretend we never resolve so that calls continue to be routed + // to us forever. + return kj::NEVER_DONE; + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + return Capability::Client(capnpStream); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } + return kj::NEVER_DONE; + } + KJ_CASE_ONEOF(e, Ended) { + // No shortening possible. Pretend we never resolve so that calls continue to be routed + // to us forever. + return kj::NEVER_DONE; + } + } + KJ_UNREACHABLE; + } + + kj::Promise write(WriteContext context) override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + return prober->whenReady().then([this, context]() mutable { + KJ_ASSERT(!state.is>()); + return write(context); + }); + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + auto data = context.getParams().getBytes(); + return kjStream->write(data.begin(), data.size()); + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + auto params = context.getParams(); + auto req = capnpStream.writeRequest(params.totalSize()); + req.setBytes(params.getBytes()); + return req.send(); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } + return kj::READY_NOW; + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("already called end()") { break; } + return kj::READY_NOW; + } + } + KJ_UNREACHABLE; + } + + kj::Promise end(EndContext context) override { + // Revoke the TLS starter when stream is ended. This will ensure any startTls calls + // cannot be falsely invoked after the stream is destroyed. + auto drop = kj::mv(tlsStarter); + + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + return prober->whenReady().then([this, context]() mutable { + KJ_ASSERT(!state.is>()); + return end(context); + }); + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + // TODO(someday): When KJ adds a proper .end() call, use it here. For now, we must + // drop the stream to close it. + state = Ended(); + return kj::READY_NOW; + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + auto params = context.getParams(); + auto req = capnpStream.endRequest(params.totalSize()); + return context.tailCall(kj::mv(req)); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } + return kj::READY_NOW; + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("already called end()") { break; } + return kj::READY_NOW; + } + } + KJ_UNREACHABLE; + } + + kj::Promise startTls(StartTlsContext context) override { + auto params = context.getParams(); + KJ_IF_MAYBE(s, tlsStarter) { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + return KJ_ASSERT_NONNULL(*s->get())(params.getExpectedServerHostname()); + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + return KJ_ASSERT_NONNULL(*s->get())(params.getExpectedServerHostname()); + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + return KJ_ASSERT_NONNULL(*s->get())(params.getExpectedServerHostname()); + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("cannot call startTls on a bytestream that was ended"); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("can't call startTls while stream is borrowed"); + } + } + } + KJ_UNREACHABLE; + } + + kj::Promise getSubstream(GetSubstreamContext context) override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(prober, kj::Own) { + return prober->whenReady().then([this, context]() mutable { + KJ_ASSERT(!state.is>()); + return getSubstream(context); + }); + } + KJ_CASE_ONEOF(kjStream, kj::Own) { + auto params = context.getParams(); + auto callback = params.getCallback(); + uint64_t limit = params.getLimit(); + context.releaseParams(); + + auto results = context.initResults(MessageSize {2, 1}); + results.setSubstream(factory.streamSet.add(kj::heap( + factory, *this, thisCap(), *kjStream, kj::mv(callback), kj::mv(limit), + kj::mv(tlsStarter)))); + state = Borrowed { kj::mv(kjStream) }; + return kj::READY_NOW; + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { + auto params = context.getParams(); + auto req = capnpStream.getSubstreamRequest(params.totalSize()); + req.setCallback(params.getCallback()); + req.setLimit(params.getLimit()); + return context.tailCall(kj::mv(req)); + } + KJ_CASE_ONEOF(b, Borrowed) { + KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } + return kj::READY_NOW; + } + KJ_CASE_ONEOF(e, Ended) { + KJ_FAIL_REQUIRE("already called end()") { break; } + return kj::READY_NOW; + } + } + KJ_UNREACHABLE; + } + +private: + ByteStreamFactory& factory; + kj::Maybe> tlsStarter; + + struct Borrowed { kj::Own stream; }; + struct Ended {}; + + kj::OneOf, kj::Own, + capnp::ByteStream::Client, Borrowed, Ended> state; + + class SubstreamCallbackImpl final: public capnp::ByteStream::SubstreamCallback::Server { + public: + SubstreamCallbackImpl(ByteStreamFactory& factory, + kj::Own pathProber, + kj::Own> originalPumpfulfiller, + uint64_t originalPumpLimit) + : factory(factory), + pathProber(kj::mv(pathProber)), + originalPumpfulfiller(kj::mv(originalPumpfulfiller)), + originalPumpLimit(originalPumpLimit) {} + + ~SubstreamCallbackImpl() noexcept(false) { + if (!done) { + originalPumpfulfiller->reject(KJ_EXCEPTION(DISCONNECTED, + "stream disconnected because SubstreamCallbackImpl was never called back")); + } + } + + kj::Promise ended(EndedContext context) override { + KJ_REQUIRE(!done); + uint64_t actual = context.getParams().getByteCount(); + KJ_REQUIRE(actual <= originalPumpLimit); + + done = true; + + // EOF before pump completed. Signal a short pump. + originalPumpfulfiller->fulfill(context.getParams().getByteCount()); + + // Give the original pump task a chance to finish up. + return pathProber->task.attach(kj::mv(pathProber)); + } + + kj::Promise reachedLimit(ReachedLimitContext context) override { + KJ_REQUIRE(!done); + done = true; + + // Allow the shortened stream to redirect back to our original underlying stream. + auto results = context.getResults(capnp::MessageSize { 4, 1 }); + results.setNext(factory.streamSet.add( + kj::heap(factory, kj::mv(pathProber)))); + + // The full pump completed. Note that it's important that we fulfill this after the + // PathProber has been attached to the new CapnpToKjStreamAdapter, which will have happened + // in CapnpToKjStreamAdapter's constructor, which calls pathProber->setNewParent(). + originalPumpfulfiller->fulfill(kj::cp(originalPumpLimit)); + + return kj::READY_NOW; + } + + private: + ByteStreamFactory& factory; + kj::Own pathProber; + kj::Own> originalPumpfulfiller; + uint64_t originalPumpLimit; + bool done = false; + }; +}; + +// ======================================================================================= + +class ByteStreamFactory::KjToCapnpStreamAdapter final: public ExplicitEndOutputStream { +public: + KjToCapnpStreamAdapter(ByteStreamFactory& factory, capnp::ByteStream::Client innerParam, + bool explicitEnd) + : factory(factory), + inner(kj::mv(innerParam)), + findShorterPathTask(findShorterPath(inner).fork()), + explicitEnd(explicitEnd) {} + + ~KjToCapnpStreamAdapter() noexcept(false) { + if (!explicitEnd) { + // HACK: KJ streams are implicitly ended on destruction, but the RPC stream needs a call. We + // use a detached promise for now, which is probably OK since capabilities are refcounted and + // asynchronously destroyed anyway. + // TODO(cleanup): Fix this when KJ streads add an explicit end() method. + KJ_IF_MAYBE(o, optimized) { + o->directEnd(); + } else { + inner.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){}); + } + } + } + + kj::Promise end() override { + KJ_REQUIRE(explicitEnd, "not expecting explicit end"); + + KJ_IF_MAYBE(o, optimized) { + return o->directExplicitEnd(); + } else { + return inner.endRequest(MessageSize {2, 0}).send().ignoreResult(); + } + } + + kj::Promise write(const void* buffer, size_t size) override { + KJ_SWITCH_ONEOF(getShortestPath()) { + KJ_CASE_ONEOF(promise, kj::Promise) { + return promise.then([this,buffer,size]() { + return write(buffer, size); + }); + } + KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) { + auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE); + if (size <= limit) { + auto promise = kjStream.stream.write(buffer, size); + return promise.then([kjStream,size]() mutable { + kjStream.lender.returnStream(size); + }); + } else { + auto promise = kjStream.stream.write(buffer, limit); + return promise.then([this,kjStream,buffer,size,limit]() mutable { + kjStream.lender.returnStream(limit); + return write(reinterpret_cast(buffer) + limit, + size - limit); + }); + } + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { + if (size <= MAX_BYTES_PER_WRITE) { + auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 }); + req.setBytes(kj::arrayPtr(reinterpret_cast(buffer), size)); + return req.send(); + } else { + auto req = capnpStream->writeRequest( + MessageSize { 8 + MAX_BYTES_PER_WRITE / sizeof(word), 0 }); + req.setBytes(kj::arrayPtr(reinterpret_cast(buffer), MAX_BYTES_PER_WRITE)); + return req.send().then([this,buffer,size]() mutable { + return write(reinterpret_cast(buffer) + MAX_BYTES_PER_WRITE, + size - MAX_BYTES_PER_WRITE); + }); + } + } + } + KJ_UNREACHABLE; + } + + kj::Promise write(kj::ArrayPtr> pieces) override { + KJ_SWITCH_ONEOF(getShortestPath()) { + KJ_CASE_ONEOF(promise, kj::Promise) { + return promise.then([this,pieces]() { + return write(pieces); + }); + } + KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) { + size_t size = 0; + for (auto& piece: pieces) { size += piece.size(); } + auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE); + if (size <= limit) { + auto promise = kjStream.stream.write(pieces); + return promise.then([kjStream,size]() mutable { + kjStream.lender.returnStream(size); + }); + } else { + // ughhhhhhhhhh, we need to split the pieces. + return splitAndWrite(pieces, kjStream.limit, + [kjStream,limit](kj::ArrayPtr> pieces) mutable { + return kjStream.stream.write(pieces).then([kjStream,limit]() mutable { + kjStream.lender.returnStream(limit); + }); + }); + } + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { + auto writePieces = [capnpStream](kj::ArrayPtr> pieces) { + size_t size = 0; + for (auto& piece: pieces) size += piece.size(); + auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 }); + auto out = req.initBytes(size); + byte* ptr = out.begin(); + for (auto& piece: pieces) { + memcpy(ptr, piece.begin(), piece.size()); + ptr += piece.size(); + } + KJ_ASSERT(ptr == out.end()); + return req.send(); + }; + + size_t size = 0; + for (auto& piece: pieces) size += piece.size(); + if (size <= MAX_BYTES_PER_WRITE) { + return writePieces(pieces); + } else { + // ughhhhhhhhhh, we need to split the pieces. + return splitAndWrite(pieces, MAX_BYTES_PER_WRITE, writePieces); + } + } + } + KJ_UNREACHABLE; + } + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + KJ_IF_MAYBE(rpc, kj::dynamicDowncastIfAvailable(input)) { + // Oh interesting, it turns we're hosting an incoming ByteStream which is pumping to this + // outgoing ByteStream. We can let the Cap'n Proto RPC layer know that it can shorten the + // path from one to the other. + return rpc->pumpToShorterPath(inner, amount); + } else { + return pumpLoop(input, 0, amount); + } + } + + kj::Promise whenWriteDisconnected() override { + return findShorterPathTask.addBranch(); + } + +private: + ByteStreamFactory& factory; + capnp::ByteStream::Client inner; + kj::Maybe optimized; + + kj::ForkedPromise findShorterPathTask; + // This serves two purposes: + // 1. Waits for the capability to resolve (if it is a promise), and then shortens the path if + // possible. + // 2. Implements whenWriteDisconnected(). + + bool explicitEnd; + // Did the creator promise to explicitly call end()? + + kj::Promise findShorterPath(capnp::ByteStream::Client& capnpClient) { + // If the capnp stream turns out to resolve back to this process, shorten the path. + // Also, implement whenWriteDisconnected() based on this. + return factory.streamSet.getLocalServer(capnpClient) + .then([this](kj::Maybe server) -> kj::Promise { + KJ_IF_MAYBE(s, server) { + // Yay, we discovered that the ByteStream actually points back to a local KJ stream. + // We can use this to shorten the path by skipping the RPC machinery. + return findShorterPath(kj::downcast(*s)); + } else { + // The capability is fully-resolved. This suggests that the remote implementation is + // NOT a CapnpToKjStreamAdapter at all, because CapnpToKjStreamAdapter is designed to + // always look like a promise. It's some other implementation that doesn't present + // itself as a promise. We have no way to detect when it is disconnected. + return kj::NEVER_DONE; + } + }, [](kj::Exception&& e) -> kj::Promise { + // getLocalServer() thrown when the capability is a promise cap that rejects. We can + // use this to implement whenWriteDisconnected(). + // + // (Note that because this exception handler is passed to the .then(), it does NOT catch + // eoxceptions thrown by the success handler immediately above it. This handler will ONLY + // catch exceptions from getLocalServer() itself.) + return kj::READY_NOW; + }); + } + + kj::Promise findShorterPath(StreamServerBase& capnpServer) { + // We found a shorter path back to this process. Record it. + optimized = capnpServer; + + KJ_SWITCH_ONEOF(capnpServer.getShortestPath()) { + KJ_CASE_ONEOF(promise, kj::Promise) { + return promise.then([this,&capnpServer]() { + return findShorterPath(capnpServer); + }); + } + KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) { + // The ByteStream::Server wraps a regular KJ stream that does not wrap another capnp + // stream. + if (kjStream.limit < (uint64_t)kj::maxValue / 2) { + // But it isn't wrapping that stream forever. Eventually it plans to redirect back to + // some other stream. So, let's wait for that, and possibly shorten again. + kjStream.lender.returnStream(0); + return KJ_ASSERT_NONNULL(capnpServer.shortenPath()) + .then([this, &capnpServer](auto&&) { + return findShorterPath(capnpServer); + }); + } else { + // This KJ stream is (effectively) the permanent endpoint. We can't get any shorter + // from here. All we want to do now is watch for disconnect. + auto promise = kjStream.stream.whenWriteDisconnected(); + kjStream.lender.returnStream(0); + return promise; + } + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { + return findShorterPath(*capnpStream); + } + } + KJ_UNREACHABLE; + } + + StreamServerBase::ShortestPath getShortestPath() { + KJ_IF_MAYBE(o, optimized) { + return o->getShortestPath(); + } else { + return &inner; + } + } + + kj::Promise pumpLoop(kj::AsyncInputStream& input, + uint64_t completed, uint64_t remaining) { + if (remaining == 0) return completed; + + KJ_SWITCH_ONEOF(getShortestPath()) { + KJ_CASE_ONEOF(promise, kj::Promise) { + return promise.then([this,&input,completed,remaining]() { + return pumpLoop(input,completed,remaining); + }); + } + KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) { + // Oh hell yes, this capability actually points back to a stream in our own thread. We can + // stop sending RPCs and just pump directly. + + if (remaining <= kjStream.limit) { + return input.pumpTo(kjStream.stream, remaining) + .then([kjStream,completed](uint64_t actual) { + kjStream.lender.returnStream(actual); + return actual + completed; + }); + } else { + auto promise = input.pumpTo(kjStream.stream, kjStream.limit); + return promise.then([this,&input,completed,remaining,kjStream] + (uint64_t actual) mutable -> kj::Promise { + kjStream.lender.returnStream(actual); + if (actual < kjStream.limit) { + // EOF reached. + return completed + actual; + } else { + return pumpLoop(input, completed + actual, remaining - actual); + } + }); + } + } + KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { + // Pumping from some other kind of stream. Optimize the pump by reading from the input + // directly into outgoing RPC messages. + size_t size = kj::min(remaining, 8192); + auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 }); + + auto orphanage = Orphanage::getForMessageContaining( + capnp::ByteStream::WriteParams::Builder(req)); + + auto buffer = orphanage.newOrphan(size); + + struct WriteRequestAndBuffer { + // The order of construction/destruction of lambda captures is unspecified, but we care + // about ordering between these two things that we want to capture, so... we need a + // struct. + StreamingRequest request; + Orphan buffer; // points into `request`... + }; + + WriteRequestAndBuffer wrab = { kj::mv(req), kj::mv(buffer) }; + + return input.tryRead(wrab.buffer.get().begin(), 1, size) + .then([this, &input, completed, remaining, size, wrab = kj::mv(wrab)] + (size_t actual) mutable -> kj::Promise { + if (actual == 0) { + return completed; + } if (actual < size) { + wrab.buffer.truncate(actual); + } + + wrab.request.adoptBytes(kj::mv(wrab.buffer)); + return wrab.request.send() + .then([this, &input, completed, remaining, actual]() { + return pumpLoop(input, completed + actual, remaining - actual); + }); + }); + } + } + KJ_UNREACHABLE; + } + + template + kj::Promise splitAndWrite(kj::ArrayPtr> pieces, + size_t limit, WritePieces&& writeFirstPieces) { + size_t splitByte = limit; + size_t splitPiece = 0; + while (pieces[splitPiece].size() <= splitByte) { + splitByte -= pieces[splitPiece].size(); + ++splitPiece; + } + + if (splitByte == 0) { + // Oh thank god, the split is between two pieces. + auto rest = pieces.slice(splitPiece, pieces.size()); + return writeFirstPieces(pieces.slice(0, splitPiece)) + .then([this,rest]() mutable { + return write(rest); + }); + } else { + // FUUUUUUUU---- we need to split one of the pieces in two. + auto left = kj::heapArray>(splitPiece + 1); + auto right = kj::heapArray>(pieces.size() - splitPiece); + for (auto i: kj::zeroTo(splitPiece)) { + left[i] = pieces[i]; + } + for (auto i: kj::zeroTo(right.size())) { + right[i] = pieces[splitPiece + i]; + } + left.back() = pieces[splitPiece].slice(0, splitByte); + right.front() = pieces[splitPiece].slice(splitByte, pieces[splitPiece].size()); + + return writeFirstPieces(left).attach(kj::mv(left)) + .then([this,right=kj::mv(right)]() mutable { + return write(right).attach(kj::mv(right)); + }); + } + } +}; + +// ======================================================================================= + +capnp::ByteStream::Client ByteStreamFactory::kjToCapnp(kj::Own kjStream) { + return streamSet.add(kj::heap(*this, kj::mv(kjStream))); +} + +capnp::ByteStream::Client ByteStreamFactory::kjToCapnp( + kj::Own kjStream, kj::Maybe> tlsStarter) { + return streamSet.add( + kj::heap(*this, kj::mv(kjStream), kj::mv(tlsStarter))); +} + +kj::Own ByteStreamFactory::capnpToKj(capnp::ByteStream::Client capnpStream) { + return kj::heap(*this, kj::mv(capnpStream), false); +} + +kj::Own ByteStreamFactory::capnpToKjExplicitEnd( + capnp::ByteStream::Client capnpStream) { + return kj::heap(*this, kj::mv(capnpStream), true); +} + +} // namespace capnp diff --git a/c++/src/capnp/compat/byte-stream.capnp b/c++/src/capnp/compat/byte-stream.capnp new file mode 100644 index 0000000000..3298c06f48 --- /dev/null +++ b/c++/src/capnp/compat/byte-stream.capnp @@ -0,0 +1,45 @@ +@0x8f5d14e1c273738d; + +using Cxx = import "/capnp/c++.capnp"; +$Cxx.namespace("capnp"); +$Cxx.allowCancellation; + +interface ByteStream { + write @0 (bytes :Data) -> stream; + # Write a chunk. + + end @1 (); + # Signals clean EOF. (If the ByteStream is dropped without calling this, then the stream was + # prematurely canceled and so the body should not be considered complete.) + + getSubstream @2 (callback :SubstreamCallback, + limit :UInt64 = 0xffffffffffffffff) -> (substream :ByteStream); + # This method is used to implement path shortening optimization. It is designed in particular + # with KJ streams' pumpTo() in mind. + # + # getSubstream() returns a new stream object that can be used to write to the same destination + # as this stream. The substream will operate until it has received `limit` bytes, or its `end()` + # method has been called, whichever occurs first. At that time, it invokes one of the methods of + # `callback` based on the termination condition. + # + # While a substream is active, it is an error to call write() on the original stream. Doing so + # may throw an exception or may arbitrarily interleave bytes with the substream's writes. + + startTls @3 (expectedServerHostname :Text) -> stream; + # Client calls this method when it wants to initiate TLS. This ByteStream is not terminated, + # the caller should reuse it. + + interface SubstreamCallback { + ended @0 (byteCount :UInt64); + # `end()` was called on the substream after writing `byteCount` bytes. The `end()` call was + # NOT forwarded to the underlying stream, which remains open. + + reachedLimit @1 () -> (next :ByteStream); + # The number of bytes specified by the `limit` parameter of `getSubstream()` was reached. + # The substream will "resolve itself" to `next`, so that all future calls to the substream + # are forwarded to `next`. + # + # If the `write()` call which reached the limit included bytes past the limit, then the first + # `write()` call to `next` will be for those leftover bytes. + } +} diff --git a/c++/src/capnp/compat/byte-stream.h b/c++/src/capnp/compat/byte-stream.h new file mode 100644 index 0000000000..b34aa3f521 --- /dev/null +++ b/c++/src/capnp/compat/byte-stream.h @@ -0,0 +1,75 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once +// Bridges from KJ streams to Cap'n Proto ByteStream RPC protocol. + +#include +#include +#include + +CAPNP_BEGIN_HEADER + +namespace capnp { + +class ExplicitEndOutputStream: public kj::AsyncOutputStream { + // HACK: KJ's AsyncOutputStream has a known serious design flaw in that EOF is signaled by + // destroying the stream object rather than by calling an explicit `end()` method. This causes + // some serious problems when signaling EOF requires doing additional I/O, such as when + // wrapping a capnp ByteStream where `end()` is an RPC call. + // + // When it really must, the ByteStream implementation will honor the KJ convention by starting + // the RPC in its destructor and detach()ing the promise. But, this has lots of negative side + // effects, especially in the case where the stream is really meant to be aborted abruptly. + // + // In lieu of an actual deep refactoring of KJ, ByteStreamFactory allows its caller to + // explicily specify when it is able to promise that it will make an explicit `end()` call. + // capnpToKjExplicitEnd() returns an ExplicitEndOutputStream, which expect to receive an + // `end()` call on clean EOF, and treats destruction without `end()` as an abort. This is used + // in particular within http-over-capnp to improve behavior somewhat. +public: + virtual kj::Promise end() = 0; +}; + +class ByteStreamFactory { + // In order to allow path-shortening through KJ, a common factory must be used for converting + // between RPC ByteStreams and KJ streams. + +public: + capnp::ByteStream::Client kjToCapnp(kj::Own kjStream); + capnp::ByteStream::Client kjToCapnp( + kj::Own kjStream, kj::Maybe> tlsStarter); + kj::Own capnpToKj(capnp::ByteStream::Client capnpStream); + + kj::Own capnpToKjExplicitEnd(capnp::ByteStream::Client capnpStream); + +private: + CapabilityServerSet streamSet; + + class StreamServerBase; + class SubstreamImpl; + class CapnpToKjStreamAdapter; + class KjToCapnpStreamAdapter; +}; + +} // namespace capnp + +CAPNP_END_HEADER diff --git a/c++/src/capnp/compat/http-over-capnp-old-test.c++ b/c++/src/capnp/compat/http-over-capnp-old-test.c++ new file mode 100644 index 0000000000..9a5aea9b13 --- /dev/null +++ b/c++/src/capnp/compat/http-over-capnp-old-test.c++ @@ -0,0 +1,2 @@ +#define TEST_PEER_OPTIMIZATION_LEVEL HttpOverCapnpFactory::LEVEL_1 +#include "http-over-capnp-test.c++" diff --git a/c++/src/capnp/compat/http-over-capnp-perf-test.c++ b/c++/src/capnp/compat/http-over-capnp-perf-test.c++ new file mode 100644 index 0000000000..20ba63840b --- /dev/null +++ b/c++/src/capnp/compat/http-over-capnp-perf-test.c++ @@ -0,0 +1,446 @@ +// Copyright (c) 2022 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "http-over-capnp.h" +#include +#include +#include +#include +#include +#if KJ_BENCHMARK_MALLOC +#include +#endif + +#if KJ_HAS_COROUTINE + +namespace capnp { +namespace { + +// ======================================================================================= +// Metrics-gathering +// +// TODO(cleanup): Generalize for other benchmarks? + +static size_t globalMallocCount = 0; +static size_t globalMallocBytes = 0; + +#if KJ_BENCHMARK_MALLOC +// If KJ_BENCHMARK_MALLOC is defined then we are instructed to override malloc() in order to +// measure total allocations. We are careful only to define this when the build is set up in a +// way that this won't cause build failures (e.g., we must not be statically linking a malloc +// implementation). + +extern "C" { + +void* malloc(size_t size) { + typedef void* Malloc(size_t); + static Malloc* realMalloc = reinterpret_cast(dlsym(RTLD_NEXT, "malloc")); + + ++globalMallocCount; + globalMallocBytes += size; + return realMalloc(size); +} + +} // extern "C" + +#endif // KJ_BENCHMARK_MALLOC + +class Metrics { +public: + Metrics() + : startMallocCount(globalMallocCount), startMallocBytes(globalMallocBytes), + upBandwidth(0), downBandwidth(0), + clientReadCount(0), clientWriteCount(0), + serverReadCount(0), serverWriteCount(0) {} + ~Metrics() noexcept(false) { + #if KJ_BENCHMARK_MALLOC + size_t mallocCount = globalMallocCount - startMallocCount; + size_t mallocBytes = globalMallocBytes - startMallocBytes; + KJ_LOG(WARNING, mallocCount, mallocBytes); + #endif + + if (hadStreamPair) { + KJ_LOG(WARNING, upBandwidth, downBandwidth, + clientReadCount, clientWriteCount, serverReadCount, serverWriteCount); + } + } + + enum Side { CLIENT, SERVER }; + + class StreamWrapper final: public kj::AsyncIoStream { + // Wrap a stream and count metrics. + + public: + StreamWrapper(Metrics& metrics, kj::AsyncIoStream& inner, Side side) + : metrics(metrics), inner(inner), side(side) {} + + ~StreamWrapper() noexcept(false) { + switch (side) { + case CLIENT: + metrics.clientReadCount += readCount; + metrics.clientWriteCount += writeCount; + metrics.upBandwidth += writeBytes; + metrics.downBandwidth += readBytes; + break; + case SERVER: + metrics.serverReadCount += readCount; + metrics.serverWriteCount += writeCount; + break; + } + } + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner.read(buffer, minBytes, maxBytes) + .then([this](size_t n) { + ++readCount; + readBytes += n; + return n; + }); + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner.tryRead(buffer, minBytes, maxBytes) + .then([this](size_t n) { + ++readCount; + readBytes += n; + return n; + }); + } + + kj::Maybe tryGetLength() override { + return inner.tryGetLength(); + } + + kj::Promise write(const void* buffer, size_t size) override { + ++writeCount; + writeBytes += size; + return inner.write(buffer, size); + } + kj::Promise write(kj::ArrayPtr> pieces) override { + ++writeCount; + for (auto& piece: pieces) { + writeBytes += piece.size(); + } + return inner.write(pieces); + } + + kj::Promise pumpTo( + kj::AsyncOutputStream& output, uint64_t amount = kj::maxValue) override { + // Our benchmarks don't depend on this currently. If they do we need to think about how to + // apply it. + KJ_UNIMPLEMENTED("pump metrics"); + } + kj::Maybe> tryPumpFrom( + AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + // Our benchmarks don't depend on this currently. If they do we need to think about how to + // apply it. + KJ_UNIMPLEMENTED("pump metrics"); + } + + kj::Promise whenWriteDisconnected() override { + return inner.whenWriteDisconnected(); + } + + void shutdownWrite() override { + inner.shutdownWrite(); + } + void abortRead() override { + inner.abortRead(); + } + + private: + Metrics& metrics; + kj::AsyncIoStream& inner; + Side side; + + size_t readCount = 0; + size_t readBytes = 0; + size_t writeCount = 0; + size_t writeBytes = 0; + }; + + struct StreamPair { + kj::TwoWayPipe pipe; + StreamWrapper client; + StreamWrapper server; + + StreamPair(Metrics& metrics) + : pipe(kj::newTwoWayPipe()), + client(metrics, *pipe.ends[0], CLIENT), + server(metrics, *pipe.ends[1], SERVER) { + metrics.hadStreamPair = true; + } + }; + +private: + size_t startMallocCount KJ_UNUSED; + size_t startMallocBytes KJ_UNUSED; + size_t upBandwidth; + size_t downBandwidth; + size_t clientReadCount; + size_t clientWriteCount; + size_t serverReadCount; + size_t serverWriteCount; + + bool hadStreamPair = false; +}; + +// ======================================================================================= + +static constexpr auto HELLO_WORLD = "Hello, world!"_kj; + +class NullInputStream final: public kj::AsyncInputStream { +public: + NullInputStream(kj::Maybe expectedLength = size_t(0)) + : expectedLength(expectedLength) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return size_t(0); + } + + kj::Maybe tryGetLength() override { + return expectedLength; + } + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return uint64_t(0); + } + +private: + kj::Maybe expectedLength; +}; + +class VectorOutputStream: public kj::AsyncOutputStream { +public: + kj::String consume() { + chars.add('\0'); + return kj::String(chars.releaseAsArray()); + } + + kj::Promise write(const void* buffer, size_t size) override { + chars.addAll(kj::arrayPtr(reinterpret_cast(buffer), size)); + return kj::READY_NOW; + } + + kj::Promise write(kj::ArrayPtr> pieces) override { + for (auto piece: pieces) { + chars.addAll(piece.asChars()); + } + return kj::READY_NOW; + } + + kj::Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + +private: + kj::Vector chars; +}; + +class MockService: public kj::HttpService { +public: + MockService(kj::HttpHeaderTable::Builder& headerTableBuilder) + : headerTable(headerTableBuilder.getFutureTable()), + customHeaderId(headerTableBuilder.add("X-Custom-Header")) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_ASSERT(method == kj::HttpMethod::GET); + KJ_ASSERT(url == "http://foo"_kj); + KJ_ASSERT(headers.get(customHeaderId) == "corge"_kj); + + kj::HttpHeaders responseHeaders(headerTable); + responseHeaders.set(kj::HttpHeaderId::CONTENT_TYPE, "text/plain"); + responseHeaders.set(customHeaderId, "foobar"_kj); + auto stream = response.send(200, "OK", responseHeaders); + auto promise = stream->write(HELLO_WORLD.begin(), HELLO_WORLD.size()); + return promise.attach(kj::mv(stream)); + } + +private: + const kj::HttpHeaderTable& headerTable; + kj::HttpHeaderId customHeaderId; +}; + +class MockSender: private kj::HttpService::Response { +public: + MockSender(kj::HttpHeaderTable::Builder& headerTableBuilder) + : headerTable(headerTableBuilder.getFutureTable()), + customHeaderId(headerTableBuilder.add("X-Custom-Header")) {} + + kj::Promise sendRequest(kj::HttpClient& client) { + kj::HttpHeaders headers(headerTable); + headers.set(customHeaderId, "corge"_kj); + auto req = client.request(kj::HttpMethod::GET, "http://foo"_kj, headers); + req.body = nullptr; + auto resp = co_await req.response; + KJ_ASSERT(resp.statusCode == 200); + KJ_ASSERT(resp.statusText == "OK"_kj); + KJ_ASSERT(resp.headers->get(customHeaderId) == "foobar"_kj); + + auto body = co_await resp.body->readAllText(); + KJ_ASSERT(body == HELLO_WORLD); + } + + kj::Promise sendRequest(kj::HttpService& service) { + kj::HttpHeaders headers(headerTable); + headers.set(customHeaderId, "corge"_kj); + NullInputStream requestBody; + co_await service.request(kj::HttpMethod::GET, "http://foo"_kj, headers, requestBody, *this); + KJ_ASSERT(responseBody.consume() == HELLO_WORLD); + } + +private: + const kj::HttpHeaderTable& headerTable; + kj::HttpHeaderId customHeaderId; + + VectorOutputStream responseBody; + + kj::Own send( + uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_ASSERT(statusCode == 200); + KJ_ASSERT(statusText == "OK"_kj); + KJ_ASSERT(headers.get(customHeaderId) == "foobar"_kj); + + return kj::attachRef(responseBody); + } + + kj::Own acceptWebSocket(const kj::HttpHeaders& headers) override { + KJ_UNIMPLEMENTED("no WebSockets here"); + } +}; + +KJ_TEST("Benchmark baseline") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + doBenchmark([&]() { + sender.sendRequest(service).wait(waitScope); + }); +} + +KJ_TEST("Benchmark KJ HTTP client wrapper") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + auto client = kj::newHttpClient(service); + + doBenchmark([&]() { + sender.sendRequest(*client).wait(waitScope); + }); +} + +KJ_TEST("Benchmark KJ HTTP full protocol") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + Metrics::StreamPair pair(metrics); + kj::TimerImpl timer(kj::origin()); + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + kj::HttpServer server(timer, *headerTable, service); + auto listenLoop = server.listenHttp({&pair.server, kj::NullDisposer::instance}) + .eagerlyEvaluate([](kj::Exception&& e) noexcept { kj::throwFatalException(kj::mv(e)); }); + auto client = kj::newHttpClient(*headerTable, pair.client); + + doBenchmark([&]() { + sender.sendRequest(*client).wait(waitScope); + }); +} + +KJ_TEST("Benchmark HTTP-over-capnp local call") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + HttpOverCapnpFactory::HeaderIdBundle headerIds(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + // Client and server use different HttpOverCapnpFactory instances to block path-shortening. + ByteStreamFactory bsFactory; + HttpOverCapnpFactory hocFactory(bsFactory, headerIds.clone(), HttpOverCapnpFactory::LEVEL_2); + ByteStreamFactory bsFactory2; + HttpOverCapnpFactory hocFactory2(bsFactory2, kj::mv(headerIds), HttpOverCapnpFactory::LEVEL_2); + + auto cap = hocFactory.kjToCapnp(kj::attachRef(service)); + auto roundTrip = hocFactory2.capnpToKj(kj::mv(cap)); + + doBenchmark([&]() { + sender.sendRequest(*roundTrip).wait(waitScope); + }); +} + +KJ_TEST("Benchmark HTTP-over-capnp full RPC") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + Metrics metrics; + Metrics::StreamPair pair(metrics); + + kj::HttpHeaderTable::Builder headerTableBuilder; + MockService service(headerTableBuilder); + MockSender sender(headerTableBuilder); + HttpOverCapnpFactory::HeaderIdBundle headerIds(headerTableBuilder); + auto headerTable = headerTableBuilder.build(); + + // Client and server use different HttpOverCapnpFactory instances to block path-shortening. + ByteStreamFactory bsFactory; + HttpOverCapnpFactory hocFactory(bsFactory, headerIds.clone(), HttpOverCapnpFactory::LEVEL_2); + ByteStreamFactory bsFactory2; + HttpOverCapnpFactory hocFactory2(bsFactory2, kj::mv(headerIds), HttpOverCapnpFactory::LEVEL_2); + + TwoPartyServer server(hocFactory.kjToCapnp(kj::attachRef(service))); + + auto pipe = kj::newTwoWayPipe(); + auto listenLoop = server.accept(pair.server); + + TwoPartyClient client(pair.client); + + auto roundTrip = hocFactory2.capnpToKj(client.bootstrap().castAs()); + + doBenchmark([&]() { + sender.sendRequest(*roundTrip).wait(waitScope); + }); +} + +} // namespace +} // namespace capnp + +#endif // KJ_HAS_COROUTINE diff --git a/c++/src/capnp/compat/http-over-capnp-test.c++ b/c++/src/capnp/compat/http-over-capnp-test.c++ new file mode 100644 index 0000000000..425014cabb --- /dev/null +++ b/c++/src/capnp/compat/http-over-capnp-test.c++ @@ -0,0 +1,989 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "http-over-capnp.h" +#include + +#ifndef TEST_PEER_OPTIMIZATION_LEVEL +#define TEST_PEER_OPTIMIZATION_LEVEL HttpOverCapnpFactory::LEVEL_2 +#endif + +namespace capnp { +namespace { + +KJ_TEST("KJ and RPC HTTP method enums match") { +#define EXPECT_MATCH(METHOD) \ + KJ_EXPECT(static_cast(kj::HttpMethod::METHOD) == \ + static_cast(capnp::HttpMethod::METHOD)); + + KJ_HTTP_FOR_EACH_METHOD(EXPECT_MATCH); +#undef EXPECT_MATCH +} + +// ======================================================================================= + +kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { + if (expected.size() == 0) return kj::READY_NOW; + + auto buffer = kj::heapArray(expected.size()); + + auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { + if (amount == 0) { + KJ_FAIL_ASSERT("expected data never sent", expected); + } + + auto actual = buffer.slice(0, amount); + if (memcmp(actual.begin(), expected.begin(), actual.size()) != 0) { + KJ_FAIL_ASSERT("data from stream doesn't match expected", expected, actual); + } + + return expectRead(in, expected.slice(amount)); + }); +} + +enum Direction { + CLIENT_TO_SERVER, + SERVER_TO_CLIENT +}; + +struct TestStep { + Direction direction; + kj::StringPtr send; + kj::StringPtr receive; + + constexpr TestStep(Direction direction, kj::StringPtr send, kj::StringPtr receive) + : direction(direction), send(send), receive(receive) {} + constexpr TestStep(Direction direction, kj::StringPtr data) + : direction(direction), send(data), receive(data) {} +}; + +constexpr TestStep TEST_STEPS[] = { + // Test basic request. + { + CLIENT_TO_SERVER, + + "GET / HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj, + }, + { + SERVER_TO_CLIENT, + + "HTTP/1.1 200 OK\r\n" + "Content-Length: 3\r\n" + "\r\n" + "foo"_kj + }, + + // Try PUT, vary path, vary status + { + CLIENT_TO_SERVER, + + "PUT /foo/bar HTTP/1.1\r\n" + "Content-Length: 5\r\n" + "Host: example.com\r\n" + "\r\n" + "corge"_kj, + }, + { + SERVER_TO_CLIENT, + + "HTTP/1.1 403 Unauthorized\r\n" + "Content-Length: 4\r\n" + "\r\n" + "nope"_kj + }, + + // HEAD request + { + CLIENT_TO_SERVER, + + "HEAD /foo/bar HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj, + }, + { + SERVER_TO_CLIENT, + + "HTTP/1.1 200 OK\r\n" + "Content-Length: 4\r\n" + "\r\n"_kj + }, + + // Empty-body response + { + CLIENT_TO_SERVER, + + "GET /foo/bar HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj, + }, + { + SERVER_TO_CLIENT, + + "HTTP/1.1 304 Not Modified\r\n" + "Server: foo\r\n" + "\r\n"_kj + }, + + // Chonky body + { + CLIENT_TO_SERVER, + + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "Host: example.com\r\n" + "\r\n" + "3\r\n" + "foo\r\n" + "5\r\n" + "corge\r\n" + "0\r\n" + "\r\n"_kj, + }, + { + SERVER_TO_CLIENT, + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "barbaz\r\n" + "6\r\n" + "garply\r\n" + "0\r\n" + "\r\n"_kj + }, + + // Streaming + { + CLIENT_TO_SERVER, + + "POST / HTTP/1.1\r\n" + "Content-Length: 9\r\n" + "Host: example.com\r\n" + "\r\n"_kj, + }, + { + CLIENT_TO_SERVER, + + "foo"_kj, + }, + { + CLIENT_TO_SERVER, + + "bar"_kj, + }, + { + CLIENT_TO_SERVER, + + "baz"_kj, + }, + { + SERVER_TO_CLIENT, + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n"_kj, + }, + { + SERVER_TO_CLIENT, + + "6\r\n" + "barbaz\r\n"_kj, + }, + { + SERVER_TO_CLIENT, + + "6\r\n" + "garply\r\n"_kj, + }, + { + SERVER_TO_CLIENT, + + "0\r\n" + "\r\n"_kj + }, + + // Bidirectional. + { + CLIENT_TO_SERVER, + + "POST / HTTP/1.1\r\n" + "Content-Length: 9\r\n" + "Host: example.com\r\n" + "\r\n"_kj, + }, + { + SERVER_TO_CLIENT, + + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n"_kj, + }, + { + CLIENT_TO_SERVER, + + "foo"_kj, + }, + { + SERVER_TO_CLIENT, + + "6\r\n" + "barbaz\r\n"_kj, + }, + { + CLIENT_TO_SERVER, + + "bar"_kj, + }, + { + SERVER_TO_CLIENT, + + "6\r\n" + "garply\r\n"_kj, + }, + { + CLIENT_TO_SERVER, + + "baz"_kj, + }, + { + SERVER_TO_CLIENT, + + "0\r\n" + "\r\n"_kj + }, + + // Test headers being re-ordered by KJ. This isn't necessary behavior, but it does prove that + // we're not testing a pure streaming pass-through... + { + CLIENT_TO_SERVER, + + "GET / HTTP/1.1\r\n" + "Host: example.com\r\n" + "Accept: text/html\r\n" + "Foo-Header: 123\r\n" + "User-Agent: kj\r\n" + "Accept-Language: en\r\n" + "\r\n"_kj, + + "GET / HTTP/1.1\r\n" + "Host: example.com\r\n" + "Accept-Language: en\r\n" + "Accept: text/html\r\n" + "User-Agent: kj\r\n" + "Foo-Header: 123\r\n" + "\r\n"_kj + }, + { + SERVER_TO_CLIENT, + + "HTTP/1.1 200 OK\r\n" + "Server: kj\r\n" + "Bar: 321\r\n" + "Content-Length: 3\r\n" + "\r\n" + "foo"_kj, + + "HTTP/1.1 200 OK\r\n" + "Content-Length: 3\r\n" + "Server: kj\r\n" + "Bar: 321\r\n" + "\r\n" + "foo"_kj + }, + + // We finish up a request with no response, to test cancellation. + { + CLIENT_TO_SERVER, + + "GET / HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj, + }, +}; + +class OneConnectNetworkAddress final: public kj::NetworkAddress { +public: + OneConnectNetworkAddress(kj::Own stream) + : stream(kj::mv(stream)) {} + + kj::Promise> connect() override { + auto result = KJ_ASSERT_NONNULL(kj::mv(stream)); + stream = nullptr; + return kj::mv(result); + } + + kj::Own listen() override { KJ_UNIMPLEMENTED("test"); } + kj::Own clone() override { KJ_UNIMPLEMENTED("test"); } + kj::String toString() override { KJ_UNIMPLEMENTED("test"); } + +private: + kj::Maybe> stream; +}; + +void runEndToEndTests(kj::Timer& timer, kj::HttpHeaderTable& headerTable, + HttpOverCapnpFactory& clientFactory, HttpOverCapnpFactory& serverFactory, + kj::WaitScope& waitScope) { + auto clientPipe = kj::newTwoWayPipe(); + auto serverPipe = kj::newTwoWayPipe(); + + OneConnectNetworkAddress oneConnectAddr(kj::mv(serverPipe.ends[0])); + + auto backHttp = kj::newHttpClient(timer, headerTable, oneConnectAddr); + auto backCapnp = serverFactory.kjToCapnp(kj::newHttpService(*backHttp)); + auto frontCapnp = clientFactory.capnpToKj(backCapnp); + kj::HttpServer frontKj(timer, headerTable, *frontCapnp); + auto listenTask = frontKj.listenHttp(kj::mv(clientPipe.ends[1])) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + for (auto& step: TEST_STEPS) { + KJ_CONTEXT(step.send); + + kj::AsyncOutputStream* out; + kj::AsyncInputStream* in; + + switch (step.direction) { + case CLIENT_TO_SERVER: + out = clientPipe.ends[0]; + in = serverPipe.ends[1]; + break; + case SERVER_TO_CLIENT: + out = serverPipe.ends[1]; + in = clientPipe.ends[0]; + break; + } + + auto writePromise = out->write(step.send.begin(), step.send.size()); + auto readPromise = expectRead(*in, step.receive); + if (!writePromise.poll(waitScope)) { + if (readPromise.poll(waitScope)) { + readPromise.wait(waitScope); + KJ_FAIL_ASSERT("write hung, read worked fine"); + } else { + KJ_FAIL_ASSERT("write and read both hung"); + } + } + + writePromise.wait(waitScope); + KJ_ASSERT(readPromise.poll(waitScope), "read hung"); + readPromise.wait(waitScope); + } + + // The last test message was a request with no response. If we now close the client end, this + // should propagate all the way through to close the server end! + clientPipe.ends[0] = nullptr; + auto lastRead = serverPipe.ends[1]->readAllText(); + KJ_ASSERT(lastRead.poll(waitScope), "last read hung"); + KJ_EXPECT(lastRead.wait(waitScope) == nullptr); +} + +KJ_TEST("HTTP-over-Cap'n-Proto E2E, no path shortening") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory1; + ByteStreamFactory streamFactory2; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory1(streamFactory1, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + HttpOverCapnpFactory factory2(streamFactory2, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + auto headerTable = tableBuilder.build(); + + runEndToEndTests(timer, *headerTable, factory1, factory2, waitScope); +} + +KJ_TEST("HTTP-over-Cap'n-Proto E2E, with path shortening") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + auto headerTable = tableBuilder.build(); + + runEndToEndTests(timer, *headerTable, factory, factory, waitScope); +} + +KJ_TEST("HTTP-over-Cap'n-Proto 205 bug with HttpClientAdapter") { + // Test that a 205 with a hanging body doesn't prevent headers from being delivered. (This was + // a bug at one point. See, 205 responses are supposed to have empty bodies. But they must + // explicitly indicate an empty body. http-over-capnp, though, *assumed* an empty body when it + // saw a 205. But, on the client side, when HttpClientAdapter sees an empty body, it blocks + // delivery of the *headers* until the service promise resolves, in order to avoid prematurely + // cancelling the service. But on the server side, the service method is left hanging because + // it's waiting for the 205 to actually produce its empty body. If that didn't make any sense, + // consider yourself lucky.) + + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + auto headerTable = tableBuilder.build(); + + auto pipe = kj::newTwoWayPipe(); + + OneConnectNetworkAddress oneConnectAddr(kj::mv(pipe.ends[0])); + + auto backHttp = kj::newHttpClient(timer, *headerTable, oneConnectAddr); + auto backCapnp = factory.kjToCapnp(kj::newHttpService(*backHttp)); + auto frontCapnp = factory.capnpToKj(backCapnp); + + auto frontClient = kj::newHttpClient(*frontCapnp); + + auto req = frontClient->request(kj::HttpMethod::GET, "/", kj::HttpHeaders(*headerTable)); + + { + auto readPromise = expectRead(*pipe.ends[1], "GET / HTTP/1.1\r\n\r\n"); + KJ_ASSERT(readPromise.poll(waitScope)); + readPromise.wait(waitScope); + } + + KJ_EXPECT(!req.response.poll(waitScope)); + + { + // A 205 response with no content-length or transfer-encoding is terminated by EOF (but also + // the body is required to be empty). We don't send the EOF yet, just the response line and + // empty headers. + kj::StringPtr resp = "HTTP/1.1 205 Reset Content\r\n\r\n"; + pipe.ends[1]->write(resp.begin(), resp.size()).wait(waitScope); + } + + // On the client end, we should get a response now! + KJ_ASSERT(req.response.poll(waitScope)); + + auto resp = req.response.wait(waitScope); + KJ_EXPECT(resp.statusCode == 205); + + // But the body is still blocked. + auto promise = resp.body->readAllText(); + KJ_EXPECT(!promise.poll(waitScope)); + + // OK now send the EOF it's waiting for. + pipe.ends[1]->shutdownWrite(); + + // And now the body is unblocked. + KJ_ASSERT(promise.poll(waitScope)); + KJ_EXPECT(promise.wait(waitScope) == ""); +} + +// ======================================================================================= + +class WebSocketAccepter final: public kj::HttpService { +public: + WebSocketAccepter(kj::HttpHeaderTable& headerTable, + kj::Own>> fulfiller, + kj::Promise done) + : headerTable(headerTable), fulfiller(kj::mv(fulfiller)), done(kj::mv(done)) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) { + kj::HttpHeaders respHeaders(headerTable); + respHeaders.add("X-Foo", "bar"); + fulfiller->fulfill(response.acceptWebSocket(respHeaders)); + return kj::mv(done); + } + +private: + kj::HttpHeaderTable& headerTable; + kj::Own>> fulfiller; + kj::Promise done; +}; + +void runWebSocketBasicTestCase( + kj::WebSocket& clientWs, kj::WebSocket& serverWs, kj::WaitScope& waitScope) { + // Called by `runWebSocketTests()`. + + { + auto promise = clientWs.send("foo"_kj); + auto message = serverWs.receive().wait(waitScope); + promise.wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "foo"); + } + + { + auto promise = serverWs.send("bar"_kj.asBytes()); + auto message = clientWs.receive().wait(waitScope); + promise.wait(waitScope); + KJ_ASSERT(message.is>()); + KJ_EXPECT(kj::str(message.get>().asChars()) == "bar"); + } + + { + auto promise = clientWs.close(1234, "baz"_kj); + auto message = serverWs.receive().wait(waitScope); + promise.wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 1234); + KJ_EXPECT(message.get().reason == "baz"); + } + + { + auto promise = serverWs.disconnect(); + auto receivePromise = clientWs.receive(); + KJ_EXPECT(receivePromise.poll(waitScope)); + KJ_EXPECT_THROW(DISCONNECTED, receivePromise.wait(waitScope)); + promise.wait(waitScope); + } +} + +void runWebSocketAbortTestCase( + kj::WebSocket& clientWs, kj::WebSocket& serverWs, kj::WaitScope& waitScope) { + auto onAbort = serverWs.whenAborted(); + KJ_EXPECT(!onAbort.poll(waitScope)); + clientWs.abort(); + + // At one time, this promise hung forever. + KJ_EXPECT(onAbort.poll(waitScope)); + onAbort.wait(waitScope); +} + +void runWebSocketTests(kj::HttpHeaderTable& headerTable, + HttpOverCapnpFactory& clientFactory, HttpOverCapnpFactory& serverFactory, + kj::WaitScope& waitScope) { + // We take a different approach here, because writing out raw WebSocket frames is a pain. + // It's easier to test WebSockets at the KJ API level. + + for (auto testCase: { + runWebSocketBasicTestCase, + runWebSocketAbortTestCase, + }) { + auto wsPaf = kj::newPromiseAndFulfiller>(); + auto donePaf = kj::newPromiseAndFulfiller(); + + auto back = serverFactory.kjToCapnp(kj::heap( + headerTable, kj::mv(wsPaf.fulfiller), kj::mv(donePaf.promise))); + auto front = clientFactory.capnpToKj(back); + auto client = kj::newHttpClient(*front); + + auto resp = client->openWebSocket("/ws", kj::HttpHeaders(headerTable)).wait(waitScope); + KJ_ASSERT(resp.webSocketOrBody.is>()); + + auto clientWs = kj::mv(resp.webSocketOrBody.get>()); + auto serverWs = wsPaf.promise.wait(waitScope); + + testCase(*clientWs, *serverWs, waitScope); + } +} + +KJ_TEST("HTTP-over-Cap'n Proto WebSocket, no path shortening") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory streamFactory1; + ByteStreamFactory streamFactory2; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory1(streamFactory1, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + HttpOverCapnpFactory factory2(streamFactory2, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + auto headerTable = tableBuilder.build(); + + runWebSocketTests(*headerTable, factory1, factory2, waitScope); +} + +KJ_TEST("HTTP-over-Cap'n Proto WebSocket, with path shortening") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + auto headerTable = tableBuilder.build(); + + runWebSocketTests(*headerTable, factory, factory, waitScope); +} + +// ======================================================================================= +// bug fixes + +class HangingHttpService final: public kj::HttpService { +public: + HangingHttpService(bool& called, bool& destroyed) + : called(called), destroyed(destroyed) {} + ~HangingHttpService() noexcept(false) { + destroyed = true; + } + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) { + called = true; + return kj::NEVER_DONE; + } + +private: + bool& called; + bool& destroyed; +}; + +KJ_TEST("HttpService isn't destroyed while call outstanding") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); + auto headerTable = tableBuilder.build(); + + bool called = false; + bool destroyed = false; + auto service = factory.kjToCapnp(kj::heap(called, destroyed)); + + KJ_EXPECT(!called); + KJ_EXPECT(!destroyed); + + auto req = service.startRequestRequest(); + auto httpReq = req.initRequest(); + httpReq.setMethod(capnp::HttpMethod::GET); + httpReq.setUrl("/"); + auto serverContext = req.send().wait(waitScope).getContext(); + service = nullptr; + + auto promise = serverContext.whenResolved(); + KJ_EXPECT(!promise.poll(waitScope)); + + KJ_EXPECT(called); + KJ_EXPECT(!destroyed); +} + + +class ConnectWriteCloseService final: public kj::HttpService { + // A simple CONNECT server that will accept a connection, write some data and close the + // connection. +public: + ConnectWriteCloseService(kj::HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect( + kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& io, + kj::HttpService::ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", kj::HttpHeaders(headerTable)); + return io.write("test", 4).then([&io]() mutable { + io.shutdownWrite(); + }); + } + +private: + kj::HttpHeaderTable& headerTable; +}; + +class ConnectWriteRespService final: public kj::HttpService { +public: + ConnectWriteRespService(kj::HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect( + kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& io, + kj::HttpService::ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", kj::HttpHeaders(headerTable)); + // TODO(later): `io.pumpTo(io).ignoreResult;` doesn't work here, + // it causes startTls to come back in a loop. The below avoids this. + auto buffer = kj::heapArray(4096); + return manualPumpLoop(buffer, io).attach(kj::mv(buffer)); + } + + kj::Promise manualPumpLoop(kj::ArrayPtr buffer, kj::AsyncIoStream& io) { + return io.tryRead(buffer.begin(), 1, buffer.size()).then( + [this,&io,buffer](size_t amount) mutable -> kj::Promise { + if (amount == 0) { return kj::READY_NOW; } + return io.write(buffer.begin(), amount).then([this,&io,buffer]() mutable -> kj::Promise { + return manualPumpLoop(buffer, io); + }); + }); + } + +private: + kj::HttpHeaderTable& headerTable; +}; + +class ConnectRejectService final: public kj::HttpService { + // A simple CONNECT server that will reject a connection. +public: + ConnectRejectService(kj::HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect( + kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& io, + kj::HttpService::ConnectResponse& response, + kj::HttpConnectSettings settings) override { + auto body = response.reject(500, "Internal Server Error", kj::HttpHeaders(headerTable), 5); + return body->write("Error", 5).attach(kj::mv(body)); + } + +private: + kj::HttpHeaderTable& headerTable; +}; + +KJ_TEST("HTTP-over-Cap'n-Proto Connect with close") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + auto pipe = kj::newTwoWayPipe(); + + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder); + kj::Own table = tableBuilder.build(); + ConnectWriteCloseService service(*table); + kj::HttpServer server(timer, *table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(*table, *pipe.ends[1]); + + capnp::HttpService::Client httpService = factory.kjToCapnp(newHttpService(*client)); + auto frontCapnpHttpService = factory.capnpToKj(httpService); + + struct ResponseImpl final: public kj::HttpService::ConnectResponse { + kj::Own> fulfiller; + ResponseImpl(kj::Own> fulfiller) + : fulfiller(kj::mv(fulfiller)) {} + void accept(uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers) override { + KJ_REQUIRE(statusCode >= 200 && statusCode < 300, "the statusCode must be 2xx for accept"); + fulfiller->fulfill( + kj::HttpClient::ConnectRequest::Status( + statusCode, + kj::str(statusText), + kj::heap(headers.clone()), + nullptr + ) + ); + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + KJ_UNREACHABLE; + } + }; + + auto clientPipe = kj::newTwoWayPipe(); + auto paf = kj::newPromiseAndFulfiller(); + ResponseImpl response(kj::mv(paf.fulfiller)); + + auto promise = frontCapnpHttpService->connect( + "https://example.org"_kj, kj::HttpHeaders(*table), *clientPipe.ends[0], + response, {}).attach(kj::mv(clientPipe.ends[0])); + + paf.promise.then( + [io = kj::mv(clientPipe.ends[1])](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + + auto buf = kj::heapArray(4); + return io->tryRead(buf.begin(), 4, 4).then( + [buf = kj::mv(buf), io = kj::mv(io)](size_t count) mutable { + KJ_ASSERT(count == 4, "Expecting the stream to read 4 chars."); + return io->tryRead(buf.begin(), 1, 1).then( + [buf = kj::mv(buf)](size_t count) mutable { + KJ_ASSERT(count == 0, "Expecting the stream to get disconnected."); + }).attach(kj::mv(io)); + }); + }).wait(waitScope); + + listenTask.wait(waitScope); +} + + +KJ_TEST("HTTP-over-Cap'n-Proto Connect Reject") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + auto pipe = kj::newTwoWayPipe(); + + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder); + kj::Own table = tableBuilder.build(); + ConnectRejectService service(*table); + kj::HttpServer server(timer, *table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(*table, *pipe.ends[1]); + + capnp::HttpService::Client httpService = factory.kjToCapnp(newHttpService(*client)); + auto frontCapnpHttpService = factory.capnpToKj(httpService); + + struct ResponseImpl final: public kj::HttpService::ConnectResponse { + kj::Own>> fulfiller; + ResponseImpl(kj::Own>> fulfiller) + : fulfiller(kj::mv(fulfiller)) {} + void accept(uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers) override { + KJ_UNREACHABLE; + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + KJ_ASSERT(statusCode == 500); + KJ_ASSERT(statusText == "Internal Server Error"); + KJ_ASSERT(expectedBodySize.orDefault(5)); + auto pipe = kj::newOneWayPipe(); + fulfiller->fulfill(kj::mv(pipe.in)); + return kj::mv(pipe.out); + } + }; + + auto clientPipe = kj::newTwoWayPipe(); + auto paf = kj::newPromiseAndFulfiller>(); + ResponseImpl response(kj::mv(paf.fulfiller)); + + auto promise = frontCapnpHttpService->connect( + "https://example.org"_kj, kj::HttpHeaders(*table), *clientPipe.ends[0], + response, {}).attach(kj::mv(clientPipe.ends[0])); + + paf.promise.then( + [](auto body) mutable { + auto buf = kj::heapArray(5); + return body->tryRead(buf.begin(), 5, 5).then( + [buf = kj::mv(buf), body = kj::mv(body)](size_t count) mutable { + KJ_ASSERT(count == 5, "Expecting the stream to read 5 chars."); + }); + }).attach(kj::mv(promise)).wait(waitScope); + + listenTask.wait(waitScope); +} + +kj::Promise expectEnd(kj::AsyncInputStream& in) { + static char buffer; + + auto promise = in.tryRead(&buffer, 1, 1); + return promise.then([](size_t amount) { + KJ_ASSERT(amount == 0, "expected EOF"); + }); +} + +KJ_TEST("HTTP-over-Cap'n-Proto Connect with startTls") { + kj::EventLoop eventLoop; + kj::WaitScope waitScope(eventLoop); + + auto pipe = kj::newTwoWayPipe(); + + kj::TimerImpl timer(kj::origin()); + + ByteStreamFactory streamFactory; + kj::HttpHeaderTable::Builder tableBuilder; + HttpOverCapnpFactory factory(streamFactory, tableBuilder); + kj::Own table = tableBuilder.build(); + ConnectWriteRespService service(*table); + kj::HttpServer server(timer, *table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(*table, *pipe.ends[1]); + + class WrapperHttpClient final: public kj::HttpClient { + public: + kj::HttpClient& inner; + + WrapperHttpClient(kj::HttpClient& client) : inner(client) {}; + + kj::Promise openWebSocket( + kj::StringPtr url, const kj::HttpHeaders& headers) override { KJ_UNREACHABLE; } + Request request(kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { KJ_UNREACHABLE; } + + ConnectRequest connect(kj::StringPtr host, const kj::HttpHeaders& headers, + kj::HttpConnectSettings settings) override { + KJ_IF_MAYBE(starter, settings.tlsStarter) { + *starter = [](kj::StringPtr) { + return kj::READY_NOW; + }; + } + + return inner.connect(host, headers, settings); + } + }; + + // Only need this wrapper to define a dummy tlsStarter. + auto wrappedClient = kj::heap(*client); + capnp::HttpService::Client httpService = factory.kjToCapnp(newHttpService(*wrappedClient)); + auto frontCapnpHttpService = factory.capnpToKj(httpService); + + auto frontCapnpHttpClient = kj::newHttpClient(*frontCapnpHttpService); + + kj::Own tlsStarter = kj::heap(); + kj::HttpConnectSettings settings = { .useTls = false }; + settings.tlsStarter = tlsStarter; + + auto request = frontCapnpHttpClient->connect( + "https://example.org"_kj, kj::HttpHeaders(*table), settings); + + KJ_ASSERT_NONNULL(*tlsStarter); + + request.status.then( + [io=kj::mv(request.connection), &tlsStarter](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + + return KJ_ASSERT_NONNULL(*tlsStarter)("example.com").then([io = kj::mv(io)]() mutable { + return io->write("hello", 5).then([io = kj::mv(io)]() mutable { + auto buffer = kj::heapArray(5); + return io->tryRead(buffer.begin(), 5, 5).then( + [io = kj::mv(io), buffer = kj::mv(buffer)](size_t) mutable { + io->shutdownWrite(); + return expectEnd(*io).attach(kj::mv(io)); + }); + }); + }); + }).wait(waitScope); + + listenTask.wait(waitScope); +} + +} // namespace +} // namespace capnp diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ new file mode 100644 index 0000000000..a92e7c59d1 --- /dev/null +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -0,0 +1,1078 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "http-over-capnp.h" +#include +#include +#include + +namespace capnp { + +using kj::uint; +using kj::byte; + +class HttpOverCapnpFactory::RequestState final + : public kj::Refcounted, public kj::TaskSet::ErrorHandler { +public: + RequestState() { + tasks.emplace(*this); + } + + template + auto wrap(Func&& func) -> decltype(func()) { + if (tasks == nullptr) { + return KJ_EXCEPTION(DISCONNECTED, "client canceled HTTP request"); + } else { + return canceler.wrap(func()); + } + } + + void cancel() { + if (tasks != nullptr) { + if (!canceler.isEmpty()) { + canceler.cancel(KJ_EXCEPTION(DISCONNECTED, "request canceled")); + } + tasks = nullptr; + webSocket = nullptr; + } + } + + void assertNotCanceled() { + if (tasks == nullptr) { + kj::throwFatalException(KJ_EXCEPTION(DISCONNECTED, "client canceled HTTP request")); + } + } + + void addTask(kj::Promise task) { + KJ_IF_MAYBE(t, tasks) { + t->add(kj::mv(task)); + } else { + // Just drop the task. + } + } + + kj::Promise finishTasks() { + // This is merged into the final promise, so we don't need to worry about wrapping it for + // cancellation. + return KJ_REQUIRE_NONNULL(tasks).onEmpty() + .then([this]() { + KJ_IF_MAYBE(e, error) { + kj::throwRecoverableException(kj::mv(*e)); + } + }); + } + + void taskFailed(kj::Exception&& exception) override { + if (error == nullptr) { + error = kj::mv(exception); + } + } + + void holdWebSocket(kj::Own webSocket) { + // Hold on to this WebSocket until cancellation. + KJ_REQUIRE(this->webSocket == nullptr); + KJ_REQUIRE(tasks != nullptr); + this->webSocket = kj::mv(webSocket); + } + + void disconnectWebSocket() { + KJ_IF_MAYBE(t, tasks) { + t->add(kj::evalNow([&]() { return KJ_ASSERT_NONNULL(webSocket)->disconnect(); })); + } + } + +private: + kj::Maybe error; + kj::Maybe> webSocket; + kj::Canceler canceler; + kj::Maybe tasks; +}; + +// ======================================================================================= + +class HttpOverCapnpFactory::CapnpToKjWebSocketAdapter final: public capnp::WebSocket::Server { +public: + CapnpToKjWebSocketAdapter(kj::Own state, kj::WebSocket& webSocket, + kj::Promise shorteningPromise) + : state(kj::mv(state)), webSocket(webSocket), + shorteningPromise(kj::mv(shorteningPromise)) {} + + ~CapnpToKjWebSocketAdapter() noexcept(false) { + state->disconnectWebSocket(); + } + + kj::Maybe> shortenPath() override { + auto onAbort = webSocket.whenAborted() + .then([]() -> kj::Promise { + return KJ_EXCEPTION(DISCONNECTED, "WebSocket was aborted"); + }); + return shorteningPromise.exclusiveJoin(kj::mv(onAbort)); + } + + kj::Promise sendText(SendTextContext context) override { + return state->wrap([&]() { return webSocket.send(context.getParams().getText()); }); + } + kj::Promise sendData(SendDataContext context) override { + return state->wrap([&]() { return webSocket.send(context.getParams().getData()); }); + } + kj::Promise close(CloseContext context) override { + auto params = context.getParams(); + return state->wrap([&]() { return webSocket.close(params.getCode(), params.getReason()); }); + } + +private: + kj::Own state; + kj::WebSocket& webSocket; + kj::Promise shorteningPromise; +}; + +class HttpOverCapnpFactory::KjToCapnpWebSocketAdapter final: public kj::WebSocket { +public: + KjToCapnpWebSocketAdapter( + kj::Maybe> in, capnp::WebSocket::Client out, + kj::Own>> shorteningFulfiller) + : in(kj::mv(in)), out(kj::mv(out)), shorteningFulfiller(kj::mv(shorteningFulfiller)) {} + ~KjToCapnpWebSocketAdapter() noexcept(false) { + if (shorteningFulfiller->isWaiting()) { + // We want to make sure the fulfiller is not rejected with a bogus "PromiseFulfiller + // destroyed" error, so fulfill it with never-done. + shorteningFulfiller->fulfill(kj::NEVER_DONE); + } + } + + kj::Promise send(kj::ArrayPtr message) override { + auto req = KJ_REQUIRE_NONNULL(out, "already called disconnect()").sendDataRequest( + MessageSize { 8 + message.size() / sizeof(word), 0 }); + req.setData(message); + sentBytes += message.size(); + return req.send(); + } + + kj::Promise send(kj::ArrayPtr message) override { + auto req = KJ_REQUIRE_NONNULL(out, "already called disconnect()").sendTextRequest( + MessageSize { 8 + message.size() / sizeof(word), 0 }); + memcpy(req.initText(message.size()).begin(), message.begin(), message.size()); + sentBytes += message.size(); + return req.send(); + } + + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + auto req = KJ_REQUIRE_NONNULL(out, "already called disconnect()").closeRequest(); + req.setCode(code); + req.setReason(reason); + sentBytes += reason.size() + 2; + return req.send().ignoreResult(); + } + + kj::Promise disconnect() override { + out = nullptr; + return kj::READY_NOW; + } + + void abort() override { + KJ_ASSERT_NONNULL(in)->abort(); + } + + kj::Promise whenAborted() override { + return KJ_ASSERT_NONNULL(out).whenResolved() + .then([]() -> kj::Promise { + // It would seem this capability resolved to an implementation of the WebSocket RPC interface + // that does not support further path-shortening (so, it's not the implementation found in + // this file). Since the path-shortening facility is also how we discover disconnects, we + // apparently have no way to be alerted on disconnect. We have to assume the other end + // never aborts. + return kj::NEVER_DONE; + }, [](kj::Exception&& e) -> kj::Promise { + if (e.getType() == kj::Exception::Type::DISCONNECTED) { + // Looks like we were aborted! + return kj::READY_NOW; + } else { + // Some other error... propagate it. + return kj::mv(e); + } + }); + } + + kj::Promise receive(size_t maxSize) override { + return KJ_ASSERT_NONNULL(in)->receive(maxSize); + } + + kj::Promise pumpTo(WebSocket& other) override { + KJ_IF_MAYBE(optimized, kj::dynamicDowncastIfAvailable(other)) { + shorteningFulfiller->fulfill( + kj::cp(KJ_REQUIRE_NONNULL(optimized->out, "already called disconnect()"))); + + // We expect the `in` pipe will stop receiving messages after the redirect, but we need to + // pump anything already in-flight. + return KJ_ASSERT_NONNULL(in)->pumpTo(other); + } else KJ_IF_MAYBE(promise, other.tryPumpFrom(*this)) { + // We may have unwrapped some layers around `other` leading to a shorter path. + return kj::mv(*promise); + } else { + return KJ_ASSERT_NONNULL(in)->pumpTo(other); + } + } + + uint64_t sentByteCount() override { return sentBytes; } + uint64_t receivedByteCount() override { return KJ_ASSERT_NONNULL(in)->receivedByteCount(); } + +private: + kj::Maybe> in; // One end of a WebSocketPipe, used only for receiving. + kj::Maybe out; // Used only for sending. + kj::Own>> shorteningFulfiller; + uint64_t sentBytes = 0; +}; + +// ======================================================================================= + +class HttpOverCapnpFactory::ClientRequestContextImpl final + : public capnp::HttpService::ClientRequestContext::Server { +public: + ClientRequestContextImpl(HttpOverCapnpFactory& factory, + kj::Own state, + kj::HttpService::Response& kjResponse) + : factory(factory), state(kj::mv(state)), kjResponse(kjResponse) {} + + ~ClientRequestContextImpl() noexcept(false) { + // Note this implicitly cancels the upstream pump task. + } + + kj::Promise startResponse(StartResponseContext context) override { + KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()"); + sent = true; + state->assertNotCanceled(); + + auto params = context.getParams(); + auto rpcResponse = params.getResponse(); + + auto bodySize = rpcResponse.getBodySize(); + kj::Maybe expectedSize; + bool hasBody = true; + if (bodySize.isFixed()) { + auto size = bodySize.getFixed(); + expectedSize = bodySize.getFixed(); + hasBody = size > 0; + } + + auto bodyStream = kjResponse.send(rpcResponse.getStatusCode(), rpcResponse.getStatusText(), + factory.headersToKj(rpcResponse.getHeaders()), expectedSize); + + auto results = context.getResults(MessageSize { 16, 1 }); + if (hasBody) { + auto pipe = kj::newOneWayPipe(); + results.setBody(factory.streamFactory.kjToCapnp(kj::mv(pipe.out))); + state->addTask(pipe.in->pumpTo(*bodyStream) + .ignoreResult() + .attach(kj::mv(bodyStream), kj::mv(pipe.in))); + } + return kj::READY_NOW; + } + + kj::Promise startWebSocket(StartWebSocketContext context) override { + KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()"); + sent = true; + state->assertNotCanceled(); + + auto params = context.getParams(); + + auto shorteningPaf = kj::newPromiseAndFulfiller>(); + + auto ownWebSocket = kjResponse.acceptWebSocket(factory.headersToKj(params.getHeaders())); + auto& webSocket = *ownWebSocket; + state->holdWebSocket(kj::mv(ownWebSocket)); + + auto upWrapper = kj::heap( + nullptr, params.getUpSocket(), kj::mv(shorteningPaf.fulfiller)); + state->addTask(webSocket.pumpTo(*upWrapper).attach(kj::mv(upWrapper)) + .catch_([&webSocket=webSocket](kj::Exception&& e) -> kj::Promise { + // The pump in the client -> server direction failed. The error may have originated from + // either the client or the server. In case it came from the server, we want to call .abort() + // to propagate the problem back to the client. If the error came from the client, then + // .abort() probably is a noop. + webSocket.abort(); + return kj::mv(e); + })); + + auto results = context.getResults(MessageSize { 16, 1 }); + results.setDownSocket(kj::heap( + kj::addRef(*state), webSocket, kj::mv(shorteningPaf.promise))); + + return kj::READY_NOW; + } + +private: + HttpOverCapnpFactory& factory; + kj::Own state; + bool sent = false; + + kj::HttpService::Response& kjResponse; + // Must check state->assertNotCanceled() before using this. +}; + +class HttpOverCapnpFactory::ConnectClientRequestContextImpl final + : public capnp::HttpService::ConnectClientRequestContext::Server { +public: + ConnectClientRequestContextImpl(HttpOverCapnpFactory& factory, + kj::HttpService::ConnectResponse& connResponse) + : factory(factory), connResponse(connResponse) {} + + kj::Promise startConnect(StartConnectContext context) override { + KJ_REQUIRE(!sent, "already called startConnect() or startError()"); + sent = true; + + auto params = context.getParams(); + auto resp = params.getResponse(); + + auto headers = factory.headersToKj(resp.getHeaders()); + connResponse.accept(resp.getStatusCode(), resp.getStatusText(), headers); + + return kj::READY_NOW; + } + + kj::Promise startError(StartErrorContext context) override { + KJ_REQUIRE(!sent, "already called startConnect() or startError()"); + sent = true; + + auto params = context.getParams(); + auto resp = params.getResponse(); + + auto headers = factory.headersToKj(resp.getHeaders()); + + auto bodySize = resp.getBodySize(); + kj::Maybe expectedSize; + if (bodySize.isFixed()) { + expectedSize = bodySize.getFixed(); + } + + auto stream = connResponse.reject( + resp.getStatusCode(), resp.getStatusText(), headers, expectedSize); + + context.initResults().setBody(factory.streamFactory.kjToCapnp(kj::mv(stream))); + + return kj::READY_NOW; + } + +private: + HttpOverCapnpFactory& factory; + bool sent = false; + + kj::HttpService::ConnectResponse& connResponse; +}; + +class HttpOverCapnpFactory::KjToCapnpHttpServiceAdapter final: public kj::HttpService { +public: + KjToCapnpHttpServiceAdapter(HttpOverCapnpFactory& factory, capnp::HttpService::Client inner) + : factory(factory), inner(kj::mv(inner)) {} + + template + kj::Promise requestImpl( + Request rpcRequest, + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse, + AwaitCompletionFunc&& awaitCompletion) { + // Common implementation calling request() or startRequest(). awaitCompletion() waits for + // final completion in a method-specific way. + // + // TODO(cleanup): When we move to C++17 or newer we can use `if constexpr` instead of a + // callback. + + auto metadata = rpcRequest.initRequest(); + metadata.setMethod(static_cast(method)); + metadata.setUrl(url); + metadata.adoptHeaders(factory.headersToCapnp( + headers, Orphanage::getForMessageContaining(metadata))); + + kj::Maybe maybeRequestBody; + + KJ_IF_MAYBE(s, requestBody.tryGetLength()) { + metadata.getBodySize().setFixed(*s); + if (*s == 0) { + maybeRequestBody = nullptr; + } else { + maybeRequestBody = requestBody; + } + } else if ((method == kj::HttpMethod::GET || method == kj::HttpMethod::HEAD) && + headers.get(kj::HttpHeaderId::TRANSFER_ENCODING) == nullptr) { + maybeRequestBody = nullptr; + metadata.getBodySize().setFixed(0); + } else { + metadata.getBodySize().setUnknown(); + maybeRequestBody = requestBody; + } + + auto state = kj::refcounted(); + auto deferredCancel = kj::defer([state = kj::addRef(*state)]() mutable { + state->cancel(); + }); + + rpcRequest.setContext( + kj::heap(factory, kj::addRef(*state), kjResponse)); + + auto pipeline = rpcRequest.send(); + + // Pump upstream -- unless we don't expect a request body. + kj::Maybe> pumpRequestTask; + KJ_IF_MAYBE(rb, maybeRequestBody) { + auto bodyOut = factory.streamFactory.capnpToKjExplicitEnd(pipeline.getRequestBody()); + pumpRequestTask = rb->pumpTo(*bodyOut) + .then([&bodyOut = *bodyOut](uint64_t) mutable { + return bodyOut.end(); + }).eagerlyEvaluate([state = kj::addRef(*state), bodyOut = kj::mv(bodyOut)] + (kj::Exception&& e) mutable { + // A DISCONNECTED exception probably means the server decided not to read the whole request + // before responding. In that case we simply want the pump to end, so that on this end it + // also appears that the service simply didn't read everything. So we don't propagate the + // exception in that case. For any other exception, we want to merge the exception with + // the final result. + if (e.getType() != kj::Exception::Type::DISCONNECTED) { + state->taskFailed(kj::mv(e)); + } + }); + } + + // Wait for the server to indicate completion. Meanwhile, if the + // promise is canceled from the client side, we propagate cancellation naturally, and we + // also call state->cancel(). + return awaitCompletion(pipeline) + // Once the server indicates it is done, then we can cancel pumping the request, because + // obviously the server won't use it. We should not cancel pumping the response since there + // could be data in-flight still. + .attach(kj::mv(pumpRequestTask)) + // finishTasks() will wait for the respones to complete. + .then([state = kj::mv(state)]() mutable { return state->finishTasks(); }) + .attach(kj::mv(deferredCancel)); + } + + kj::Promise request( + kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, + kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse) override { + if (factory.peerOptimizationLevel < LEVEL_2) { + return requestImpl(inner.startRequestRequest(), method, url, headers, requestBody, kjResponse, + [](auto& pipeline) { return pipeline.getContext().whenResolved(); }); + } else { + return requestImpl(inner.requestRequest(), method, url, headers, requestBody, kjResponse, + [](auto& pipeline) { return pipeline.ignoreResult(); }); + } + } + + kj::Promise connect( + kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& connection, + ConnectResponse& tunnel, kj::HttpConnectSettings settings) override { + auto rpcRequest = inner.connectRequest(); + auto downPipe = kj::newOneWayPipe(); + rpcRequest.setHost(host); + rpcRequest.setDown(factory.streamFactory.kjToCapnp(kj::mv(downPipe.out))); + rpcRequest.initSettings().setUseTls(settings.useTls); + + auto context = kj::heap(factory, tunnel); + RevocableServer revocableContext(*context); + + auto builder = capnp::Request< + capnp::HttpService::ConnectParams, + capnp::HttpService::ConnectResults>::Builder(rpcRequest); + rpcRequest.adoptHeaders(factory.headersToCapnp(headers, + Orphanage::getForMessageContaining(builder))); + rpcRequest.setContext(revocableContext.getClient()); + RemotePromise pipeline = rpcRequest.send(); + + // We read from `downPipe` (the other side writes into it.) + auto downPumpTask = downPipe.in->pumpTo(connection) + .then([&connection, down = kj::mv(downPipe.in)](uint64_t) -> kj::Promise { + connection.shutdownWrite(); + return kj::NEVER_DONE; + }); + // We write to `up` (the other side reads from it). + auto up = pipeline.getUp(); + + // We need to create a tlsStarter callback which sends a startTls request to the capnp server. + KJ_IF_MAYBE(tlsStarter, settings.tlsStarter) { + kj::Function(kj::StringPtr)> cb = + [upForStartTls = kj::cp(up)] + (kj::StringPtr expectedServerHostname) + mutable -> kj::Promise { + auto startTlsRpcRequest = upForStartTls.startTlsRequest(); + startTlsRpcRequest.setExpectedServerHostname(expectedServerHostname); + return startTlsRpcRequest.send(); + }; + *tlsStarter = kj::mv(cb); + } + + auto upStream = factory.streamFactory.capnpToKjExplicitEnd(up); + auto upPumpTask = connection.pumpTo(*upStream) + .then([&upStream = *upStream](uint64_t) mutable { + return upStream.end(); + }).then([up = kj::mv(up), upStream = kj::mv(upStream)]() mutable + -> kj::Promise { + return kj::NEVER_DONE; + }); + + return pipeline.ignoreResult() + .attach(kj::mv(downPumpTask), kj::mv(upPumpTask), kj::mv(revocableContext)) + // Separate attach to make sure `revocableContext` is destroyed before `context`. + .attach(kj::mv(context)); + } + + +private: + HttpOverCapnpFactory& factory; + capnp::HttpService::Client inner; +}; + +kj::Own HttpOverCapnpFactory::capnpToKj(capnp::HttpService::Client rpcService) { + return kj::heap(*this, kj::mv(rpcService)); +} + +// ======================================================================================= + +namespace { + +class NullInputStream final: public kj::AsyncInputStream { + // TODO(cleanup): This class has been replicated in a bunch of places now, make it public + // somewhere. + +public: + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return kj::constPromise(); + } + + kj::Maybe tryGetLength() override { + return uint64_t(0); + } + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return kj::constPromise(); + } +}; + +class NullOutputStream final: public kj::AsyncOutputStream { + // TODO(cleanup): This class has been replicated in a bunch of places now, make it public + // somewhere. + +public: + kj::Promise write(const void* buffer, size_t size) override { + return kj::READY_NOW; + } + kj::Promise write(kj::ArrayPtr> pieces) override { + return kj::READY_NOW; + } + kj::Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. +}; + +class ResolvedServerRequestContext final: public capnp::HttpService::ServerRequestContext::Server { +public: + // Nothing! It's done. +}; + +} // namespace + +class HttpOverCapnpFactory::HttpServiceResponseImpl + : public kj::HttpService::Response { +public: + HttpServiceResponseImpl(HttpOverCapnpFactory& factory, + capnp::HttpRequest::Reader request, + capnp::HttpService::ClientRequestContext::Client clientContext) + : factory(factory), + method(validateMethod(request.getMethod())), + url(request.getUrl()), + headers(factory.headersToKj(request.getHeaders())), + clientContext(kj::mv(clientContext)) {} + + kj::Own send( + uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_REQUIRE(replyTask == nullptr, "already called send() or acceptWebSocket()"); + + auto req = clientContext.startResponseRequest(); + + if (method == kj::HttpMethod::HEAD || + statusCode == 204 || statusCode == 304) { + expectedBodySize = uint64_t(0); + } + + auto rpcResponse = req.initResponse(); + rpcResponse.setStatusCode(statusCode); + rpcResponse.setStatusText(statusText); + rpcResponse.adoptHeaders(factory.headersToCapnp( + headers, Orphanage::getForMessageContaining(rpcResponse))); + bool hasBody = true; + KJ_IF_MAYBE(s, expectedBodySize) { + rpcResponse.getBodySize().setFixed(*s); + hasBody = *s > 0; + } + + auto logError = [hasBody](kj::Exception&& e) { + KJ_LOG(INFO, "HTTP-over-RPC startResponse() failed", hasBody, e); + }; + if (hasBody) { + auto pipeline = req.send(); + auto result = factory.streamFactory.capnpToKj(pipeline.getBody()); + replyTask = pipeline.ignoreResult().eagerlyEvaluate(kj::mv(logError)); + return result; + } else { + replyTask = req.send().ignoreResult().eagerlyEvaluate(kj::mv(logError)); + return kj::heap(); + } + + // We don't actually wait for replyTask anywhere, because we may be all done with this HTTP + // message before the client gets a chance to respond, and we don't want to force an extra + // network round trip. If the client fails this call that's the client's problem, really. + } + + kj::Own acceptWebSocket(const kj::HttpHeaders& headers) override { + KJ_REQUIRE(replyTask == nullptr, "already called send() or acceptWebSocket()"); + + auto req = clientContext.startWebSocketRequest(); + + req.adoptHeaders(factory.headersToCapnp( + headers, Orphanage::getForMessageContaining( + capnp::HttpService::ClientRequestContext::StartWebSocketParams::Builder(req)))); + + auto pipe = kj::newWebSocketPipe(); + auto shorteningPaf = kj::newPromiseAndFulfiller>(); + + // We don't need the RequestState mechanism on the server side because + // CapnpToKjWebSocketAdapter wraps a pipe end, and that pipe end can continue to exist beyond + // the lifetime of the request, because the other end will have been dropped. We only create + // a RequestState here so that we can reuse the implementation of CapnpToKjWebSocketAdapter + // that needs this for the client side. + auto dummyState = kj::refcounted(); + auto& pipeEnd0Ref = *pipe.ends[0]; + dummyState->holdWebSocket(kj::mv(pipe.ends[0])); + req.setUpSocket(kj::heap( + kj::mv(dummyState), pipeEnd0Ref, kj::mv(shorteningPaf.promise))); + + auto pipeline = req.send(); + auto result = kj::heap( + kj::mv(pipe.ends[1]), pipeline.getDownSocket(), kj::mv(shorteningPaf.fulfiller)); + + // Note we need eagerlyEvaluate() here to force proactively discarding the response object, + // since it holds a reference to `downSocket`. + replyTask = pipeline.ignoreResult() + .eagerlyEvaluate([](kj::Exception&& e) { + KJ_LOG(INFO, "HTTP-over-RPC startWebSocketRequest() failed", e); + }); + + return result; + } + + HttpOverCapnpFactory& factory; + kj::HttpMethod method; + kj::StringPtr url; + kj::HttpHeaders headers; + capnp::HttpService::ClientRequestContext::Client clientContext; + kj::Maybe> replyTask; + + static kj::HttpMethod validateMethod(capnp::HttpMethod method) { + KJ_REQUIRE(method <= capnp::HttpMethod::UNSUBSCRIBE, "unknown method", method); + return static_cast(method); + } +}; + +class HttpOverCapnpFactory::HttpOverCapnpConnectResponseImpl final : + public kj::HttpService::ConnectResponse { +public: + HttpOverCapnpConnectResponseImpl( + HttpOverCapnpFactory& factory, + capnp::HttpService::ConnectClientRequestContext::Client context) : + context(context), factory(factory) {} + + void accept(uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers) override { + KJ_REQUIRE(replyTask == nullptr, "already called accept() or reject()"); + + auto req = context.startConnectRequest(); + auto rpcResponse = req.initResponse(); + rpcResponse.setStatusCode(statusCode); + rpcResponse.setStatusText(statusText); + rpcResponse.adoptHeaders(factory.headersToCapnp( + headers, Orphanage::getForMessageContaining(rpcResponse))); + + replyTask = req.send().ignoreResult(); + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const kj::HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_REQUIRE(replyTask == nullptr, "already called accept() or reject()"); + auto pipe = kj::newOneWayPipe(expectedBodySize); + + auto req = context.startErrorRequest(); + auto rpcResponse = req.initResponse(); + rpcResponse.setStatusCode(statusCode); + rpcResponse.setStatusText(statusText); + rpcResponse.adoptHeaders(factory.headersToCapnp( + headers, Orphanage::getForMessageContaining(rpcResponse))); + + auto errorBody = kj::mv(pipe.in); + // Set the body size if the error body exists. + KJ_IF_MAYBE(size, errorBody->tryGetLength()) { + rpcResponse.getBodySize().setFixed(*size); + } + + replyTask = req.send().then( + [this, errorBody = kj::mv(errorBody)](auto resp) mutable -> kj::Promise { + auto body = factory.streamFactory.capnpToKjExplicitEnd(resp.getBody()); + return errorBody->pumpTo(*body) + .then([&body = *body](uint64_t) mutable { + return body.end(); + }).attach(kj::mv(errorBody), kj::mv(body)); + }); + + return kj::mv(pipe.out); + } + + capnp::HttpService::ConnectClientRequestContext::Client context; + capnp::HttpOverCapnpFactory& factory; + kj::Maybe> replyTask; +}; + + +class HttpOverCapnpFactory::ServerRequestContextImpl final + : public capnp::HttpService::ServerRequestContext::Server, + public HttpServiceResponseImpl { +public: + ServerRequestContextImpl(HttpOverCapnpFactory& factory, + HttpService::Client serviceCap, + kj::Own request, + capnp::HttpService::ClientRequestContext::Client clientContext, + kj::Own requestBodyIn, + kj::HttpService& kjService) + : HttpServiceResponseImpl(factory, *request, kj::mv(clientContext)), + request(kj::mv(request)), + serviceCap(kj::mv(serviceCap)), + // Note we attach `requestBodyIn` to `task` so that we will implicitly cancel reading + // the request body as soon as the service returns. This is important in particular when + // the request body is not fully consumed, in order to propagate cancellation. + task(kjService.request(method, url, headers, *requestBodyIn, *this) + .attach(kj::mv(requestBodyIn))) {} + + kj::Maybe> shortenPath() override { + return task.then([]() -> Capability::Client { + // If all went well, resolve to a settled capability. + // TODO(perf): Could save a message by resolving to a capability hosted by the client, or + // some special "null" capability that isn't an error but is still transmitted by value. + // Otherwise we need a Release message from client -> server just to drop this... + return kj::heap(); + }); + } + + KJ_DISALLOW_COPY_AND_MOVE(ServerRequestContextImpl); + +private: + kj::Own request; + HttpService::Client serviceCap; // ensures the inner kj::HttpService isn't destroyed + kj::Promise task; +}; + +class HttpOverCapnpFactory::CapnpToKjHttpServiceAdapter final: public capnp::HttpService::Server { +public: + CapnpToKjHttpServiceAdapter(HttpOverCapnpFactory& factory, kj::Own inner) + : factory(factory), inner(kj::mv(inner)) {} + + template + kj::Promise requestImpl(CallContext context, Callback&& callback) { + // Common implementation of request() and startRequest(). callback() performs the + // method-specific stuff at the end. + // + // TODO(cleanup): When we move to C++17 or newer we can use `if constexpr` instead of a + // callback. + + auto params = context.getParams(); + auto metadata = params.getRequest(); + + auto bodySize = metadata.getBodySize(); + kj::Maybe expectedSize; + bool hasBody = true; + if (bodySize.isFixed()) { + auto size = bodySize.getFixed(); + expectedSize = bodySize.getFixed(); + hasBody = size > 0; + } + + auto results = context.getResults(MessageSize {8, 2}); + kj::Own requestBody; + if (hasBody) { + auto pipe = kj::newOneWayPipe(expectedSize); + auto requestBodyCap = factory.streamFactory.kjToCapnp(kj::mv(pipe.out)); + + if (kj::isSameType()) { + // For request(), use context.setPipeline() to enable pipelined calls to the request body + // stream before this RPC completes. (We don't bother when using startRequest() because + // it returns immediately anyway, so this would just waste effort.) + PipelineBuilder pipeline; + pipeline.setRequestBody(kj::cp(requestBodyCap)); + context.setPipeline(pipeline.build()); + } + + results.setRequestBody(kj::mv(requestBodyCap)); + requestBody = kj::mv(pipe.in); + } else { + requestBody = kj::heap(); + } + + return callback(results, metadata, params, requestBody); + } + + kj::Promise request(RequestContext context) override { + return requestImpl(kj::mv(context), + [&](auto& results, auto& metadata, auto& params, auto& requestBody) { + class FinalHttpServiceResponseImpl final: public HttpServiceResponseImpl { + public: + using HttpServiceResponseImpl::HttpServiceResponseImpl; + }; + auto impl = kj::heap(factory, metadata, params.getContext()); + auto promise = inner->request(impl->method, impl->url, impl->headers, *requestBody, *impl); + return promise.attach(kj::mv(requestBody), kj::mv(impl)); + }); + } + + kj::Promise startRequest(StartRequestContext context) override { + return requestImpl(kj::mv(context), + [&](auto& results, auto& metadata, auto& params, auto& requestBody) { + results.setContext(kj::heap( + factory, thisCap(), capnp::clone(metadata), params.getContext(), kj::mv(requestBody), + *inner)); + + return kj::READY_NOW; + }); + } + + kj::Promise connect(ConnectContext context) override { + auto params = context.getParams(); + auto host = params.getHost(); + kj::Own tlsStarter = kj::heap(); + kj::HttpConnectSettings settings = { .useTls = params.getSettings().getUseTls()}; + settings.tlsStarter = tlsStarter; + auto headers = factory.headersToKj(params.getHeaders()); + auto pipe = kj::newTwoWayPipe(); + + class EofDetector final: public kj::AsyncOutputStream { + public: + EofDetector(kj::Own inner) + : inner(kj::mv(inner)) {} + ~EofDetector() { + inner->shutdownWrite(); + } + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + return inner->tryPumpFrom(input, amount); + } + + kj::Promise write(const void* buffer, size_t size) override { + return inner->write(buffer, size); + } + + kj::Promise write(kj::ArrayPtr> pieces) override { + return inner->write(pieces); + } + + kj::Promise whenWriteDisconnected() override { + return inner->whenWriteDisconnected(); + } + private: + kj::Own inner; + }; + + auto stream = factory.streamFactory.capnpToKjExplicitEnd(context.getParams().getDown()); + + // We want to keep the stream alive even after EofDetector is destroyed, so we need to create + // a refcounted AsyncIoStream. + auto refcounted = kj::refcountedWrapper(kj::mv(pipe.ends[1])); + kj::Own ref1 = refcounted->addWrappedRef(); + kj::Own ref2 = refcounted->addWrappedRef(); + + // We write to the `down` pipe. + auto pumpTask = ref1->pumpTo(*stream) + .then([&stream = *stream](uint64_t) mutable { + return stream.end(); + }).then([httpProxyStream = kj::mv(ref1), stream = kj::mv(stream)]() mutable + -> kj::Promise { + return kj::NEVER_DONE; + }); + + PipelineBuilder pb; + auto eofWrapper = kj::heap(kj::mv(ref2)); + auto up = factory.streamFactory.kjToCapnp(kj::mv(eofWrapper), kj::mv(tlsStarter)); + pb.setUp(kj::cp(up)); + + context.setPipeline(pb.build()); + context.initResults(capnp::MessageSize { 4, 1 }).setUp(kj::mv(up)); + + auto response = kj::heap( + factory, context.getParams().getContext()); + + return inner->connect(host, headers, *pipe.ends[0], *response, settings).attach( + kj::mv(host), kj::mv(headers), kj::mv(response), kj::mv(pipe)) + .exclusiveJoin(kj::mv(pumpTask)); + } + +private: + HttpOverCapnpFactory& factory; + kj::Own inner; +}; + +capnp::HttpService::Client HttpOverCapnpFactory::kjToCapnp(kj::Own service) { + return kj::heap(*this, kj::mv(service)); +} + +// ======================================================================================= + +static constexpr uint64_t COMMON_TEXT_ANNOTATION = 0x857745131db6fc83ull; +// Type ID of `commonText` from `http.capnp`. +// TODO(cleanup): Cap'n Proto should auto-generate constants for these. + +HttpOverCapnpFactory::HeaderIdBundle::HeaderIdBundle(kj::HttpHeaderTable::Builder& builder) + : table(builder.getFutureTable()) { + auto commonHeaderNames = Schema::from().getEnumerants(); + nameCapnpToKj = kj::heapArray(commonHeaderNames.size()); + for (size_t i = 1; i < commonHeaderNames.size(); i++) { + kj::StringPtr nameText; + for (auto ann: commonHeaderNames[i].getProto().getAnnotations()) { + if (ann.getId() == COMMON_TEXT_ANNOTATION) { + nameText = ann.getValue().getText(); + break; + } + } + KJ_ASSERT(nameText != nullptr); + kj::HttpHeaderId headerId = builder.add(nameText); + nameCapnpToKj[i] = headerId; + maxHeaderId = kj::max(maxHeaderId, headerId.hashCode()); + } +} + +HttpOverCapnpFactory::HeaderIdBundle::HeaderIdBundle( + const kj::HttpHeaderTable& table, kj::Array nameCapnpToKj, size_t maxHeaderId) + : table(table), nameCapnpToKj(kj::mv(nameCapnpToKj)), maxHeaderId(maxHeaderId) {} + +HttpOverCapnpFactory::HeaderIdBundle HttpOverCapnpFactory::HeaderIdBundle::clone() const { + return HeaderIdBundle(table, kj::heapArray(nameCapnpToKj), maxHeaderId); +} + +HttpOverCapnpFactory::HttpOverCapnpFactory(ByteStreamFactory& streamFactory, + HeaderIdBundle headerIds, + OptimizationLevel peerOptimizationLevel) + : streamFactory(streamFactory), headerTable(headerIds.table), + peerOptimizationLevel(peerOptimizationLevel), + nameCapnpToKj(kj::mv(headerIds.nameCapnpToKj)) { + auto commonHeaderNames = Schema::from().getEnumerants(); + nameKjToCapnp = kj::heapArray(headerIds.maxHeaderId + 1); + for (auto& slot: nameKjToCapnp) slot = capnp::CommonHeaderName::INVALID; + + for (size_t i = 1; i < commonHeaderNames.size(); i++) { + auto& slot = nameKjToCapnp[nameCapnpToKj[i].hashCode()]; + KJ_ASSERT(slot == capnp::CommonHeaderName::INVALID); + slot = static_cast(i); + } + + auto commonHeaderValues = Schema::from().getEnumerants(); + valueCapnpToKj = kj::heapArray(commonHeaderValues.size()); + for (size_t i = 1; i < commonHeaderValues.size(); i++) { + kj::StringPtr valueText; + for (auto ann: commonHeaderValues[i].getProto().getAnnotations()) { + if (ann.getId() == COMMON_TEXT_ANNOTATION) { + valueText = ann.getValue().getText(); + break; + } + } + KJ_ASSERT(valueText != nullptr); + valueCapnpToKj[i] = valueText; + valueKjToCapnp.insert(valueText, static_cast(i)); + } +} + +Orphan> HttpOverCapnpFactory::headersToCapnp( + const kj::HttpHeaders& headers, Orphanage orphanage) { + auto result = orphanage.newOrphan>(headers.size()); + auto rpcHeaders = result.get(); + uint i = 0; + headers.forEach([&](kj::HttpHeaderId id, kj::StringPtr value) { + auto capnpName = id.hashCode() < nameKjToCapnp.size() + ? nameKjToCapnp[id.hashCode()] + : capnp::CommonHeaderName::INVALID; + if (capnpName == capnp::CommonHeaderName::INVALID) { + auto header = rpcHeaders[i++].initUncommon(); + header.setName(id.toString()); + header.setValue(value); + } else { + auto header = rpcHeaders[i++].initCommon(); + header.setName(capnpName); + header.setValue(value); + } + }, [&](kj::StringPtr name, kj::StringPtr value) { + auto header = rpcHeaders[i++].initUncommon(); + header.setName(name); + header.setValue(value); + }); + KJ_ASSERT(i == rpcHeaders.size()); + return result; +} + +kj::HttpHeaders HttpOverCapnpFactory::headersToKj( + List::Reader capnpHeaders) const { + kj::HttpHeaders result(headerTable); + + for (auto header: capnpHeaders) { + switch (header.which()) { + case capnp::HttpHeader::COMMON: { + auto nv = header.getCommon(); + auto nameInt = static_cast(nv.getName()); + KJ_REQUIRE(nameInt < nameCapnpToKj.size(), "unknown common header name", nv.getName()); + + switch (nv.which()) { + case capnp::HttpHeader::Common::COMMON_VALUE: { + auto cvInt = static_cast(nv.getCommonValue()); + KJ_REQUIRE(nameInt < valueCapnpToKj.size(), + "unknown common header value", nv.getCommonValue()); + result.set(nameCapnpToKj[nameInt], valueCapnpToKj[cvInt]); + break; + } + case capnp::HttpHeader::Common::VALUE: { + auto headerId = nameCapnpToKj[nameInt]; + if (result.get(headerId) == nullptr) { + result.set(headerId, nv.getValue()); + } else { + // Unusual: This is a duplicate header, so fall back to add(), which may trigger + // comma-concatenation, except in certain cases where comma-concatentaion would + // be problematic. + result.add(headerId.toString(), nv.getValue()); + } + break; + } + } + break; + } + case capnp::HttpHeader::UNCOMMON: { + auto nv = header.getUncommon(); + result.add(nv.getName(), nv.getValue()); + } + } + } + + return result; +} + +} // namespace capnp diff --git a/c++/src/capnp/compat/http-over-capnp.capnp b/c++/src/capnp/compat/http-over-capnp.capnp new file mode 100644 index 0000000000..8b4afba043 --- /dev/null +++ b/c++/src/capnp/compat/http-over-capnp.capnp @@ -0,0 +1,268 @@ +# Copyright (c) 2019 Cloudflare, Inc. and contributors +# Licensed under the MIT License: +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +@0xb665280aaff2e632; +# Cap'n Proto interface for HTTP. + +using import "byte-stream.capnp".ByteStream; + +using Cxx = import "/capnp/c++.capnp"; +$Cxx.namespace("capnp"); +$Cxx.allowCancellation; + +interface HttpService { + request @1 (request :HttpRequest, context :ClientRequestContext) + -> (requestBody :ByteStream); + # Perform an HTTP request. + # + # The client sends the request method/url/headers. The server responds with a `ByteStream` where + # the client can make calls to stream up the request body. `requestBody` will be null in the case + # that request.bodySize.fixed == 0. + # + # The server will send a response by invoking a method on `callback`. + # + # `request()` does not return until the server is completely done processing the request, + # including sending the response. The client therefore must use promise pipelining to send the + # request body. The client may request cancellation of the HTTP request by canceling the + # `request()` call itself. + + startRequest @0 (request :HttpRequest, context :ClientRequestContext) + -> (requestBody :ByteStream, context :ServerRequestContext); + # DEPRECATED: Older form of `request()`. In this version, the server immediately returns a + # `ServerRequestContext` before it begins processing the request. This version was designed + # before `CallContext::setPipeline()` was introduced. At that time, it was impossible for the + # server to receive data sent to the `requestBody` stream until `startRequest()` had returned + # a stream capability to use, hence the ongoing call on the server side had to be represented + # using a separate capability. Now that we have `CallContext::setPipeline()`, the server can + # begin receiving the request body without returning from the top-level RPC, so we can now use + # `request()` instead of `startRequest()`. The new approach is more intuitive and avoids some + # unnecessary bookkeeping. + # + # `HttpOverCapnpFactory` will continue to support both methods. Use the `peerOptimizationLevel` + # constructor parameter to specify which method to use, for backwards-compatibiltiy purposes. + + connect @2 (host :Text, headers :List(HttpHeader), down :ByteStream, + context :ConnectClientRequestContext, settings :ConnectSettings) + -> (up :ByteStream); + # Setup an HTTP CONNECT proxy tunnel. + # + # The client sends the request host/headers together with a `down` ByteStream that will be used + # for communication across the tunnel. The server will respond with the other side of that + # ByteStream for two-way communication. The `context` includes callbacks which are used to + # supply the client with headers. + + interface ClientRequestContext { + # Provides callbacks for the server to send the response. + + startResponse @0 (response :HttpResponse) -> (body :ByteStream); + # Server calls this method to send the response status and headers and to begin streaming the + # response body. `body` will be null in the case that response.bodySize.fixed == 0, which is + # required for HEAD responses and status codes 204, 205, and 304. + + startWebSocket @1 (headers :List(HttpHeader), upSocket :WebSocket) + -> (downSocket :WebSocket); + # Server calls this method to indicate that the request is a valid WebSocket handshake and it + # wishes to accept it as a WebSocket. + # + # Client -> Server WebSocket frames will be sent via method calls on `upSocket`, while + # Server -> Client will be sent as calls to `downSocket`. + } + + interface ConnectClientRequestContext { + # Provides callbacks for the server to send the response. + + startConnect @0 (response :HttpResponse); + # Server calls this method to let the client know that the CONNECT request has been + # accepted. It also includes status code and header information. + + startError @1 (response :HttpResponse) -> (body :ByteStream); + # Server calls this method to let the client know that the CONNECT request has been rejected. + } + + interface ServerRequestContext { + # DEPRECATED: Used only with startRequest(); see comments there. + # + # Represents execution of a particular request on the server side. + # + # Dropping this object before the request completes will cancel the request. + # + # ServerRequestContext is always a promise capability. The client must wait for it to + # resolve using whenMoreResolved() in order to find out when the server is really done + # processing the request. This will throw an exception if the server failed in some way that + # could not be captured in the HTTP response. Note that it's possible for such an exception to + # be thrown even after the response body has been completely transmitted. + } +} + +struct ConnectSettings { + useTls @0 :Bool; +} + +interface WebSocket { + sendText @0 (text :Text) -> stream; + sendData @1 (data :Data) -> stream; + # Send a text or data frame. + + close @2 (code :UInt16, reason :Text); + # Send a close frame. +} + +struct HttpRequest { + # Standard HTTP request metadata. + + method @0 :HttpMethod; + url @1 :Text; + headers @2 :List(HttpHeader); + bodySize :union { + unknown @3 :Void; # e.g. due to transfer-encoding: chunked + fixed @4 :UInt64; # e.g. due to content-length + } +} + +struct HttpResponse { + # Standard HTTP response metadata. + + statusCode @0 :UInt16; + statusText @1 :Text; # leave null if it matches the default for statusCode + headers @2 :List(HttpHeader); + bodySize :union { + unknown @3 :Void; # e.g. due to transfer-encoding: chunked + fixed @4 :UInt64; # e.g. due to content-length + } +} + +enum HttpMethod { + # This enum aligns precisely with the kj::HttpMethod enum. However, the backwards-compat + # constraints of a public-facing C++ enum vs. an internal Cap'n Proto interface differ in + # several ways, which could possibly lead to divergence someday. For now, a unit test verifies + # that they match exactly; if that test ever fails, we'll have to figure out what to do about it. + + get @0; + head @1; + post @2; + put @3; + delete @4; + patch @5; + purge @6; + options @7; + trace @8; + + copy @9; + lock @10; + mkcol @11; + move @12; + propfind @13; + proppatch @14; + search @15; + unlock @16; + acl @17; + + report @18; + mkactivity @19; + checkout @20; + merge @21; + + msearch @22; + notify @23; + subscribe @24; + unsubscribe @25; +} + +annotation commonText @0x857745131db6fc83(enumerant) :Text; + +enum CommonHeaderName { + invalid @0; + # Dummy to serve as default value. Should never actually appear on wire. + + acceptCharset @1 $commonText("Accept-Charset"); + acceptEncoding @2 $commonText("Accept-Encoding"); + acceptLanguage @3 $commonText("Accept-Language"); + acceptRanges @4 $commonText("Accept-Ranges"); + accept @5 $commonText("Accept"); + accessControlAllowOrigin @6 $commonText("Access-Control-Allow-Origin"); + age @7 $commonText("Age"); + allow @8 $commonText("Allow"); + authorization @9 $commonText("Authorization"); + cacheControl @10 $commonText("Cache-Control"); + contentDisposition @11 $commonText("Content-Disposition"); + contentEncoding @12 $commonText("Content-Encoding"); + contentLanguage @13 $commonText("Content-Language"); + contentLength @14 $commonText("Content-Length"); + contentLocation @15 $commonText("Content-Location"); + contentRange @16 $commonText("Content-Range"); + contentType @17 $commonText("Content-Type"); + cookie @18 $commonText("Cookie"); + date @19 $commonText("Date"); + etag @20 $commonText("ETag"); + expect @21 $commonText("Expect"); + expires @22 $commonText("Expires"); + from @23 $commonText("From"); + host @24 $commonText("Host"); + ifMatch @25 $commonText("If-Match"); + ifModifiedSince @26 $commonText("If-Modified-Since"); + ifNoneMatch @27 $commonText("If-None-Match"); + ifRange @28 $commonText("If-Range"); + ifUnmodifiedSince @29 $commonText("If-Unmodified-Since"); + lastModified @30 $commonText("Last-Modified"); + link @31 $commonText("Link"); + location @32 $commonText("Location"); + maxForwards @33 $commonText("Max-Forwards"); + proxyAuthenticate @34 $commonText("Proxy-Authenticate"); + proxyAuthorization @35 $commonText("Proxy-Authorization"); + range @36 $commonText("Range"); + referer @37 $commonText("Referer"); + refresh @38 $commonText("Refresh"); + retryAfter @39 $commonText("Retry-After"); + server @40 $commonText("Server"); + setCookie @41 $commonText("Set-Cookie"); + strictTransportSecurity @42 $commonText("Strict-Transport-Security"); + transferEncoding @43 $commonText("Transfer-Encoding"); + userAgent @44 $commonText("User-Agent"); + vary @45 $commonText("Vary"); + via @46 $commonText("Via"); + wwwAuthenticate @47 $commonText("WWW-Authenticate"); +} + +enum CommonHeaderValue { + invalid @0; + + gzipDeflate @1 $commonText("gzip, deflate"); + + # TODO(someday): "gzip, deflate" is the only common header value recognized by HPACK. +} + +struct HttpHeader { + union { + common :group { + name @0 :CommonHeaderName; + union { + commonValue @1 :CommonHeaderValue; + value @2 :Text; + } + } + uncommon @3 :NameValue; + } + + struct NameValue { + name @0 :Text; + value @1 :Text; + } +} diff --git a/c++/src/capnp/compat/http-over-capnp.h b/c++/src/capnp/compat/http-over-capnp.h new file mode 100644 index 0000000000..6b16749118 --- /dev/null +++ b/c++/src/capnp/compat/http-over-capnp.h @@ -0,0 +1,106 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once +// Bridges from KJ HTTP to Cap'n Proto HTTP-over-RPC. + +#include +#include +#include +#include "byte-stream.h" + +CAPNP_BEGIN_HEADER + +namespace capnp { + +class HttpOverCapnpFactory { +public: + class HeaderIdBundle { + public: + HeaderIdBundle(kj::HttpHeaderTable::Builder& builder); + + HeaderIdBundle clone() const; + + private: + HeaderIdBundle(const kj::HttpHeaderTable& table, kj::Array nameCapnpToKj, + size_t maxHeaderId); + // Constructor for clone(). + + const kj::HttpHeaderTable& table; + + kj::Array nameCapnpToKj; + size_t maxHeaderId = 0; + + friend class HttpOverCapnpFactory; + }; + + enum OptimizationLevel { + // Specifies the protocol optimization level supported by the remote peer. Setting this higher + // will improve efficiency but breaks compatibility with older peers that don't implement newer + // levels. + + LEVEL_1, + // Use startRequest(), the original version of the protocol. + + LEVEL_2 + // Use request(). This is more efficient than startRequest() but won't work with old peers that + // only implement startRequest(). + }; + + HttpOverCapnpFactory(ByteStreamFactory& streamFactory, HeaderIdBundle headerIds, + OptimizationLevel peerOptimizationLevel = LEVEL_1); + + kj::Own capnpToKj(capnp::HttpService::Client rpcService); + capnp::HttpService::Client kjToCapnp(kj::Own service); + +private: + ByteStreamFactory& streamFactory; + const kj::HttpHeaderTable& headerTable; + OptimizationLevel peerOptimizationLevel; + kj::Array nameKjToCapnp; + kj::Array nameCapnpToKj; + kj::Array valueCapnpToKj; + kj::HashMap valueKjToCapnp; + + class RequestState; + + class CapnpToKjWebSocketAdapter; + class KjToCapnpWebSocketAdapter; + + class ClientRequestContextImpl; + class ConnectClientRequestContextImpl; + class KjToCapnpHttpServiceAdapter; + + class HttpServiceResponseImpl; + class HttpOverCapnpConnectResponseImpl; + class ServerRequestContextImpl; + class CapnpToKjHttpServiceAdapter; + + kj::HttpHeaders headersToKj(capnp::List::Reader capnpHeaders) const; + // Returned headers may alias into `capnpHeaders`. + + capnp::Orphan> headersToCapnp( + const kj::HttpHeaders& headers, capnp::Orphanage orphanage); +}; + +} // namespace capnp + +CAPNP_END_HEADER diff --git a/c++/src/capnp/compat/json-rpc-test.c++ b/c++/src/capnp/compat/json-rpc-test.c++ new file mode 100644 index 0000000000..bf78bb577b --- /dev/null +++ b/c++/src/capnp/compat/json-rpc-test.c++ @@ -0,0 +1,100 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "json-rpc.h" +#include +#include + +namespace capnp { +namespace _ { // private +namespace { + +KJ_TEST("json-rpc basics") { + auto io = kj::setupAsyncIo(); + auto pipe = kj::newTwoWayPipe(); + + JsonRpc::ContentLengthTransport clientTransport(*pipe.ends[0]); + JsonRpc::ContentLengthTransport serverTransport(*pipe.ends[1]); + + int callCount = 0; + + JsonRpc client(clientTransport); + JsonRpc server(serverTransport, toDynamic(kj::heap(callCount))); + + auto cap = client.getPeer(); + auto req = cap.fooRequest(); + req.setI(123); + req.setJ(true); + auto resp = req.send().wait(io.waitScope); + KJ_EXPECT(resp.getX() == "foo"); + + KJ_EXPECT(callCount == 1); +} + +KJ_TEST("json-rpc error") { + auto io = kj::setupAsyncIo(); + auto pipe = kj::newTwoWayPipe(); + + JsonRpc::ContentLengthTransport clientTransport(*pipe.ends[0]); + JsonRpc::ContentLengthTransport serverTransport(*pipe.ends[1]); + + int callCount = 0; + + JsonRpc client(clientTransport); + JsonRpc server(serverTransport, toDynamic(kj::heap(callCount))); + + auto cap = client.getPeer(); + KJ_EXPECT_THROW_MESSAGE("Method not implemented", cap.barRequest().send().wait(io.waitScope)); +} + +KJ_TEST("json-rpc multiple calls") { + auto io = kj::setupAsyncIo(); + auto pipe = kj::newTwoWayPipe(); + + JsonRpc::ContentLengthTransport clientTransport(*pipe.ends[0]); + JsonRpc::ContentLengthTransport serverTransport(*pipe.ends[1]); + + int callCount = 0; + + JsonRpc client(clientTransport); + JsonRpc server(serverTransport, toDynamic(kj::heap(callCount))); + + auto cap = client.getPeer(); + auto req1 = cap.fooRequest(); + req1.setI(123); + req1.setJ(true); + auto promise1 = req1.send(); + + auto req2 = cap.bazRequest(); + initTestMessage(req2.initS()); + auto promise2 = req2.send(); + + auto resp1 = promise1.wait(io.waitScope); + KJ_EXPECT(resp1.getX() == "foo"); + + auto resp2 = promise2.wait(io.waitScope); + + KJ_EXPECT(callCount == 2); +} + +} // namespace +} // namespace _ (private) +} // namespace capnp diff --git a/c++/src/capnp/compat/json-rpc.c++ b/c++/src/capnp/compat/json-rpc.c++ new file mode 100644 index 0000000000..bd1fbf125b --- /dev/null +++ b/c++/src/capnp/compat/json-rpc.c++ @@ -0,0 +1,340 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "json-rpc.h" +#include +#include + +namespace capnp { + +static constexpr uint64_t JSON_NAME_ANNOTATION_ID = 0xfa5b1fd61c2e7c3dull; +static constexpr uint64_t JSON_NOTIFICATION_ANNOTATION_ID = 0xa0a054dea32fd98cull; + +class JsonRpc::CapabilityImpl final: public DynamicCapability::Server { +public: + CapabilityImpl(JsonRpc& parent, InterfaceSchema schema) + : DynamicCapability::Server(schema), parent(parent) {} + + kj::Promise call(InterfaceSchema::Method method, + CallContext context) override { + auto proto = method.getProto(); + bool isNotification = false; + kj::StringPtr name = proto.getName(); + for (auto annotation: proto.getAnnotations()) { + switch (annotation.getId()) { + case JSON_NAME_ANNOTATION_ID: + name = annotation.getValue().getText(); + break; + case JSON_NOTIFICATION_ANNOTATION_ID: + isNotification = true; + break; + } + } + + capnp::MallocMessageBuilder message; + auto value = message.getRoot(); + auto list = value.initObject(3 + !isNotification); + + uint index = 0; + + auto jsonrpc = list[index++]; + jsonrpc.setName("jsonrpc"); + jsonrpc.initValue().setString("2.0"); + + uint callId = parent.callCount++; + + if (!isNotification) { + auto id = list[index++]; + id.setName("id"); + id.initValue().setNumber(callId); + } + + auto methodName = list[index++]; + methodName.setName("method"); + methodName.initValue().setString(name); + + auto params = list[index++]; + params.setName("params"); + parent.codec.encode(context.getParams(), params.initValue()); + + auto writePromise = parent.queueWrite(parent.codec.encode(value)); + + if (isNotification) { + auto sproto = context.getResultsType().getProto().getStruct(); + MessageSize size { sproto.getDataWordCount(), sproto.getPointerCount() }; + context.initResults(size); + return kj::mv(writePromise); + } else { + auto paf = kj::newPromiseAndFulfiller(); + parent.awaitedResponses.insert(callId, AwaitedResponse { context, kj::mv(paf.fulfiller) }); + auto promise = writePromise.then([p = kj::mv(paf.promise)]() mutable { return kj::mv(p); }); + auto& parentRef = parent; + return promise.attach(kj::defer([&parentRef,callId]() { + parentRef.awaitedResponses.erase(callId); + })); + } + } + +private: + JsonRpc& parent; +}; + +JsonRpc::JsonRpc(Transport& transport, DynamicCapability::Client interface) + : JsonRpc(transport, kj::mv(interface), kj::newPromiseAndFulfiller()) {} +JsonRpc::JsonRpc(Transport& transport, DynamicCapability::Client interfaceParam, + kj::PromiseFulfillerPair paf) + : transport(transport), + interface(kj::mv(interfaceParam)), + errorPromise(paf.promise.fork()), + errorFulfiller(kj::mv(paf.fulfiller)), + readTask(readLoop().eagerlyEvaluate([this](kj::Exception&& e) { + errorFulfiller->reject(kj::mv(e)); + })), + tasks(*this) { + codec.handleByAnnotation(interface.getSchema()); + codec.handleByAnnotation(); + + for (auto method: interface.getSchema().getMethods()) { + auto proto = method.getProto(); + kj::StringPtr name = proto.getName(); + for (auto annotation: proto.getAnnotations()) { + switch (annotation.getId()) { + case JSON_NAME_ANNOTATION_ID: + name = annotation.getValue().getText(); + break; + } + } + methodMap.insert(name, method); + } +} + +DynamicCapability::Client JsonRpc::getPeer(InterfaceSchema schema) { + codec.handleByAnnotation(interface.getSchema()); + return kj::heap(*this, schema); +} + +static kj::HttpHeaderTable& staticHeaderTable() { + static kj::HttpHeaderTable HEADER_TABLE; + return HEADER_TABLE; +} + +kj::Promise JsonRpc::queueWrite(kj::String text) { + auto fork = writeQueue.then([this, text = kj::mv(text)]() mutable { + auto promise = transport.send(text); + return promise.attach(kj::mv(text)); + }).eagerlyEvaluate([this](kj::Exception&& e) { + errorFulfiller->reject(kj::mv(e)); + }).fork(); + writeQueue = fork.addBranch(); + return fork.addBranch(); +} + +void JsonRpc::queueError(kj::Maybe id, int code, kj::StringPtr message) { + MallocMessageBuilder capnpMessage; + auto jsonResponse = capnpMessage.getRoot(); + jsonResponse.setJsonrpc("2.0"); + KJ_IF_MAYBE(i, id) { + jsonResponse.setId(*i); + } else { + jsonResponse.initId().setNull(); + } + auto error = jsonResponse.initError(); + error.setCode(code); + error.setMessage(message); + + // OK to discard result of queueWrite() since it's just one branch of a fork. + queueWrite(codec.encode(jsonResponse)); +} + +kj::Promise JsonRpc::readLoop() { + return transport.receive().then([this](kj::String message) -> kj::Promise { + MallocMessageBuilder capnpMessage; + auto rpcMessageBuilder = capnpMessage.getRoot(); + + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + codec.decode(message, rpcMessageBuilder); + })) { + queueError(nullptr, -32700, kj::str("Parse error: ", exception->getDescription())); + return readLoop(); + } + + KJ_CONTEXT("decoding JSON-RPC message", message); + + auto rpcMessage = rpcMessageBuilder.asReader(); + + if (!rpcMessage.hasJsonrpc()) { + queueError(nullptr, -32700, kj::str("Missing 'jsonrpc' field.")); + return readLoop(); + } else if (rpcMessage.getJsonrpc() != "2.0") { + queueError(nullptr, -32700, + kj::str("Unknown JSON-RPC version. This peer implements version '2.0'.")); + return readLoop(); + } + + switch (rpcMessage.which()) { + case json::RpcMessage::NONE: + queueError(nullptr, -32700, kj::str("message has none of params, result, or error")); + break; + + case json::RpcMessage::PARAMS: { + // a call + KJ_IF_MAYBE(method, methodMap.find(rpcMessage.getMethod())) { + auto req = interface.newRequest(*method); + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + codec.decode(rpcMessage.getParams(), req); + })) { + kj::Maybe id; + if (rpcMessage.hasId()) id = rpcMessage.getId(); + queueError(id, -32602, + kj::str("Type error in method params: ", exception->getDescription())); + break; + } + + if (rpcMessage.hasId()) { + auto id = rpcMessage.getId(); + auto idCopy = kj::heapArray(id.totalSize().wordCount + 1); + memset(idCopy.begin(), 0, idCopy.asBytes().size()); + copyToUnchecked(id, idCopy); + auto idPtr = readMessageUnchecked(idCopy.begin()); + + auto promise = req.send() + .then([this,idPtr](Response response) mutable { + MallocMessageBuilder capnpMessage; + auto jsonResponse = capnpMessage.getRoot(); + jsonResponse.setJsonrpc("2.0"); + jsonResponse.setId(idPtr); + codec.encode(DynamicStruct::Reader(response), jsonResponse.initResult()); + return queueWrite(codec.encode(jsonResponse)); + }, [this,idPtr](kj::Exception&& e) { + MallocMessageBuilder capnpMessage; + auto jsonResponse = capnpMessage.getRoot(); + jsonResponse.setJsonrpc("2.0"); + jsonResponse.setId(idPtr); + auto error = jsonResponse.initError(); + switch (e.getType()) { + case kj::Exception::Type::FAILED: + error.setCode(-32000); + break; + case kj::Exception::Type::DISCONNECTED: + error.setCode(-32001); + break; + case kj::Exception::Type::OVERLOADED: + error.setCode(-32002); + break; + case kj::Exception::Type::UNIMPLEMENTED: + error.setCode(-32601); // method not found + break; + } + error.setMessage(e.getDescription()); + return queueWrite(codec.encode(jsonResponse)); + }); + tasks.add(promise.attach(kj::mv(idCopy))); + } else { + // No 'id', so this is a notification. + tasks.add(req.send().ignoreResult().catch_([](kj::Exception&& exception) { + if (exception.getType() != kj::Exception::Type::UNIMPLEMENTED) { + KJ_LOG(ERROR, "JSON-RPC notification threw exception into the abyss", exception); + } + })); + } + } else { + if (rpcMessage.hasId()) { + queueError(rpcMessage.getId(), -32601, "Method not found"); + } else { + // Ignore notification for unknown method. + } + } + break; + } + + case json::RpcMessage::RESULT: { + auto id = rpcMessage.getId(); + if (!id.isNumber()) { + // JSON-RPC doesn't define what to do if receiving a response with an invalid id. + KJ_LOG(ERROR, "JSON-RPC response has invalid ID"); + } else KJ_IF_MAYBE(awaited, awaitedResponses.find((uint)id.getNumber())) { + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + codec.decode(rpcMessage.getResult(), awaited->context.getResults()); + awaited->fulfiller->fulfill(); + })) { + // Errors always propagate from callee to caller, so we don't want to throw this error + // back to the server. + awaited->fulfiller->reject(kj::mv(*exception)); + } + } else { + // Probably, this is the response to a call that was canceled. + } + break; + } + + case json::RpcMessage::ERROR: { + auto id = rpcMessage.getId(); + if (id.isNull()) { + // Error message will be logged by KJ_CONTEXT, above. + KJ_LOG(ERROR, "peer reports JSON-RPC protocol error"); + } else if (!id.isNumber()) { + // JSON-RPC doesn't define what to do if receiving a response with an invalid id. + KJ_LOG(ERROR, "JSON-RPC response has invalid ID"); + } else KJ_IF_MAYBE(awaited, awaitedResponses.find((uint)id.getNumber())) { + auto error = rpcMessage.getError(); + auto code = error.getCode(); + kj::Exception::Type type = + code == -32601 ? kj::Exception::Type::UNIMPLEMENTED + : kj::Exception::Type::FAILED; + awaited->fulfiller->reject(kj::Exception( + type, __FILE__, __LINE__, kj::str(error.getMessage()))); + } else { + // Probably, this is the response to a call that was canceled. + } + break; + } + } + + return readLoop(); + }); +} + +void JsonRpc::taskFailed(kj::Exception&& exception) { + errorFulfiller->reject(kj::mv(exception)); +} + +// ======================================================================================= + +JsonRpc::ContentLengthTransport::ContentLengthTransport(kj::AsyncIoStream& stream) + : stream(stream), input(kj::newHttpInputStream(stream, staticHeaderTable())) {} +JsonRpc::ContentLengthTransport::~ContentLengthTransport() noexcept(false) {} + +kj::Promise JsonRpc::ContentLengthTransport::send(kj::StringPtr text) { + auto headers = kj::str("Content-Length: ", text.size(), "\r\n\r\n"); + parts[0] = headers.asBytes(); + parts[1] = text.asBytes(); + return stream.write(parts).attach(kj::mv(headers)); +} + +kj::Promise JsonRpc::ContentLengthTransport::receive() { + return input->readMessage() + .then([](kj::HttpInputStream::Message&& message) { + auto promise = message.body->readAllText(); + return promise.attach(kj::mv(message.body)); + }); +} + +} // namespace capnp diff --git a/c++/src/capnp/compat/json-rpc.capnp b/c++/src/capnp/compat/json-rpc.capnp new file mode 100644 index 0000000000..9380788cd7 --- /dev/null +++ b/c++/src/capnp/compat/json-rpc.capnp @@ -0,0 +1,43 @@ +@0xd04299800d6725ba; + +$import "/capnp/c++.capnp".namespace("capnp::json"); + +using Json = import "json.capnp"; + +struct RpcMessage { + jsonrpc @0 :Text; + # Must always be "2.0". + + id @1 :Json.Value; + # Correlates a request to a response. Technically must be a string or number. Our implementation + # will always use a number for calls it initiates, and will reflect IDs of any type for calls + # it receives. + # + # May be omitted when caller doesn't care about the response. The implementation will omit `id` + # and return immediately when calling methods with the annotation `@notification` (defined in + # `json.capnp`). The `@notification` annotation only matters for outgoing calls; for incoming + # calls, it's the client's decision whether it wants to receive the response. + + method @2 :Text; + # Method name. Only expected when `params` is sent. + + union { + none @3 :Void $Json.name("!missing params, result, or error"); + # Dummy default value of union, to detect when none of the fields below were received. + + params @4 :Json.Value; + # Initiates a call. + + result @5 :Json.Value; + # Completes a call. + + error @6 :Error; + # Completes a call throwing an exception. + } + + struct Error { + code @0 :Int32; + message @1 :Text; + data @2 :Json.Value; + } +} diff --git a/c++/src/capnp/compat/json-rpc.h b/c++/src/capnp/compat/json-rpc.h new file mode 100644 index 0000000000..c4d3b99700 --- /dev/null +++ b/c++/src/capnp/compat/json-rpc.h @@ -0,0 +1,116 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "json.h" +#include +#include +#include + +CAPNP_BEGIN_HEADER + +namespace kj { class HttpInputStream; } + +namespace capnp { + +class JsonRpc: private kj::TaskSet::ErrorHandler { + // An implementation of JSON-RPC 2.0: https://www.jsonrpc.org/specification + // + // This allows you to use Cap'n Proto interface declarations to implement JSON-RPC protocols. + // Of course, JSON-RPC does not support capabilities. So, the client and server each expose + // exactly one object to the other. + +public: + class Transport; + class ContentLengthTransport; + + JsonRpc(Transport& transport, DynamicCapability::Client interface = {}); + KJ_DISALLOW_COPY_AND_MOVE(JsonRpc); + + DynamicCapability::Client getPeer(InterfaceSchema schema); + + template + typename T::Client getPeer() { + return getPeer(Schema::from()).template castAs(); + } + + kj::Promise onError() { return errorPromise.addBranch(); } + +private: + JsonCodec codec; + Transport& transport; + DynamicCapability::Client interface; + kj::HashMap methodMap; + uint callCount = 0; + kj::Promise writeQueue = kj::READY_NOW; + kj::ForkedPromise errorPromise; + kj::Own> errorFulfiller; + kj::Promise readTask; + + struct AwaitedResponse { + CallContext context; + kj::Own> fulfiller; + }; + kj::HashMap awaitedResponses; + + kj::TaskSet tasks; + + class CapabilityImpl; + + kj::Promise queueWrite(kj::String text); + void queueError(kj::Maybe id, int code, kj::StringPtr message); + + kj::Promise readLoop(); + + void taskFailed(kj::Exception&& exception) override; + + JsonRpc(Transport& transport, DynamicCapability::Client interface, + kj::PromiseFulfillerPair paf); +}; + +class JsonRpc::Transport { +public: + virtual kj::Promise send(kj::StringPtr text) = 0; + virtual kj::Promise receive() = 0; +}; + +class JsonRpc::ContentLengthTransport: public Transport { + // The transport used by Visual Studio Code: Each message is composed like an HTTP message + // without the first line. That is, a list of headers, followed by a blank line, followed by the + // content whose length is determined by the content-length header. +public: + explicit ContentLengthTransport(kj::AsyncIoStream& stream); + ~ContentLengthTransport() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(ContentLengthTransport); + + kj::Promise send(kj::StringPtr text) override; + kj::Promise receive() override; + +private: + kj::AsyncIoStream& stream; + kj::Own input; + kj::ArrayPtr parts[2]; +}; + +} // namespace capnp + +CAPNP_END_HEADER diff --git a/c++/src/capnp/compat/json-test.c++ b/c++/src/capnp/compat/json-test.c++ index d1b0031bb9..550ecafa7d 100644 --- a/c++/src/capnp/compat/json-test.c++ +++ b/c++/src/capnp/compat/json-test.c++ @@ -22,6 +22,7 @@ #include "json.h" #include #include +#include #include #include #include @@ -49,6 +50,17 @@ KJ_TEST("basic json encoding") { KJ_EXPECT(json.encode(Data::Reader(bytes, 3)) == "[12, 34, 56]"); } +KJ_TEST("raw encoding") { + JsonCodec json; + + auto text = kj::str("{\"field\":\"value\"}"); + MallocMessageBuilder message; + auto value = message.initRoot(); + value.setRaw(text); + + KJ_EXPECT(json.encodeRaw(value) == text); +} + const char ALL_TYPES_JSON[] = "{ \"voidField\": null,\n" " \"boolField\": true,\n" @@ -185,13 +197,20 @@ KJ_TEST("encode union") { KJ_TEST("decode all types") { JsonCodec json; -#define CASE(s, f) \ + json.setHasMode(HasMode::NON_DEFAULT); + +#define CASE_MAYBE_ROUNDTRIP(s, f, roundtrip) \ { \ MallocMessageBuilder message; \ auto root = message.initRoot(); \ - json.decode(s, root); \ - KJ_EXPECT((f)) \ + kj::StringPtr input = s; \ + json.decode(input, root); \ + KJ_EXPECT((f), input, root); \ + auto reencoded = json.encode(root); \ + KJ_EXPECT(roundtrip == (input == reencoded), roundtrip, input, reencoded); \ } +#define CASE_NO_ROUNDTRIP(s, f) CASE_MAYBE_ROUNDTRIP(s, f, false) +#define CASE(s, f) CASE_MAYBE_ROUNDTRIP(s, f, true) #define CASE_THROW(s, errorMessage) \ { \ MallocMessageBuilder message; \ @@ -206,113 +225,129 @@ KJ_TEST("decode all types") { } CASE(R"({})", root.getBoolField() == false); - CASE(R"({"unknownField":7})", root.getBoolField() == false); + CASE_NO_ROUNDTRIP(R"({"unknownField":7})", root.getBoolField() == false); CASE(R"({"boolField":true})", root.getBoolField() == true); CASE(R"({"int8Field":-128})", root.getInt8Field() == -128); - CASE(R"({"int8Field":"127"})", root.getInt8Field() == 127); + CASE_NO_ROUNDTRIP(R"({"int8Field":"127"})", root.getInt8Field() == 127); CASE_THROW_RECOVERABLE(R"({"int8Field":"-129"})", "Value out-of-range"); CASE_THROW_RECOVERABLE(R"({"int8Field":128})", "Value out-of-range"); CASE(R"({"int16Field":-32768})", root.getInt16Field() == -32768); - CASE(R"({"int16Field":"32767"})", root.getInt16Field() == 32767); + CASE_NO_ROUNDTRIP(R"({"int16Field":"32767"})", root.getInt16Field() == 32767); CASE_THROW_RECOVERABLE(R"({"int16Field":"-32769"})", "Value out-of-range"); CASE_THROW_RECOVERABLE(R"({"int16Field":32768})", "Value out-of-range"); CASE(R"({"int32Field":-2147483648})", root.getInt32Field() == -2147483648); - CASE(R"({"int32Field":"2147483647"})", root.getInt32Field() == 2147483647); - CASE(R"({"int64Field":-9007199254740992})", root.getInt64Field() == -9007199254740992LL); - CASE(R"({"int64Field":9007199254740991})", root.getInt64Field() == 9007199254740991LL); + CASE_NO_ROUNDTRIP(R"({"int32Field":"2147483647"})", root.getInt32Field() == 2147483647); + CASE_NO_ROUNDTRIP(R"({"int64Field":-9007199254740992})", root.getInt64Field() == -9007199254740992LL); + CASE_NO_ROUNDTRIP(R"({"int64Field":9007199254740991})", root.getInt64Field() == 9007199254740991LL); CASE(R"({"int64Field":"-9223372036854775808"})", root.getInt64Field() == -9223372036854775808ULL); CASE(R"({"int64Field":"9223372036854775807"})", root.getInt64Field() == 9223372036854775807LL); CASE_THROW_RECOVERABLE(R"({"int64Field":"-9223372036854775809"})", "Value out-of-range"); CASE_THROW_RECOVERABLE(R"({"int64Field":"9223372036854775808"})", "Value out-of-range"); CASE(R"({"uInt8Field":255})", root.getUInt8Field() == 255); - CASE(R"({"uInt8Field":"0"})", root.getUInt8Field() == 0); + CASE_NO_ROUNDTRIP(R"({"uInt8Field":"0"})", root.getUInt8Field() == 0); CASE_THROW_RECOVERABLE(R"({"uInt8Field":"256"})", "Value out-of-range"); CASE_THROW_RECOVERABLE(R"({"uInt8Field":-1})", "Value out-of-range"); CASE(R"({"uInt16Field":65535})", root.getUInt16Field() == 65535); - CASE(R"({"uInt16Field":"0"})", root.getUInt16Field() == 0); + CASE_NO_ROUNDTRIP(R"({"uInt16Field":"0"})", root.getUInt16Field() == 0); CASE_THROW_RECOVERABLE(R"({"uInt16Field":"655356"})", "Value out-of-range"); CASE_THROW_RECOVERABLE(R"({"uInt16Field":-1})", "Value out-of-range"); CASE(R"({"uInt32Field":4294967295})", root.getUInt32Field() == 4294967295); - CASE(R"({"uInt32Field":"0"})", root.getUInt32Field() == 0); + CASE_NO_ROUNDTRIP(R"({"uInt32Field":"0"})", root.getUInt32Field() == 0); CASE_THROW_RECOVERABLE(R"({"uInt32Field":"42949672956"})", "Value out-of-range"); CASE_THROW_RECOVERABLE(R"({"uInt32Field":-1})", "Value out-of-range"); - CASE(R"({"uInt64Field":9007199254740991})", root.getUInt64Field() == 9007199254740991ULL); + CASE_NO_ROUNDTRIP(R"({"uInt64Field":9007199254740991})", root.getUInt64Field() == 9007199254740991ULL); CASE(R"({"uInt64Field":"18446744073709551615"})", root.getUInt64Field() == 18446744073709551615ULL); - CASE(R"({"uInt64Field":"0"})", root.getUInt64Field() == 0); + CASE_NO_ROUNDTRIP(R"({"uInt64Field":"0"})", root.getUInt64Field() == 0); CASE_THROW_RECOVERABLE(R"({"uInt64Field":"18446744073709551616"})", "Value out-of-range"); - CASE(R"({"float32Field":0})", root.getFloat32Field() == 0); + CASE_NO_ROUNDTRIP(R"({"float32Field":0})", root.getFloat32Field() == 0); CASE(R"({"float32Field":4.5})", root.getFloat32Field() == 4.5); - CASE(R"({"float32Field":null})", kj::isNaN(root.getFloat32Field())); - CASE(R"({"float32Field":"nan"})", kj::isNaN(root.getFloat32Field())); - CASE(R"({"float32Field":"nan"})", kj::isNaN(root.getFloat32Field())); + CASE_NO_ROUNDTRIP(R"({"float32Field":null})", kj::isNaN(root.getFloat32Field())); + CASE(R"({"float32Field":"NaN"})", kj::isNaN(root.getFloat32Field())); + CASE_NO_ROUNDTRIP(R"({"float32Field":"nan"})", kj::isNaN(root.getFloat32Field())); CASE(R"({"float32Field":"Infinity"})", root.getFloat32Field() == kj::inf()); CASE(R"({"float32Field":"-Infinity"})", root.getFloat32Field() == -kj::inf()); - CASE(R"({"float64Field":0})", root.getFloat64Field() == 0); + CASE_NO_ROUNDTRIP(R"({"float32Field":"infinity"})", root.getFloat32Field() == kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float32Field":"-infinity"})", root.getFloat32Field() == -kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float32Field":"INF"})", root.getFloat32Field() == kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float32Field":"-INF"})", root.getFloat32Field() == -kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float32Field":1e39})", root.getFloat32Field() == kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float32Field":-1e39})", root.getFloat32Field() == -kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float64Field":0})", root.getFloat64Field() == 0); CASE(R"({"float64Field":4.5})", root.getFloat64Field() == 4.5); - CASE(R"({"float64Field":null})", kj::isNaN(root.getFloat64Field())); - CASE(R"({"float64Field":"nan"})", kj::isNaN(root.getFloat64Field())); - CASE(R"({"float64Field":"nan"})", kj::isNaN(root.getFloat64Field())); + CASE_NO_ROUNDTRIP(R"({"float64Field":null})", kj::isNaN(root.getFloat64Field())); + CASE(R"({"float64Field":"NaN"})", kj::isNaN(root.getFloat64Field())); + CASE_NO_ROUNDTRIP(R"({"float64Field":"nan"})", kj::isNaN(root.getFloat64Field())); CASE(R"({"float64Field":"Infinity"})", root.getFloat64Field() == kj::inf()); - CASE(R"({"float64Field":"-Infinity"})", root.getFloat64Field() == -kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float64Field":"infinity"})", root.getFloat64Field() == kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float64Field":"-infinity"})", root.getFloat64Field() == -kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float64Field":"INF"})", root.getFloat64Field() == kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float64Field":"-INF"})", root.getFloat64Field() == -kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float64Field":1e309})", root.getFloat64Field() == kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float64Field":-1e309})", root.getFloat64Field() == -kj::inf()); CASE(R"({"textField":"hello"})", kj::str("hello") == root.getTextField()); CASE(R"({"dataField":[7,0,122]})", kj::heapArray({7,0,122}).asPtr() == root.getDataField()); - CASE(R"({"structField":null})", root.hasStructField() == false); CASE(R"({"structField":{}})", root.hasStructField() == true); CASE(R"({"structField":{}})", root.getStructField().getBoolField() == false); - CASE(R"({"structField":{"boolField":false}})", root.getStructField().getBoolField() == false); + CASE_NO_ROUNDTRIP(R"({"structField":{"boolField":false}})", root.getStructField().getBoolField() == false); CASE(R"({"structField":{"boolField":true}})", root.getStructField().getBoolField() == true); CASE(R"({"enumField":"bar"})", root.getEnumField() == TestEnum::BAR); + CASE_NO_ROUNDTRIP(R"({"textField":"foo\u1234bar"})", + kj::str(u8"foo\u1234bar") == root.getTextField()); + + CASE_THROW_RECOVERABLE(R"({"structField":null})", "Expected object value"); + CASE_THROW_RECOVERABLE(R"({"structList":null})", "Expected list value"); + CASE_THROW_RECOVERABLE(R"({"boolList":null})", "Expected list value"); + CASE_THROW_RECOVERABLE(R"({"structList":[null]})", "Expected object value"); CASE_THROW_RECOVERABLE(R"({"int64Field":"177a"})", "String does not contain valid"); CASE_THROW_RECOVERABLE(R"({"uInt64Field":"177a"})", "String does not contain valid"); CASE_THROW_RECOVERABLE(R"({"float64Field":"177a"})", "String does not contain valid"); CASE(R"({})", root.hasBoolList() == false); - CASE(R"({"boolList":null})", root.hasBoolList() == false); CASE(R"({"boolList":[]})", root.hasBoolList() == true); CASE(R"({"boolList":[]})", root.getBoolList().size() == 0); CASE(R"({"boolList":[false]})", root.getBoolList().size() == 1); CASE(R"({"boolList":[false]})", root.getBoolList()[0] == false); CASE(R"({"boolList":[true]})", root.getBoolList()[0] == true); CASE(R"({"int8List":[7]})", root.getInt8List()[0] == 7); - CASE(R"({"int8List":["7"]})", root.getInt8List()[0] == 7); + CASE_NO_ROUNDTRIP(R"({"int8List":["7"]})", root.getInt8List()[0] == 7); CASE(R"({"int16List":[7]})", root.getInt16List()[0] == 7); - CASE(R"({"int16List":["7"]})", root.getInt16List()[0] == 7); + CASE_NO_ROUNDTRIP(R"({"int16List":["7"]})", root.getInt16List()[0] == 7); CASE(R"({"int32List":[7]})", root.getInt32List()[0] == 7); - CASE(R"({"int32List":["7"]})", root.getInt32List()[0] == 7); - CASE(R"({"int64List":[7]})", root.getInt64List()[0] == 7); + CASE_NO_ROUNDTRIP(R"({"int32List":["7"]})", root.getInt32List()[0] == 7); + CASE_NO_ROUNDTRIP(R"({"int64List":[7]})", root.getInt64List()[0] == 7); CASE(R"({"int64List":["7"]})", root.getInt64List()[0] == 7); CASE(R"({"uInt8List":[7]})", root.getUInt8List()[0] == 7); - CASE(R"({"uInt8List":["7"]})", root.getUInt8List()[0] == 7); + CASE_NO_ROUNDTRIP(R"({"uInt8List":["7"]})", root.getUInt8List()[0] == 7); CASE(R"({"uInt16List":[7]})", root.getUInt16List()[0] == 7); - CASE(R"({"uInt16List":["7"]})", root.getUInt16List()[0] == 7); + CASE_NO_ROUNDTRIP(R"({"uInt16List":["7"]})", root.getUInt16List()[0] == 7); CASE(R"({"uInt32List":[7]})", root.getUInt32List()[0] == 7); - CASE(R"({"uInt32List":["7"]})", root.getUInt32List()[0] == 7); - CASE(R"({"uInt64List":[7]})", root.getUInt64List()[0] == 7); + CASE_NO_ROUNDTRIP(R"({"uInt32List":["7"]})", root.getUInt32List()[0] == 7); + CASE_NO_ROUNDTRIP(R"({"uInt64List":[7]})", root.getUInt64List()[0] == 7); CASE(R"({"uInt64List":["7"]})", root.getUInt64List()[0] == 7); CASE(R"({"float32List":[4.5]})", root.getFloat32List()[0] == 4.5); - CASE(R"({"float32List":["4.5"]})", root.getFloat32List()[0] == 4.5); - CASE(R"({"float32List":[null]})", kj::isNaN(root.getFloat32List()[0])); - CASE(R"({"float32List":["nan"]})", kj::isNaN(root.getFloat32List()[0])); - CASE(R"({"float32List":["infinity"]})", root.getFloat32List()[0] == kj::inf()); - CASE(R"({"float32List":["-infinity"]})", root.getFloat32List()[0] == -kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float32List":["4.5"]})", root.getFloat32List()[0] == 4.5); + CASE_NO_ROUNDTRIP(R"({"float32List":[null]})", kj::isNaN(root.getFloat32List()[0])); + CASE(R"({"float32List":["NaN"]})", kj::isNaN(root.getFloat32List()[0])); + CASE(R"({"float32List":["Infinity"]})", root.getFloat32List()[0] == kj::inf()); + CASE(R"({"float32List":["-Infinity"]})", root.getFloat32List()[0] == -kj::inf()); CASE(R"({"float64List":[4.5]})", root.getFloat64List()[0] == 4.5); - CASE(R"({"float64List":["4.5"]})", root.getFloat64List()[0] == 4.5); - CASE(R"({"float64List":[null]})", kj::isNaN(root.getFloat64List()[0])); - CASE(R"({"float64List":["nan"]})", kj::isNaN(root.getFloat64List()[0])); - CASE(R"({"float64List":["infinity"]})", root.getFloat64List()[0] == kj::inf()); - CASE(R"({"float64List":["-infinity"]})", root.getFloat64List()[0] == -kj::inf()); + CASE_NO_ROUNDTRIP(R"({"float64List":["4.5"]})", root.getFloat64List()[0] == 4.5); + CASE_NO_ROUNDTRIP(R"({"float64List":[null]})", kj::isNaN(root.getFloat64List()[0])); + CASE(R"({"float64List":["NaN"]})", kj::isNaN(root.getFloat64List()[0])); + CASE(R"({"float64List":["Infinity"]})", root.getFloat64List()[0] == kj::inf()); + CASE(R"({"float64List":["-Infinity"]})", root.getFloat64List()[0] == -kj::inf()); CASE(R"({"textList":["hello"]})", kj::str("hello") == root.getTextList()[0]); CASE(R"({"dataList":[[7,0,122]]})", kj::heapArray({7,0,122}).asPtr() == root.getDataList()[0]); - CASE(R"({"structList":null})", root.hasStructList() == false); - CASE(R"({"structList":[null]})", root.hasStructList() == true); - CASE(R"({"structList":[null]})", root.getStructList()[0].getBoolField() == false); + CASE(R"({"structList":[{}]})", root.hasStructList() == true); CASE(R"({"structList":[{}]})", root.getStructList()[0].getBoolField() == false); - CASE(R"({"structList":[{"boolField":false}]})", root.getStructList()[0].getBoolField() == false); + CASE_NO_ROUNDTRIP(R"({"structList":[{"boolField":false}]})", root.getStructList()[0].getBoolField() == false); CASE(R"({"structList":[{"boolField":true}]})", root.getStructList()[0].getBoolField() == true); CASE(R"({"enumList":["bar"]})", root.getEnumList()[0] == TestEnum::BAR); +#undef CASE_MAYBE_ROUNDTRIP +#undef CASE_NO_ROUNDTRIP #undef CASE #undef CASE_THROW #undef CASE_THROW_RECOVERABLE @@ -515,20 +550,6 @@ KJ_TEST("basic json decoding") { KJ_EXPECT_THROW_MESSAGE("Unexpected", json.decodeRaw("+123", root)); } - { - MallocMessageBuilder message; - auto root = message.initRoot(); - - KJ_EXPECT_THROW_MESSAGE("Overflow", json.decodeRaw("1e1024", root)); - } - - { - MallocMessageBuilder message; - auto root = message.initRoot(); - - KJ_EXPECT_THROW_MESSAGE("Underflow", json.decodeRaw("1e-1023", root)); - } - { MallocMessageBuilder message; auto root = message.initRoot(); @@ -596,6 +617,17 @@ KJ_TEST("basic json decoding") { KJ_EXPECT_THROW_MESSAGE("Unexpected input", json.decodeRaw("\f{}", root)); KJ_EXPECT_THROW_MESSAGE("Unexpected input", json.decodeRaw("{\v}", root)); } + + { + MallocMessageBuilder message; + auto root = message.initRoot(); + + json.decodeRaw(R"("\u007f")", root); + KJ_EXPECT(root.which() == JsonValue::STRING); + + char utf_buffer[] = {127, 0}; + KJ_EXPECT(kj::str(utf_buffer) == root.getString()); + } } KJ_TEST("maximum nesting depth") { @@ -637,7 +669,24 @@ KJ_TEST("maximum nesting depth") { } } -class TestHandler: public JsonCodec::Handler { +KJ_TEST("unknown fields") { + JsonCodec json; + MallocMessageBuilder message; + auto root = message.initRoot(); + auto valid = R"({"foo": "a"})"_kj; + auto unknown = R"({"foo": "a", "unknown-field": "b"})"_kj; + json.decode(valid, root); + json.decode(unknown, root); + json.setRejectUnknownFields(true); + json.decode(valid, root); + KJ_EXPECT_THROW_MESSAGE("Unknown field", json.decode(unknown, root)); + + // Verify unknown field rejection still works when handling by annotation. + json.handleByAnnotation(); + KJ_EXPECT_THROW_MESSAGE("Unknown field", json.decode(unknown, root)); +} + +class TestCallHandler: public JsonCodec::Handler { public: void encode(const JsonCodec& codec, Text::Reader input, JsonValue::Builder output) const override { @@ -654,31 +703,144 @@ public: } }; -KJ_TEST("register handler") { - MallocMessageBuilder message; - auto root = message.getRoot(); +class TestDynamicStructHandler: public JsonCodec::Handler { +public: + void encode(const JsonCodec& codec, DynamicStruct::Reader input, + JsonValue::Builder output) const override { + auto fields = input.getSchema().getFields(); + auto items = output.initArray(fields.size()); + for (auto field: fields) { + KJ_REQUIRE(field.getIndex() < items.size()); + auto item = items[field.getIndex()]; + if (input.has(field)) { + codec.encode(input.get(field), field.getType(), item); + } else { + item.setNull(); + } + } + } + + void decode(const JsonCodec& codec, JsonValue::Reader input, + DynamicStruct::Builder output) const override { + auto orphanage = Orphanage::getForMessageContaining(output); + auto fields = output.getSchema().getFields(); + auto items = input.getArray(); + for (auto field: fields) { + KJ_REQUIRE(field.getIndex() < items.size()); + auto item = items[field.getIndex()]; + if (!item.isNull()) { + output.adopt(field, codec.decode(item, field.getType(), orphanage)); + } + } + } +}; + +class TestStructHandler: public JsonCodec::Handler { +public: + void encode(const JsonCodec& codec, test::TestOldVersion::Reader input, JsonValue::Builder output) const override { + dynamicHandler.encode(codec, input, output); + } + + void decode(const JsonCodec& codec, JsonValue::Reader input, test::TestOldVersion::Builder output) const override { + dynamicHandler.decode(codec, input, output); + } + +private: + TestDynamicStructHandler dynamicHandler; +}; - TestHandler handler; +KJ_TEST("register custom encoding handlers") { JsonCodec json; - json.addTypeHandler(handler); + TestStructHandler structHandler; + json.addTypeHandler(structHandler); + + // JSON decoder can't parse calls back, so test only encoder here + TestCallHandler callHandler; + json.addTypeHandler(callHandler); + + MallocMessageBuilder message; + auto root = message.getRoot(); root.setOld1(123); root.setOld2("foo"); - KJ_EXPECT(json.encode(root) == "{\"old1\":\"123\",\"old2\":Frob(123,\"foo\")}"); + + KJ_EXPECT(json.encode(root) == "[\"123\",Frob(123,\"foo\"),null]"); } -KJ_TEST("register field handler") { - MallocMessageBuilder message; - auto root = message.getRoot(); +KJ_TEST("register custom roundtrip handler") { + for (auto i = 1; i <= 2; i++) { + JsonCodec json; + TestStructHandler staticHandler; + TestDynamicStructHandler dynamicHandler; + kj::String encoded; + + if (i == 1) { + // first iteration: test with explicit struct handler + json.addTypeHandler(staticHandler); + } else { + // second iteration: same checks, but with DynamicStruct handler + json.addTypeHandler(StructSchema::from(), dynamicHandler); + } - TestHandler handler; + { + MallocMessageBuilder message; + auto root = message.getRoot(); + root.setOld1(123); + root.initOld3().setOld2("foo"); + + encoded = json.encode(root); + + KJ_EXPECT(encoded == "[\"123\",null,[\"0\",\"foo\",null]]"); + } + + { + MallocMessageBuilder message; + auto root = message.getRoot(); + json.decode(encoded, root); + + KJ_EXPECT(root.getOld1() == 123); + KJ_EXPECT(!root.hasOld2()); + auto nested = root.getOld3(); + KJ_EXPECT(nested.getOld1() == 0); + KJ_EXPECT("foo" == nested.getOld2()); + KJ_EXPECT(!nested.hasOld3()); + } + } +} + +KJ_TEST("register field handler") { + TestStructHandler handler; JsonCodec json; - json.addFieldHandler(StructSchema::from().getFieldByName("corge"), + json.addFieldHandler(StructSchema::from().getFieldByName("old3"), handler); - root.setBaz("abcd"); - root.setCorge("efg"); - KJ_EXPECT(json.encode(root) == "{\"corge\":Frob(123,\"efg\"),\"baz\":\"abcd\"}"); + kj::String encoded; + + { + MallocMessageBuilder message; + auto root = message.getRoot(); + root.setOld1(123); + root.setOld2("foo"); + auto nested = root.initOld3(); + nested.setOld2("bar"); + + encoded = json.encode(root); + + KJ_EXPECT(encoded == "{\"old1\":\"123\",\"old2\":\"foo\",\"old3\":[\"0\",\"bar\",null]}") + } + + { + MallocMessageBuilder message; + auto root = message.getRoot(); + json.decode(encoded, root); + + KJ_EXPECT(root.getOld1() == 123); + KJ_EXPECT("foo" == root.getOld2()); + auto nested = root.getOld3(); + KJ_EXPECT(nested.getOld1() == 0); + KJ_EXPECT("bar" == nested.getOld2()); + KJ_EXPECT(!nested.hasOld3()); + } } class TestCapabilityHandler: public JsonCodec::Handler { @@ -703,27 +865,168 @@ KJ_TEST("register capability handler") { json.addTypeHandler(handler); } -class TestDynamicStructHandler: public JsonCodec::Handler { +static constexpr kj::StringPtr GOLDEN_ANNOTATED = +R"({ "names-can_contain!anything Really": "foo", + "flatFoo": 123, + "flatBar": "abc", + "renamed-flatBaz": {"hello": true}, + "flatQux": "cba", + "pfx.foo": "this is a long string in order to force multi-line pretty printing", + "pfx.renamed-bar": 321, + "pfx.baz": {"hello": true}, + "pfx.xfp.qux": "fed", + "union-type": "renamed-bar", + "barMember": 789, + "multiMember": "ghi", + "dependency": {"renamed-foo": "corge"}, + "simpleGroup": {"renamed-grault": "garply"}, + "enums": ["qux", "renamed-bar", "foo", "renamed-baz"], + "innerJson": [123, "hello", {"object": true}], + "customFieldHandler": "add-prefix-waldo", + "testBase64": "ZnJlZA==", + "testHex": "706c756768", + "bUnion": "renamed-bar", + "bValue": 678, + "externalUnion": {"type": "bar", "value": "cba"}, + "unionWithVoid": {"type": "voidValue"} })"_kj; + +static constexpr kj::StringPtr GOLDEN_ANNOTATED_REVERSE = +R"({ + "unionWithVoid": {"type": "voidValue"}, + "externalUnion": {"type": "bar", "value": "cba"}, + "bValue": 678, + "bUnion": "renamed-bar", + "testHex": "706c756768", + "testBase64": "ZnJlZA==", + "customFieldHandler": "add-prefix-waldo", + "innerJson": [123, "hello", {"object": true}], + "enums": ["qux", "renamed-bar", "foo", "renamed-baz"], + "simpleGroup": { "renamed-grault": "garply" }, + "dependency": { "renamed-foo": "corge" }, + "multiMember": "ghi", + "barMember": 789, + "union-type": "renamed-bar", + "pfx.xfp.qux": "fed", + "pfx.baz": {"hello": true}, + "pfx.renamed-bar": 321, + "pfx.foo": "this is a long string in order to force multi-line pretty printing", + "flatQux": "cba", + "renamed-flatBaz": {"hello": true}, + "flatBar": "abc", + "flatFoo": 123, + "names-can_contain!anything Really": "foo" +})"_kj; + +class PrefixAdder: public JsonCodec::Handler { public: - void encode(const JsonCodec& codec, DynamicStruct::Reader input, - JsonValue::Builder output) const override { - KJ_UNIMPLEMENTED("TestDynamicStructHandler::encode"); + void encode(const JsonCodec& codec, capnp::Text::Reader input, JsonValue::Builder output) const { + output.setString(kj::str("add-prefix-", input)); } - void decode(const JsonCodec& codec, JsonValue::Reader input, - DynamicStruct::Builder output) const override { - KJ_UNIMPLEMENTED("TestDynamicStructHandler::decode"); + Orphan decode(const JsonCodec& codec, JsonValue::Reader input, + Orphanage orphanage) const { + return orphanage.newOrphanCopy(capnp::Text::Reader(input.getString().slice(11))); } }; +KJ_TEST("rename fields") { + JsonCodec json; + json.handleByAnnotation(); + json.setPrettyPrint(true); + + PrefixAdder customHandler; + json.addFieldHandler(Schema::from().getFieldByName("customFieldHandler"), + customHandler); -KJ_TEST("register DynamicStruct handler") { - // This test currently only checks that this compiles, which at one point wasn't the caes. - // TODO(test): Actually run some code here. + kj::String goldenText; + + { + MallocMessageBuilder message; + auto root = message.getRoot(); + root.setSomeField("foo"); + + auto aGroup = root.getAGroup(); + aGroup.setFlatFoo(123); + aGroup.setFlatBar("abc"); + aGroup.getFlatBaz().setHello(true); + aGroup.getDoubleFlat().setFlatQux("cba"); + + auto prefixedGroup = root.getPrefixedGroup(); + prefixedGroup.setFoo("this is a long string in order to force multi-line pretty printing"); + prefixedGroup.setBar(321); + prefixedGroup.getBaz().setHello(true); + prefixedGroup.getMorePrefix().setQux("fed"); + + auto unionBar = root.getAUnion().initBar(); + unionBar.setBarMember(789); + unionBar.setMultiMember("ghi"); + + root.initDependency().setFoo("corge"); + root.initSimpleGroup().setGrault("garply"); + + root.setEnums({ + TestJsonAnnotatedEnum::QUX, + TestJsonAnnotatedEnum::BAR, + TestJsonAnnotatedEnum::FOO, + TestJsonAnnotatedEnum::BAZ + }); + + auto val = root.initInnerJson(); + auto arr = val.initArray(3); + arr[0].setNumber(123); + arr[1].setString("hello"); + auto field = arr[2].initObject(1)[0]; + field.setName("object"); + field.initValue().setBoolean(true); + + root.setCustomFieldHandler("waldo"); + + root.setTestBase64("fred"_kj.asBytes()); + root.setTestHex("plugh"_kj.asBytes()); + + root.getBUnion().setBar(678); + + root.initExternalUnion().initBar().setValue("cba"); + + root.initUnionWithVoid().setVoidValue(); + + auto encoded = json.encode(root.asReader()); + KJ_EXPECT(encoded == GOLDEN_ANNOTATED, encoded); + + goldenText = kj::str(root); + } + + { + MallocMessageBuilder message; + auto root = message.getRoot(); + json.decode(GOLDEN_ANNOTATED, root); + + KJ_EXPECT(kj::str(root) == goldenText, root, goldenText); + } + + { + // Try parsing in reverse, mostly to test that union tags can come after content. + MallocMessageBuilder message; + auto root = message.getRoot(); + json.decode(GOLDEN_ANNOTATED_REVERSE, root); + + KJ_EXPECT(kj::str(root) == goldenText, root, goldenText); + } +} + +KJ_TEST("base64 union encoded correctly") { + // At one point field handlers were not correctly applied when the field was a member of a union + // in a type that was handled by annotation. - TestDynamicStructHandler handler; JsonCodec json; - json.addTypeHandler(Schema::from(), handler); + json.handleByAnnotation(); + json.setPrettyPrint(true); + + MallocMessageBuilder message; + auto root = message.getRoot(); + root.initFoo(5); + + KJ_EXPECT(json.encode(root) == "{\"foo\": \"AAAAAAA=\"}", json.encode(root)); } } // namespace diff --git a/c++/src/capnp/compat/json-test.capnp b/c++/src/capnp/compat/json-test.capnp new file mode 100644 index 0000000000..4406fd24b5 --- /dev/null +++ b/c++/src/capnp/compat/json-test.capnp @@ -0,0 +1,123 @@ +# Copyright (c) 2018 Cloudflare, Inc. and contributors +# Licensed under the MIT License: +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +@0xc9d405cf4333e4c9; + +using Json = import "/capnp/compat/json.capnp"; + +$import "/capnp/c++.capnp".namespace("capnp"); + +struct TestJsonAnnotations { + someField @0 :Text $Json.name("names-can_contain!anything Really"); + + aGroup :group $Json.flatten() { + flatFoo @1 :UInt32; + flatBar @2 :Text; + flatBaz :group $Json.name("renamed-flatBaz") { + hello @3 :Bool; + } + doubleFlat :group $Json.flatten() { + flatQux @4 :Text; + } + } + + prefixedGroup :group $Json.flatten(prefix = "pfx.") { + foo @5 :Text; + bar @6 :UInt32 $Json.name("renamed-bar"); + baz :group { + hello @7 :Bool; + } + morePrefix :group $Json.flatten(prefix = "xfp.") { + qux @8 :Text; + } + } + + aUnion :union $Json.flatten() $Json.discriminator(name = "union-type") { + foo :group $Json.flatten() { + fooMember @9 :Text; + multiMember @10 :UInt32; + } + bar :group $Json.flatten() $Json.name("renamed-bar") { + barMember @11 :UInt32; + multiMember @12 :Text; + } + } + + dependency @13 :TestJsonAnnotations2; + # To test that dependencies are loaded even if not flattened. + + simpleGroup :group { + # To test that group types are loaded even if not flattened. + grault @14 :Text $Json.name("renamed-grault"); + } + + enums @15 :List(TestJsonAnnotatedEnum); + + innerJson @16 :Json.Value; + + customFieldHandler @17 :Text; + + testBase64 @18 :Data $Json.base64; + testHex @19 :Data $Json.hex; + + bUnion :union $Json.flatten() $Json.discriminator(valueName = "bValue") { + foo @20 :Text; + bar @21 :UInt32 $Json.name("renamed-bar"); + } + + externalUnion @22 :TestJsonAnnotations3; + + unionWithVoid :union $Json.discriminator(name = "type") { + intValue @23 :UInt32; + voidValue @24 :Void; + textValue @25 :Text; + } +} + +struct TestJsonAnnotations2 { + foo @0 :Text $Json.name("renamed-foo"); + cycle @1 :TestJsonAnnotations; +} + +struct TestJsonAnnotations3 $Json.discriminator(name = "type") { + union { + foo @0 :UInt32; + bar @1 :TestFlattenedStruct $Json.flatten(); + } +} + +struct TestFlattenedStruct { + value @0 :Text; +} + +enum TestJsonAnnotatedEnum { + foo @0; + bar @1 $Json.name("renamed-bar"); + baz @2 $Json.name("renamed-baz"); + qux @3; +} + +struct TestBase64Union { + union { + foo @0 :Data $Json.base64; + bar @1 :Text; + } +} diff --git a/c++/src/capnp/compat/json.c++ b/c++/src/capnp/compat/json.c++ index 1aca63fad2..83a522f4ee 100644 --- a/c++/src/capnp/compat/json.c++ +++ b/c++/src/capnp/compat/json.c++ @@ -20,39 +20,26 @@ // THE SOFTWARE. #include "json.h" -#include // for HUGEVAL to check for overflow in strtod -#include // strtod -#include // for strtod errors -#include #include #include #include #include +#include +#include +#include namespace capnp { -namespace { - -struct TypeHash { - size_t operator()(const Type& type) const { - return type.hashCode(); - } -}; - -struct FieldHash { - size_t operator()(const StructSchema::Field& field) const { - return field.getIndex() ^ field.getContainingStruct().getProto().getId(); - } -}; - -} // namespace - struct JsonCodec::Impl { bool prettyPrint = false; + HasMode hasMode = HasMode::NON_NULL; size_t maxNestingDepth = 64; + bool rejectUnknownFields = false; - std::unordered_map typeHandlers; - std::unordered_map fieldHandlers; + kj::HashMap typeHandlers; + kj::HashMap fieldHandlers; + kj::HashMap>> annotatedHandlers; + kj::HashMap> annotatedEnumHandlers; kj::StringTree encodeRaw(JsonValue::Reader value, uint indent, bool& multiline, bool hasPrefix) const { @@ -106,6 +93,10 @@ struct JsonCodec::Impl { return kj::strTree(call.getFunction(), '(', encodeList( kj::mv(encodedElements), childMultiline, indent, multiline, true), ')'); } + + case JsonValue::RAW: { + return kj::strTree(value.getRaw()); + } } KJ_FAIL_ASSERT("unknown JsonValue type", static_cast(value.which())); @@ -120,14 +111,13 @@ struct JsonCodec::Impl { switch (c) { case '\"': escaped.addAll(kj::StringPtr("\\\"")); break; case '\\': escaped.addAll(kj::StringPtr("\\\\")); break; - case '/' : escaped.addAll(kj::StringPtr("\\/" )); break; case '\b': escaped.addAll(kj::StringPtr("\\b")); break; case '\f': escaped.addAll(kj::StringPtr("\\f")); break; case '\n': escaped.addAll(kj::StringPtr("\\n")); break; case '\r': escaped.addAll(kj::StringPtr("\\r")); break; case '\t': escaped.addAll(kj::StringPtr("\\t")); break; default: - if (c >= 0 && c < 0x20) { + if (static_cast(c) < 0x20) { escaped.addAll(kj::StringPtr("\\u00")); uint8_t c2 = c; escaped.add(HEXDIGITS[c2 / 16]); @@ -195,6 +185,10 @@ void JsonCodec::setMaxNestingDepth(size_t maxNestingDepth) { impl->maxNestingDepth = maxNestingDepth; } +void JsonCodec::setHasMode(HasMode mode) { impl->hasMode = mode; } + +void JsonCodec::setRejectUnknownFields(bool enabled) { impl->rejectUnknownFields = enabled; } + kj::String JsonCodec::encode(DynamicValue::Reader value, Type type) const { MallocMessageBuilder message; auto json = message.getRoot(); @@ -223,12 +217,11 @@ kj::String JsonCodec::encodeRaw(JsonValue::Reader value) const { } void JsonCodec::encode(DynamicValue::Reader input, Type type, JsonValue::Builder output) const { - // TODO(0.7): For interfaces, check for handlers on superclasses, per documentation... - // TODO(0.7): For branded types, should we check for handlers on the generic? + // TODO(someday): For interfaces, check for handlers on superclasses, per documentation... + // TODO(someday): For branded types, should we check for handlers on the generic? // TODO(someday): Allow registering handlers for "all structs", "all lists", etc? - auto iter = impl->typeHandlers.find(type); - if (iter != impl->typeHandlers.end()) { - iter->second->encodeBase(*this, input, output); + KJ_IF_MAYBE(handler, impl->typeHandlers.find(type)) { + (*handler)->encodeBase(*this, input, output); return; } @@ -308,7 +301,7 @@ void JsonCodec::encode(DynamicValue::Reader input, Type type, JsonValue::Builder uint fieldCount = 0; for (auto i: kj::indices(nonUnionFields)) { - fieldCount += (hasField[i] = structValue.has(nonUnionFields[i])); + fieldCount += (hasField[i] = structValue.has(nonUnionFields[i], impl->hasMode)); } // We try to write the union field, if any, in proper order with the rest. @@ -318,7 +311,7 @@ void JsonCodec::encode(DynamicValue::Reader input, Type type, JsonValue::Builder KJ_IF_MAYBE(field, which) { // Even if the union field is null, if it is not the default field of the union then we // have to print it anyway. - unionFieldIsNull = !structValue.has(*field); + unionFieldIsNull = !structValue.has(*field, impl->hasMode); if (field->getProto().getDiscriminantValue() != 0 || !unionFieldIsNull) { ++fieldCount; } else { @@ -374,198 +367,163 @@ void JsonCodec::encode(DynamicValue::Reader input, Type type, JsonValue::Builder void JsonCodec::encodeField(StructSchema::Field field, DynamicValue::Reader input, JsonValue::Builder output) const { - auto iter = impl->fieldHandlers.find(field); - if (iter != impl->fieldHandlers.end()) { - iter->second->encodeBase(*this, input, output); + KJ_IF_MAYBE(handler, impl->fieldHandlers.find(field)) { + (*handler)->encodeBase(*this, input, output); return; } encode(input, field.getType(), output); } -namespace { +Orphan JsonCodec::decodeArray(List::Reader input, ListSchema type, Orphanage orphanage) const { + auto orphan = orphanage.newOrphan(type, input.size()); + auto output = orphan.get(); + for (auto i: kj::indices(input)) { + output.adopt(i, decode(input[i], type.getElementType(), orphanage)); + } + return orphan; +} + +void JsonCodec::decodeObject(JsonValue::Reader input, StructSchema type, Orphanage orphanage, DynamicStruct::Builder output) const { + KJ_REQUIRE(input.isObject(), "Expected object value") { return; } + for (auto field: input.getObject()) { + KJ_IF_MAYBE(fieldSchema, type.findFieldByName(field.getName())) { + decodeField(*fieldSchema, field.getValue(), orphanage, output); + } else { + KJ_REQUIRE(!impl->rejectUnknownFields, "Unknown field", field.getName()); + } + } +} + +void JsonCodec::decodeField(StructSchema::Field fieldSchema, JsonValue::Reader fieldValue, + Orphanage orphanage, DynamicStruct::Builder output) const { + auto fieldType = fieldSchema.getType(); + + KJ_IF_MAYBE(handler, impl->fieldHandlers.find(fieldSchema)) { + output.adopt(fieldSchema, (*handler)->decodeBase(*this, fieldValue, fieldType, orphanage)); + } else { + output.adopt(fieldSchema, decode(fieldValue, fieldType, orphanage)); + } +} + +void JsonCodec::decode(JsonValue::Reader input, DynamicStruct::Builder output) const { + auto type = output.getSchema(); + + KJ_IF_MAYBE(handler, impl->typeHandlers.find(type)) { + return (*handler)->decodeStructBase(*this, input, output); + } + + decodeObject(input, type, Orphanage::getForMessageContaining(output), output); +} + +Orphan JsonCodec::decode( + JsonValue::Reader input, Type type, Orphanage orphanage) const { + KJ_IF_MAYBE(handler, impl->typeHandlers.find(type)) { + return (*handler)->decodeBase(*this, input, type, orphanage); + } -template -void decodeField(Type type, JsonValue::Reader value, SetFn setFn, DecodeArrayFn decodeArrayFn, - DecodeObjectFn decodeObjectFn) { - // This code relies on conversions in DynamicValue::Reader::as. switch(type.which()) { case schema::Type::VOID: - break; + return capnp::VOID; case schema::Type::BOOL: - switch (value.which()) { + switch (input.which()) { case JsonValue::BOOLEAN: - setFn(value.getBoolean()); - break; + return input.getBoolean(); default: KJ_FAIL_REQUIRE("Expected boolean value"); } - break; case schema::Type::INT8: case schema::Type::INT16: case schema::Type::INT32: case schema::Type::INT64: // Relies on range check in DynamicValue::Reader::as - switch (value.which()) { + switch (input.which()) { case JsonValue::NUMBER: - setFn(value.getNumber()); - break; + return input.getNumber(); case JsonValue::STRING: - setFn(value.getString().parseAs()); - break; + return input.getString().parseAs(); default: KJ_FAIL_REQUIRE("Expected integer value"); } - break; case schema::Type::UINT8: case schema::Type::UINT16: case schema::Type::UINT32: case schema::Type::UINT64: // Relies on range check in DynamicValue::Reader::as - switch (value.which()) { + switch (input.which()) { case JsonValue::NUMBER: - setFn(value.getNumber()); - break; + return input.getNumber(); case JsonValue::STRING: - setFn(value.getString().parseAs()); - break; + return input.getString().parseAs(); default: KJ_FAIL_REQUIRE("Expected integer value"); } - break; case schema::Type::FLOAT32: case schema::Type::FLOAT64: - switch (value.which()) { + switch (input.which()) { case JsonValue::NULL_: - setFn(kj::nan()); - break; + return kj::nan(); case JsonValue::NUMBER: - setFn(value.getNumber()); - break; + return input.getNumber(); case JsonValue::STRING: - setFn(value.getString().parseAs()); - break; + return input.getString().parseAs(); default: KJ_FAIL_REQUIRE("Expected float value"); } - break; case schema::Type::TEXT: - switch (value.which()) { + switch (input.which()) { case JsonValue::STRING: - setFn(value.getString()); - break; + return orphanage.newOrphanCopy(input.getString()); default: KJ_FAIL_REQUIRE("Expected text value"); } - break; case schema::Type::DATA: - switch (value.which()) { + switch (input.which()) { case JsonValue::ARRAY: { - auto array = value.getArray(); - kj::Vector data(array.size()); - for (auto arrayObject : array) { - auto x = arrayObject.getNumber(); + auto array = input.getArray(); + auto orphan = orphanage.newOrphan(array.size()); + auto data = orphan.get(); + for (auto i: kj::indices(array)) { + auto x = array[i].getNumber(); KJ_REQUIRE(byte(x) == x, "Number in byte array is not an integer in [0, 255]"); - data.add(byte(x)); + data[i] = x; } - setFn(Data::Reader(data.asPtr())); - break; + return kj::mv(orphan); } default: KJ_FAIL_REQUIRE("Expected data value"); } - break; case schema::Type::LIST: - switch (value.which()) { - case JsonValue::NULL_: - // nothing to do - break; + switch (input.which()) { case JsonValue::ARRAY: - decodeArrayFn(value.getArray()); - break; + return decodeArray(input.getArray(), type.asList(), orphanage); default: - KJ_FAIL_REQUIRE("Expected list value"); + KJ_FAIL_REQUIRE("Expected list value") { break; } + return orphanage.newOrphan(type.asList(), 0); } - break; case schema::Type::ENUM: - switch (value.which()) { + switch (input.which()) { case JsonValue::STRING: - setFn(value.getString()); - break; + return DynamicEnum(type.asEnum().getEnumerantByName(input.getString())); default: - KJ_FAIL_REQUIRE("Expected enum value"); + KJ_FAIL_REQUIRE("Expected enum value") { break; } + return DynamicEnum(type.asEnum(), 0); } - break; - case schema::Type::STRUCT: - switch (value.which()) { - case JsonValue::NULL_: - // nothing to do - break; - case JsonValue::OBJECT: - decodeObjectFn(value.getObject()); - break; - default: - KJ_FAIL_REQUIRE("Expected object value"); - } - break; + case schema::Type::STRUCT: { + auto structType = type.asStruct(); + auto orphan = orphanage.newOrphan(structType); + decodeObject(input, structType, orphanage, orphan.get()); + return kj::mv(orphan); + } case schema::Type::INTERFACE: KJ_FAIL_REQUIRE("don't know how to JSON-decode capabilities; " - "JsonCodec::Handler not implemented yet :("); + "please register a JsonCodec::Handler for this"); case schema::Type::ANY_POINTER: KJ_FAIL_REQUIRE("don't know how to JSON-decode AnyPointer; " - "JsonCodec::Handler not implemented yet :("); - } -} -} // namespace - -void JsonCodec::decodeArray(List::Reader input, DynamicList::Builder output) const { - KJ_ASSERT(input.size() == output.size(), "Builder was not initialized to input size"); - auto type = output.getSchema().getElementType(); - for (auto i = 0; i < input.size(); i++) { - decodeField(type, input[i], - [&](DynamicValue::Reader value) { output.set(i, value); }, - [&](List::Reader array) { - decodeArray(array, output.init(i, array.size()).as()); - }, - [&](List::Reader object) { - decodeObject(object, output[i].as()); - }); - } -} - -void JsonCodec::decodeObject(List::Reader input, DynamicStruct::Builder output) - const { - for (auto field : input) { - KJ_IF_MAYBE(fieldSchema, output.getSchema().findFieldByName(field.getName())) { - decodeField((*fieldSchema).getType(), field.getValue(), - [&](DynamicValue::Reader value) { output.set(*fieldSchema, value); }, - [&](List::Reader array) { - decodeArray(array, output.init(*fieldSchema, array.size()).as()); - }, - [&](List::Reader object) { - decodeObject(object, output.init(*fieldSchema).as()); - }); - } else { - // Unknown json fields are ignored to allow schema evolution - } + "please register a JsonCodec::Handler for this"); } -} - -void JsonCodec::decode(JsonValue::Reader input, DynamicStruct::Builder output) const { - // TODO(0.7): type and field handlers - switch (input.which()) { - case JsonValue::OBJECT: - decodeObject(input.getObject(), output); - break; - default: - KJ_FAIL_REQUIRE("Top level json value must be object"); - }; -} -Orphan JsonCodec::decode( - JsonValue::Reader input, Type type, Orphanage orphanage) const { - // TODO(0.7) - KJ_FAIL_ASSERT("JSON decode into orphanage not implement yet. :("); + KJ_CLANG_KNOWS_THIS_IS_UNREACHABLE_BUT_GCC_DOESNT; } // ----------------------------------------------------------------------------- @@ -694,19 +652,7 @@ public: } void parseNumber(JsonValue::Builder& output) { - auto numberStr = consumeNumber(); - char *endPtr; - - errno = 0; - double value = strtod(numberStr.begin(), &endPtr); - - KJ_ASSERT(endPtr != numberStr.begin(), "strtod should not fail! Is consumeNumber wrong?"); - KJ_REQUIRE((value != HUGE_VAL && value != -HUGE_VAL) || errno != ERANGE, - "Overflow in JSON number."); - KJ_REQUIRE(value != 0.0 || errno != ERANGE, - "Underflow in JSON number."); - - output.setNumber(value); + output.setNumber(consumeNumber().parseAs()); } void parseString(JsonValue::Builder& output) { @@ -799,7 +745,7 @@ public: private: kj::String consumeQuotedString() { input.consume('"'); - // TODO(perf): Avoid copy / alloc if no escapes encoutered. + // TODO(perf): Avoid copy / alloc if no escapes encountered. // TODO(perf): Get statistics on string size and preallocate? kj::Vector decoded; @@ -877,17 +823,21 @@ private: if ('0' <= c && c <= '9') { codePoint |= c - '0'; } else if ('a' <= c && c <= 'f') { - codePoint |= c - 'a'; + codePoint |= c - 'a' + 10; } else if ('A' <= c && c <= 'F') { - codePoint |= c - 'A'; + codePoint |= c - 'A' + 10; } else { KJ_FAIL_REQUIRE("Invalid hex digit in unicode escape.", c); } } - // TODO(0.7): Support at least basic multi-lingual plane, ie ignore surrogates. - KJ_REQUIRE(codePoint < 128, "non-ASCII unicode escapes are not supported (yet!)"); - target.add(0x7f & static_cast(codePoint)); + if (codePoint < 128) { + target.add(0x7f & static_cast(codePoint)); + } else { + // TODO(perf): This is sorta malloc-heavy... + char16_t u = codePoint; + target.addAll(kj::decodeUtf16(kj::arrayPtr(&u, 1))); + } } const size_t maxNestingDepth; @@ -919,13 +869,613 @@ void JsonCodec::HandlerBase::decodeStructBase( } void JsonCodec::addTypeHandlerImpl(Type type, HandlerBase& handler) { - impl->typeHandlers[type] = &handler; + impl->typeHandlers.upsert(type, &handler, [](HandlerBase*& existing, HandlerBase* replacement) { + KJ_REQUIRE(existing == replacement, "type already has a different registered handler"); + }); } void JsonCodec::addFieldHandlerImpl(StructSchema::Field field, Type type, HandlerBase& handler) { KJ_REQUIRE(type == field.getType(), "handler type did not match field type for addFieldHandler()"); - impl->fieldHandlers[field] = &handler; + impl->fieldHandlers.upsert(field, &handler, [](HandlerBase*& existing, HandlerBase* replacement) { + KJ_REQUIRE(existing == replacement, "field already has a different registered handler"); + }); +} + +// ======================================================================================= + +static constexpr uint64_t JSON_NAME_ANNOTATION_ID = 0xfa5b1fd61c2e7c3dull; +static constexpr uint64_t JSON_FLATTEN_ANNOTATION_ID = 0x82d3e852af0336bfull; +static constexpr uint64_t JSON_DISCRIMINATOR_ANNOTATION_ID = 0xcfa794e8d19a0162ull; +static constexpr uint64_t JSON_BASE64_ANNOTATION_ID = 0xd7d879450a253e4bull; +static constexpr uint64_t JSON_HEX_ANNOTATION_ID = 0xf061e22f0ae5c7b5ull; + +class JsonCodec::Base64Handler final: public JsonCodec::Handler { +public: + void encode(const JsonCodec& codec, capnp::Data::Reader input, JsonValue::Builder output) const { + output.setString(kj::encodeBase64(input)); + } + + Orphan decode(const JsonCodec& codec, JsonValue::Reader input, + Orphanage orphanage) const { + return orphanage.newOrphanCopy(capnp::Data::Reader(kj::decodeBase64(input.getString()))); + } +}; + +class JsonCodec::HexHandler final: public JsonCodec::Handler { +public: + void encode(const JsonCodec& codec, capnp::Data::Reader input, JsonValue::Builder output) const { + output.setString(kj::encodeHex(input)); + } + + Orphan decode(const JsonCodec& codec, JsonValue::Reader input, + Orphanage orphanage) const { + return orphanage.newOrphanCopy(capnp::Data::Reader(kj::decodeHex(input.getString()))); + } +}; + +class JsonCodec::AnnotatedHandler final: public JsonCodec::Handler { +public: + AnnotatedHandler(JsonCodec& codec, StructSchema schema, + kj::Maybe discriminator, + kj::Maybe unionDeclName, + kj::Vector& dependencies) + : schema(schema) { + auto schemaProto = schema.getProto(); + auto typeName = schemaProto.getDisplayName(); + + if (discriminator == nullptr) { + // There are two cases of unions: + // * Named unions, which are special cases of named groups. In this case, the union may be + // annotated by annotating the field. In this case, we receive a non-null `discriminator` + // as a constructor parameter, and schemaProto.getAnnotations() must be empty because + // it's not possible to annotate a group's type (because the type is anonymous). + // * Unnamed unions, of which there can only be one in any particular scope. In this case, + // the parent struct type itself is annotated. + // So if we received `null` as the constructor parameter, check for annotations on the struct + // type. + for (auto anno: schemaProto.getAnnotations()) { + switch (anno.getId()) { + case JSON_DISCRIMINATOR_ANNOTATION_ID: + discriminator = anno.getValue().getStruct().getAs(); + break; + } + } + } + + KJ_IF_MAYBE(d, discriminator) { + if (d->hasName()) { + unionTagName = d->getName(); + } else { + unionTagName = unionDeclName; + } + KJ_IF_MAYBE(u, unionTagName) { + fieldsByName.insert(*u, FieldNameInfo { + FieldNameInfo::UNION_TAG, 0, 0, nullptr + }); + } + + if (d->hasValueName()) { + fieldsByName.insert(d->getValueName(), FieldNameInfo { + FieldNameInfo::UNION_VALUE, 0, 0, nullptr + }); + } + } + + discriminantOffset = schemaProto.getStruct().getDiscriminantOffset(); + + fields = KJ_MAP(field, schema.getFields()) { + auto fieldProto = field.getProto(); + auto type = field.getType(); + auto fieldName = fieldProto.getName(); + + FieldNameInfo nameInfo; + nameInfo.index = field.getIndex(); + nameInfo.type = FieldNameInfo::NORMAL; + nameInfo.prefixLength = 0; + + FieldInfo info; + info.name = fieldName; + + kj::Maybe subDiscriminator; + bool flattened = false; + for (auto anno: field.getProto().getAnnotations()) { + switch (anno.getId()) { + case JSON_NAME_ANNOTATION_ID: + info.name = anno.getValue().getText(); + break; + case JSON_FLATTEN_ANNOTATION_ID: + KJ_REQUIRE(type.isStruct(), "only struct types can be flattened", fieldName, typeName); + flattened = true; + info.prefix = anno.getValue().getStruct().getAs().getPrefix(); + break; + case JSON_DISCRIMINATOR_ANNOTATION_ID: + KJ_REQUIRE(fieldProto.isGroup(), "only unions can have discriminator"); + subDiscriminator = anno.getValue().getStruct().getAs(); + break; + case JSON_BASE64_ANNOTATION_ID: { + KJ_REQUIRE(field.getType().isData(), "only Data can be marked for base64 encoding"); + static Base64Handler handler; + codec.addFieldHandler(field, handler); + break; + } + case JSON_HEX_ANNOTATION_ID: { + KJ_REQUIRE(field.getType().isData(), "only Data can be marked for hex encoding"); + static HexHandler handler; + codec.addFieldHandler(field, handler); + break; + } + } + } + + if (fieldProto.isGroup()) { + // Load group type handler now, even if not flattened, so that we can pass its + // `subDiscriminator`. + kj::Maybe subFieldName; + if (flattened) { + // If the group was flattened, then we allow its field name to be used as the + // discriminator name, so that the discriminator doesn't have to explicitly specify a + // name. + subFieldName = fieldName; + } + auto& subHandler = codec.loadAnnotatedHandler( + type.asStruct(), subDiscriminator, subFieldName, dependencies); + if (flattened) { + info.flattenHandler = subHandler; + } + } else if (type.isStruct()) { + if (flattened) { + info.flattenHandler = codec.loadAnnotatedHandler( + type.asStruct(), nullptr, nullptr, dependencies); + } + } + + bool isUnionMember = fieldProto.getDiscriminantValue() != schema::Field::NO_DISCRIMINANT; + + KJ_IF_MAYBE(fh, info.flattenHandler) { + // Set up fieldsByName for each of the child's fields. + for (auto& entry: fh->fieldsByName) { + kj::StringPtr flattenedName; + kj::String ownName; + if (info.prefix.size() > 0) { + ownName = kj::str(info.prefix, entry.key); + flattenedName = ownName; + } else { + flattenedName = entry.key; + } + + fieldsByName.upsert(flattenedName, FieldNameInfo { + isUnionMember ? FieldNameInfo::FLATTENED_FROM_UNION : FieldNameInfo::FLATTENED, + field.getIndex(), (uint)info.prefix.size(), kj::mv(ownName) + }, [&](FieldNameInfo& existing, FieldNameInfo&& replacement) { + KJ_REQUIRE(existing.type == FieldNameInfo::FLATTENED_FROM_UNION && + replacement.type == FieldNameInfo::FLATTENED_FROM_UNION, + "flattened members have the same name and are not mutually exclusive"); + }); + } + } + + info.nameForDiscriminant = info.name; + + if (!flattened) { + bool isUnionWithValueName = false; + if (isUnionMember) { + KJ_IF_MAYBE(d, discriminator) { + if (d->hasValueName()) { + info.name = d->getValueName(); + isUnionWithValueName = true; + } + } + } + + if (!isUnionWithValueName) { + fieldsByName.insert(info.name, kj::mv(nameInfo)); + } + } + + if (isUnionMember) { + unionTagValues.insert(info.nameForDiscriminant, field); + } + + // Look for dependencies that we need to add. + while (type.isList()) type = type.asList().getElementType(); + if (codec.impl->typeHandlers.find(type) == nullptr) { + switch (type.which()) { + case schema::Type::STRUCT: + dependencies.add(type.asStruct()); + break; + case schema::Type::ENUM: + dependencies.add(type.asEnum()); + break; + case schema::Type::INTERFACE: + dependencies.add(type.asInterface()); + break; + default: + break; + } + } + + return info; + }; + } + + const StructSchema schema; + + void encode(const JsonCodec& codec, DynamicStruct::Reader input, + JsonValue::Builder output) const override { + kj::Vector flattenedFields; + gatherForEncode(codec, input, nullptr, nullptr, flattenedFields); + + auto outs = output.initObject(flattenedFields.size()); + for (auto i: kj::indices(flattenedFields)) { + auto& in = flattenedFields[i]; + auto out = outs[i]; + out.setName(in.name); + KJ_SWITCH_ONEOF(in.type) { + KJ_CASE_ONEOF(type, Type) { + codec.encode(in.value, type, out.initValue()); + } + KJ_CASE_ONEOF(field, StructSchema::Field) { + codec.encodeField(field, in.value, out.initValue()); + } + } + } + } + + void decode(const JsonCodec& codec, JsonValue::Reader input, + DynamicStruct::Builder output) const override { + KJ_REQUIRE(input.isObject()); + kj::HashSet unionsSeen; + kj::Vector retries; + for (auto field: input.getObject()) { + if (!decodeField(codec, field.getName(), field.getValue(), output, unionsSeen)) { + retries.add(field); + } + } + while (!retries.empty()) { + auto retriesCopy = kj::mv(retries); + KJ_ASSERT(retries.empty()); + for (auto field: retriesCopy) { + if (!decodeField(codec, field.getName(), field.getValue(), output, unionsSeen)) { + retries.add(field); + } + } + if (retries.size() == retriesCopy.size()) { + // We made no progress in this iteration. Give up on the remaining fields. + break; + } + } + } + +private: + struct FieldInfo { + kj::StringPtr name; + kj::StringPtr nameForDiscriminant; + kj::Maybe flattenHandler; + kj::StringPtr prefix; + }; + + kj::Array fields; + // Maps field index -> info about the field + + struct FieldNameInfo { + enum { + NORMAL, + // This is a normal field with the given `index`. + + FLATTENED, + // This is a field of a flattened inner struct or group (that is not in a union). `index` + // is the field index of the particular struct/group field. + + UNION_TAG, + // The parent struct is a flattened union, and this field is the discriminant tag. It is a + // string field whose name determines the union type. `index` is not used. + + FLATTENED_FROM_UNION, + // The parent struct is a flattened union, and some of the union's members are flattened + // structs or groups, and this field is possibly a member of one or more of them. `index` + // is not used, because it's possible that the same field name appears in multiple variants. + // Instead, the parser must find the union tag, and then can descend and attempt to parse + // the field in the context of whichever variant is selected. + + UNION_VALUE + // This field is the value of a discriminated union that has `valueName` set. + } type; + + uint index; + // For `NORMAL` and `FLATTENED`, the index of the field in schema.getFields(). + + uint prefixLength; + kj::String ownName; + }; + + kj::HashMap fieldsByName; + // Maps JSON names to info needed to parse them. + + kj::HashMap unionTagValues; + // If the parent struct is a flattened union, it has a tag field which is a string with one of + // these values. The map maps to the union member to set. + + kj::Maybe unionTagName; + // If the parent struct is a flattened union, the name of the "tag" field. + + uint discriminantOffset; + // Shortcut for schema.getProto().getStruct().getDiscriminantOffset(), used in a hack to identify + // which unions have been seen. + + struct FlattenedField { + kj::String ownName; + kj::StringPtr name; + kj::OneOf type; + DynamicValue::Reader value; + + FlattenedField(kj::StringPtr prefix, kj::StringPtr name, + kj::OneOf type, DynamicValue::Reader value) + : ownName(prefix.size() > 0 ? kj::str(prefix, name) : nullptr), + name(prefix.size() > 0 ? ownName : name), + type(type), value(value) {} + }; + + void gatherForEncode(const JsonCodec& codec, DynamicValue::Reader input, + kj::StringPtr prefix, kj::StringPtr morePrefix, + kj::Vector& flattenedFields) const { + kj::String ownPrefix; + if (morePrefix.size() > 0) { + if (prefix.size() > 0) { + ownPrefix = kj::str(prefix, morePrefix); + prefix = ownPrefix; + } else { + prefix = morePrefix; + } + } + + auto reader = input.as(); + auto schema = reader.getSchema(); + for (auto field: schema.getNonUnionFields()) { + auto& info = fields[field.getIndex()]; + if (!reader.has(field, codec.impl->hasMode)) { + // skip + } else KJ_IF_MAYBE(handler, info.flattenHandler) { + handler->gatherForEncode(codec, reader.get(field), prefix, info.prefix, flattenedFields); + } else { + flattenedFields.add(FlattenedField { + prefix, info.name, field, reader.get(field) }); + } + } + + KJ_IF_MAYBE(which, reader.which()) { + auto& info = fields[which->getIndex()]; + KJ_IF_MAYBE(tag, unionTagName) { + flattenedFields.add(FlattenedField { + prefix, *tag, Type(schema::Type::TEXT), Text::Reader(info.nameForDiscriminant) }); + } + + KJ_IF_MAYBE(handler, info.flattenHandler) { + handler->gatherForEncode(codec, reader.get(*which), prefix, info.prefix, flattenedFields); + } else { + auto type = which->getType(); + if (type.which() == schema::Type::VOID && unionTagName != nullptr) { + // When we have an explicit union discriminant, we don't need to encode void fields. + } else { + flattenedFields.add(FlattenedField { + prefix, info.name, *which, reader.get(*which) }); + } + } + } + } + + bool decodeField(const JsonCodec& codec, kj::StringPtr name, JsonValue::Reader value, + DynamicStruct::Builder output, kj::HashSet& unionsSeen) const { + KJ_ASSERT(output.getSchema() == schema); + + KJ_IF_MAYBE(info, fieldsByName.find(name)) { + switch (info->type) { + case FieldNameInfo::NORMAL: { + auto field = output.getSchema().getFields()[info->index]; + codec.decodeField(field, value, Orphanage::getForMessageContaining(output), output); + return true; + } + case FieldNameInfo::FLATTENED: + return KJ_ASSERT_NONNULL(fields[info->index].flattenHandler) + .decodeField(codec, name.slice(info->prefixLength), value, + output.get(output.getSchema().getFields()[info->index]).as(), + unionsSeen); + case FieldNameInfo::UNION_TAG: { + KJ_REQUIRE(value.isString(), "Expected string value."); + + // Mark that we've seen a union tag for this struct. + const void* ptr = getUnionInstanceIdentifier(output); + KJ_IF_MAYBE(field, unionTagValues.find(value.getString())) { + // clear() has the side-effect of activating this member of the union, without + // allocating any objects. + output.clear(*field); + unionsSeen.insert(ptr); + } + return true; + } + case FieldNameInfo::FLATTENED_FROM_UNION: { + const void* ptr = getUnionInstanceIdentifier(output); + if (unionsSeen.contains(ptr)) { + auto variant = KJ_ASSERT_NONNULL(output.which()); + return KJ_ASSERT_NONNULL(fields[variant.getIndex()].flattenHandler) + .decodeField(codec, name.slice(info->prefixLength), value, + output.get(variant).as(), unionsSeen); + } else { + // We haven't seen the union tag yet, so we can't parse this field yet. Try again later. + return false; + } + } + case FieldNameInfo::UNION_VALUE: { + const void* ptr = getUnionInstanceIdentifier(output); + if (unionsSeen.contains(ptr)) { + auto variant = KJ_ASSERT_NONNULL(output.which()); + codec.decodeField(variant, value, Orphanage::getForMessageContaining(output), output); + return true; + } else { + // We haven't seen the union tag yet, so we can't parse this field yet. Try again later. + return false; + } + } + } + + KJ_UNREACHABLE; + } else { + // Ignore undefined field -- unless the flag is set to reject them. + KJ_REQUIRE(!codec.impl->rejectUnknownFields, "Unknown field", name); + return true; + } + } + + const void* getUnionInstanceIdentifier(DynamicStruct::Builder obj) const { + // Gets a value uniquely identifying an instance of a union. + // HACK: We return a pointer to the union's discriminant within the underlying buffer. + return reinterpret_cast( + AnyStruct::Reader(obj.asReader()).getDataSection().begin()) + discriminantOffset; + } +}; + +class JsonCodec::AnnotatedEnumHandler final: public JsonCodec::Handler { +public: + AnnotatedEnumHandler(EnumSchema schema): schema(schema) { + auto enumerants = schema.getEnumerants(); + auto builder = kj::heapArrayBuilder(enumerants.size()); + + for (auto e: enumerants) { + auto proto = e.getProto(); + kj::StringPtr name = proto.getName(); + + for (auto anno: proto.getAnnotations()) { + switch (anno.getId()) { + case JSON_NAME_ANNOTATION_ID: + name = anno.getValue().getText(); + break; + } + } + + builder.add(name); + nameToValue.insert(name, e.getIndex()); + } + + valueToName = builder.finish(); + } + + void encode(const JsonCodec& codec, DynamicEnum input, JsonValue::Builder output) const override { + KJ_IF_MAYBE(e, input.getEnumerant()) { + KJ_ASSERT(e->getIndex() < valueToName.size()); + output.setString(valueToName[e->getIndex()]); + } else { + output.setNumber(input.getRaw()); + } + } + + DynamicEnum decode(const JsonCodec& codec, JsonValue::Reader input) const override { + if (input.isNumber()) { + return DynamicEnum(schema, static_cast(input.getNumber())); + } else { + uint16_t val = KJ_REQUIRE_NONNULL(nameToValue.find(input.getString()), + "invalid enum value", input.getString()); + return DynamicEnum(schema.getEnumerants()[val]); + } + } + +private: + EnumSchema schema; + kj::Array valueToName; + kj::HashMap nameToValue; +}; + +class JsonCodec::JsonValueHandler final: public JsonCodec::Handler { +public: + void encode(const JsonCodec& codec, DynamicStruct::Reader input, + JsonValue::Builder output) const override { +#if _MSC_VER && !defined(__clang__) + // TODO(msvc): Hack to work around missing AnyStruct::Builder constructor on MSVC. + rawCopy(input, toDynamic(output)); +#else + rawCopy(input, kj::mv(output)); +#endif + } + + void decode(const JsonCodec& codec, JsonValue::Reader input, + DynamicStruct::Builder output) const override { + rawCopy(input, kj::mv(output)); + } + +private: + void rawCopy(AnyStruct::Reader input, AnyStruct::Builder output) const { + // HACK: Manually copy using AnyStruct, so that if JsonValue's definition changes, this code + // doesn't need to be updated. However, note that if JsonValue ever adds new fields that + // change its size, and the input struct is a newer version than the output, we may lose + // the new fields. Technically the "correct" thing to do would be to allocate the output + // struct to be exactly the same size as the input, but JsonCodec's Handler interface is + // not designed to allow that -- it passes in an already-allocated builder. Oops. + auto dataIn = input.getDataSection(); + auto dataOut = output.getDataSection(); + memcpy(dataOut.begin(), dataIn.begin(), kj::min(dataOut.size(), dataIn.size())); + + auto ptrIn = input.getPointerSection(); + auto ptrOut = output.getPointerSection(); + for (auto i: kj::zeroTo(kj::min(ptrIn.size(), ptrOut.size()))) { + ptrOut[i].set(ptrIn[i]); + } + } +}; + +JsonCodec::AnnotatedHandler& JsonCodec::loadAnnotatedHandler( + StructSchema schema, kj::Maybe discriminator, + kj::Maybe unionDeclName, kj::Vector& dependencies) { + auto& entry = impl->annotatedHandlers.upsert(schema, nullptr, + [&](kj::Maybe>& existing, auto dummy) { + KJ_ASSERT(existing != nullptr, + "cyclic JSON flattening detected", schema.getProto().getDisplayName()); + }); + + KJ_IF_MAYBE(v, entry.value) { + // Already exists. + return **v; + } else { + // Not seen before. + auto newHandler = kj::heap( + *this, schema, discriminator, unionDeclName, dependencies); + auto& result = *newHandler; + + // Map may have changed, so we have to look up again. + KJ_ASSERT_NONNULL(impl->annotatedHandlers.find(schema)) = kj::mv(newHandler); + + addTypeHandler(schema, result); + return result; + }; +} + +void JsonCodec::handleByAnnotation(Schema schema) { + switch (schema.getProto().which()) { + case schema::Node::STRUCT: { + if (schema.getProto().getId() == capnp::typeId()) { + // Special handler for JsonValue. + static JsonValueHandler GLOBAL_HANDLER; + addTypeHandler(schema.asStruct(), GLOBAL_HANDLER); + } else { + kj::Vector dependencies; + loadAnnotatedHandler(schema.asStruct(), nullptr, nullptr, dependencies); + for (auto dep: dependencies) { + handleByAnnotation(dep); + } + } + break; + } + case schema::Node::ENUM: { + auto enumSchema = schema.asEnum(); + impl->annotatedEnumHandlers.findOrCreate(enumSchema, [&]() { + auto handler = kj::heap(enumSchema); + addTypeHandler(enumSchema, *handler); + return kj::HashMap>::Entry { + enumSchema, kj::mv(handler) }; + }); + break; + } + default: + break; + } } } // namespace capnp diff --git a/c++/src/capnp/compat/json.capnp b/c++/src/capnp/compat/json.capnp index 55188736f8..e5d1870c00 100644 --- a/c++/src/capnp/compat/json.capnp +++ b/c++/src/capnp/compat/json.capnp @@ -21,15 +21,15 @@ @0x8ef99297a43a5e34; -$import "/capnp/c++.capnp".namespace("capnp"); +$import "/capnp/c++.capnp".namespace("capnp::json"); -struct JsonValue { +struct Value { union { null @0 :Void; boolean @1 :Bool; number @2 :Float64; string @3 :Text; - array @4 :List(JsonValue); + array @4 :List(Value); object @5 :List(Field); # Standard JSON values. @@ -44,15 +44,84 @@ struct JsonValue { # "binary" and "date" types in text, since JSON has no analog of these. This is basically the # reason this extension exists. We do NOT recommend using `call` unless you specifically need # to be compatible with some silly format that uses this syntax. + + raw @7 :Text; + # Used to indicate that the text should be written directly to the output without + # modifications. Use this if you have an already serialized JSON value and don't want + # to feel the cost of deserializing the value just to serialize it again. + # + # The parser will never produce a `raw` value -- this is only useful for serialization. + # + # WARNING: You MUST ensure that the value is valid stand-alone JSOn. It will not be verified. + # Invalid JSON could mjake the whole message unparsable. Worse, a malicious raw value could + # perform JSON injection attacks. Make sure that the value was produced by a trustworthy JSON + # encoder. } struct Field { name @0 :Text; - value @1 :JsonValue; + value @1 :Value; } struct Call { function @0 :Text; - params @1 :List(JsonValue); + params @1 :List(Value); } } + +# ======================================================================================== +# Annotations to control parsing. Typical usage: +# +# using Json = import "/capnp/compat/json.capnp"; +# +# And then later on: +# +# myField @0 :Text $Json.name("my_field"); + +annotation name @0xfa5b1fd61c2e7c3d (field, enumerant, method, group, union) :Text; +# Define an alternative name to use when encoding the given item in JSON. This can be used, for +# example, to use snake_case names where needed, even though Cap'n Proto uses strictly camelCase. +# +# (However, because JSON is derived from JavaScript, you *should* use camelCase names when +# defining JSON-based APIs. But, when supporting a pre-existing API you may not have a choice.) + +annotation flatten @0x82d3e852af0336bf (field, group, union) :FlattenOptions; +# Specifies that an aggregate field should be flattened into its parent. +# +# In order to flatten a member of a union, the union (or, for an anonymous union, the parent +# struct type) must have the $jsonDiscriminator annotation. +# +# TODO(someday): Maybe support "flattening" a List(Value.Field) as a way to support unknown JSON +# fields? + +struct FlattenOptions { + prefix @0 :Text = ""; + # Optional: Adds the given prefix to flattened field names. +} + +annotation discriminator @0xcfa794e8d19a0162 (struct, union) :DiscriminatorOptions; +# Specifies that a union's variant will be decided not by which fields are present, but instead +# by a special discriminator field. The value of the discriminator field is a string naming which +# variant is active. This allows the members of the union to have the $jsonFlatten annotation, or +# to all have the same name. + +struct DiscriminatorOptions { + name @0 :Text; + # The name of the discriminator field. Defaults to matching the name of the union. + + valueName @1 :Text; + # If non-null, specifies that the union's value shall have the given field name, rather than the + # value's name. In this case the union's variant can only be determined by looking at the + # discriminant field, not by inspecting which value field is present. + # + # It is an error to use `valueName` while also declaring some variants as $flatten. +} + +annotation base64 @0xd7d879450a253e4b (field) :Void; +# Place on a field of type `Data` to indicate that its JSON representation is a Base64 string. + +annotation hex @0xf061e22f0ae5c7b5 (field) :Void; +# Place on a field of type `Data` to indicate that its JSON representation is a hex string. + +annotation notification @0xa0a054dea32fd98c (method) :Void; +# Indicates that this method is a JSON-RPC "notification", meaning it expects no response. diff --git a/c++/src/capnp/compat/json.capnp.c++ b/c++/src/capnp/compat/json.capnp.c++ index a7f0eedd9c..faf41e5066 100644 --- a/c++/src/capnp/compat/json.capnp.c++ +++ b/c++/src/capnp/compat/json.capnp.c++ @@ -5,81 +5,87 @@ namespace capnp { namespace schemas { -static const ::capnp::_::AlignedData<138> b_8825ffaa852cda72 = { +static const ::capnp::_::AlignedData<152> b_a3fa7845f919dd83 = { { 0, 0, 0, 0, 5, 0, 6, 0, - 114, 218, 44, 133, 170, 255, 37, 136, + 131, 221, 25, 249, 69, 120, 250, 163, 24, 0, 0, 0, 1, 0, 2, 0, 52, 94, 58, 164, 151, 146, 249, 142, - 1, 0, 7, 0, 0, 0, 7, 0, + 1, 0, 7, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 21, 0, 0, 0, 18, 1, 0, 0, - 37, 0, 0, 0, 39, 0, 0, 0, + 21, 0, 0, 0, 242, 0, 0, 0, + 33, 0, 0, 0, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 57, 0, 0, 0, 143, 1, 0, 0, + 53, 0, 0, 0, 199, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 99, 111, 109, 112, 97, 116, 47, 106, 115, 111, 110, 46, 99, 97, 112, 110, 112, 58, - 74, 115, 111, 110, 86, 97, 108, 117, - 101, 0, 0, 0, 0, 0, 0, 0, + 86, 97, 108, 117, 101, 0, 0, 0, 8, 0, 0, 0, 1, 0, 1, 0, - 204, 55, 169, 83, 216, 85, 120, 194, + 223, 157, 214, 53, 231, 38, 16, 227, 9, 0, 0, 0, 50, 0, 0, 0, - 96, 187, 212, 61, 21, 132, 191, 155, + 72, 61, 201, 161, 236, 246, 217, 160, 5, 0, 0, 0, 42, 0, 0, 0, 70, 105, 101, 108, 100, 0, 0, 0, 67, 97, 108, 108, 0, 0, 0, 0, - 28, 0, 0, 0, 3, 0, 4, 0, + 32, 0, 0, 0, 3, 0, 4, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 181, 0, 0, 0, 42, 0, 0, 0, + 209, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 176, 0, 0, 0, 3, 0, 1, 0, - 188, 0, 0, 0, 2, 0, 1, 0, + 204, 0, 0, 0, 3, 0, 1, 0, + 216, 0, 0, 0, 2, 0, 1, 0, 1, 0, 254, 255, 16, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 185, 0, 0, 0, 66, 0, 0, 0, + 213, 0, 0, 0, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 180, 0, 0, 0, 3, 0, 1, 0, - 192, 0, 0, 0, 2, 0, 1, 0, + 208, 0, 0, 0, 3, 0, 1, 0, + 220, 0, 0, 0, 2, 0, 1, 0, 2, 0, 253, 255, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 189, 0, 0, 0, 58, 0, 0, 0, + 217, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 184, 0, 0, 0, 3, 0, 1, 0, - 196, 0, 0, 0, 2, 0, 1, 0, + 212, 0, 0, 0, 3, 0, 1, 0, + 224, 0, 0, 0, 2, 0, 1, 0, 3, 0, 252, 255, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 193, 0, 0, 0, 58, 0, 0, 0, + 221, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 188, 0, 0, 0, 3, 0, 1, 0, - 200, 0, 0, 0, 2, 0, 1, 0, + 216, 0, 0, 0, 3, 0, 1, 0, + 228, 0, 0, 0, 2, 0, 1, 0, 4, 0, 251, 255, 0, 0, 0, 0, 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 197, 0, 0, 0, 50, 0, 0, 0, + 225, 0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 192, 0, 0, 0, 3, 0, 1, 0, - 220, 0, 0, 0, 2, 0, 1, 0, + 220, 0, 0, 0, 3, 0, 1, 0, + 248, 0, 0, 0, 2, 0, 1, 0, 5, 0, 250, 255, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 217, 0, 0, 0, 58, 0, 0, 0, + 245, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 212, 0, 0, 0, 3, 0, 1, 0, - 240, 0, 0, 0, 2, 0, 1, 0, + 240, 0, 0, 0, 3, 0, 1, 0, + 12, 1, 0, 0, 2, 0, 1, 0, 6, 0, 249, 255, 0, 0, 0, 0, 0, 0, 1, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 237, 0, 0, 0, 42, 0, 0, 0, + 9, 1, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 232, 0, 0, 0, 3, 0, 1, 0, - 244, 0, 0, 0, 2, 0, 1, 0, + 4, 1, 0, 0, 3, 0, 1, 0, + 16, 1, 0, 0, 2, 0, 1, 0, + 7, 0, 248, 255, 0, 0, 0, 0, + 0, 0, 1, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 13, 1, 0, 0, 34, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 8, 1, 0, 0, 3, 0, 1, 0, + 20, 1, 0, 0, 2, 0, 1, 0, 110, 117, 108, 108, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -118,7 +124,7 @@ static const ::capnp::_::AlignedData<138> b_8825ffaa852cda72 = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 16, 0, 0, 0, 0, 0, 0, 0, - 114, 218, 44, 133, 170, 255, 37, 136, + 131, 221, 25, 249, 69, 120, 250, 163, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, @@ -130,7 +136,7 @@ static const ::capnp::_::AlignedData<138> b_8825ffaa852cda72 = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 16, 0, 0, 0, 0, 0, 0, 0, - 204, 55, 169, 83, 216, 85, 120, 194, + 223, 157, 214, 53, 231, 38, 16, 227, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, @@ -138,35 +144,43 @@ static const ::capnp::_::AlignedData<138> b_8825ffaa852cda72 = { 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 108, 108, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, - 96, 187, 212, 61, 21, 132, 191, 155, + 72, 61, 201, 161, 236, 246, 217, 160, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 114, 97, 119, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; -::capnp::word const* const bp_8825ffaa852cda72 = b_8825ffaa852cda72.words; +::capnp::word const* const bp_a3fa7845f919dd83 = b_a3fa7845f919dd83.words; #if !CAPNP_LITE -static const ::capnp::_::RawSchema* const d_8825ffaa852cda72[] = { - &s_8825ffaa852cda72, - &s_9bbf84153dd4bb60, - &s_c27855d853a937cc, +static const ::capnp::_::RawSchema* const d_a3fa7845f919dd83[] = { + &s_a0d9f6eca1c93d48, + &s_a3fa7845f919dd83, + &s_e31026e735d69ddf, }; -static const uint16_t m_8825ffaa852cda72[] = {4, 1, 6, 0, 2, 5, 3}; -static const uint16_t i_8825ffaa852cda72[] = {0, 1, 2, 3, 4, 5, 6}; -const ::capnp::_::RawSchema s_8825ffaa852cda72 = { - 0x8825ffaa852cda72, b_8825ffaa852cda72.words, 138, d_8825ffaa852cda72, m_8825ffaa852cda72, - 3, 7, i_8825ffaa852cda72, nullptr, nullptr, { &s_8825ffaa852cda72, nullptr, nullptr, 0, 0, nullptr } +static const uint16_t m_a3fa7845f919dd83[] = {4, 1, 6, 0, 2, 5, 7, 3}; +static const uint16_t i_a3fa7845f919dd83[] = {0, 1, 2, 3, 4, 5, 6, 7}; +const ::capnp::_::RawSchema s_a3fa7845f919dd83 = { + 0xa3fa7845f919dd83, b_a3fa7845f919dd83.words, 152, d_a3fa7845f919dd83, m_a3fa7845f919dd83, + 3, 8, i_a3fa7845f919dd83, nullptr, nullptr, { &s_a3fa7845f919dd83, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<49> b_c27855d853a937cc = { +static const ::capnp::_::AlignedData<49> b_e31026e735d69ddf = { { 0, 0, 0, 0, 5, 0, 6, 0, - 204, 55, 169, 83, 216, 85, 120, 194, - 34, 0, 0, 0, 1, 0, 0, 0, - 114, 218, 44, 133, 170, 255, 37, 136, + 223, 157, 214, 53, 231, 38, 16, 227, + 30, 0, 0, 0, 1, 0, 0, 0, + 131, 221, 25, 249, 69, 120, 250, 163, 2, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 21, 0, 0, 0, 66, 1, 0, 0, + 21, 0, 0, 0, 34, 1, 0, 0, 37, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 33, 0, 0, 0, 119, 0, 0, 0, @@ -175,8 +189,8 @@ static const ::capnp::_::AlignedData<49> b_c27855d853a937cc = { 99, 97, 112, 110, 112, 47, 99, 111, 109, 112, 97, 116, 47, 106, 115, 111, 110, 46, 99, 97, 112, 110, 112, 58, - 74, 115, 111, 110, 86, 97, 108, 117, - 101, 46, 70, 105, 101, 108, 100, 0, + 86, 97, 108, 117, 101, 46, 70, 105, + 101, 108, 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 8, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -203,33 +217,33 @@ static const ::capnp::_::AlignedData<49> b_c27855d853a937cc = { 0, 0, 0, 0, 0, 0, 0, 0, 118, 97, 108, 117, 101, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, - 114, 218, 44, 133, 170, 255, 37, 136, + 131, 221, 25, 249, 69, 120, 250, 163, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; -::capnp::word const* const bp_c27855d853a937cc = b_c27855d853a937cc.words; +::capnp::word const* const bp_e31026e735d69ddf = b_e31026e735d69ddf.words; #if !CAPNP_LITE -static const ::capnp::_::RawSchema* const d_c27855d853a937cc[] = { - &s_8825ffaa852cda72, +static const ::capnp::_::RawSchema* const d_e31026e735d69ddf[] = { + &s_a3fa7845f919dd83, }; -static const uint16_t m_c27855d853a937cc[] = {0, 1}; -static const uint16_t i_c27855d853a937cc[] = {0, 1}; -const ::capnp::_::RawSchema s_c27855d853a937cc = { - 0xc27855d853a937cc, b_c27855d853a937cc.words, 49, d_c27855d853a937cc, m_c27855d853a937cc, - 1, 2, i_c27855d853a937cc, nullptr, nullptr, { &s_c27855d853a937cc, nullptr, nullptr, 0, 0, nullptr } +static const uint16_t m_e31026e735d69ddf[] = {0, 1}; +static const uint16_t i_e31026e735d69ddf[] = {0, 1}; +const ::capnp::_::RawSchema s_e31026e735d69ddf = { + 0xe31026e735d69ddf, b_e31026e735d69ddf.words, 49, d_e31026e735d69ddf, m_e31026e735d69ddf, + 1, 2, i_e31026e735d69ddf, nullptr, nullptr, { &s_e31026e735d69ddf, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<54> b_9bbf84153dd4bb60 = { +static const ::capnp::_::AlignedData<54> b_a0d9f6eca1c93d48 = { { 0, 0, 0, 0, 5, 0, 6, 0, - 96, 187, 212, 61, 21, 132, 191, 155, - 34, 0, 0, 0, 1, 0, 0, 0, - 114, 218, 44, 133, 170, 255, 37, 136, + 72, 61, 201, 161, 236, 246, 217, 160, + 30, 0, 0, 0, 1, 0, 0, 0, + 131, 221, 25, 249, 69, 120, 250, 163, 2, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 21, 0, 0, 0, 58, 1, 0, 0, + 21, 0, 0, 0, 26, 1, 0, 0, 37, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 33, 0, 0, 0, 119, 0, 0, 0, @@ -238,8 +252,8 @@ static const ::capnp::_::AlignedData<54> b_9bbf84153dd4bb60 = { 99, 97, 112, 110, 112, 47, 99, 111, 109, 112, 97, 116, 47, 106, 115, 111, 110, 46, 99, 97, 112, 110, 112, 58, - 74, 115, 111, 110, 86, 97, 108, 117, - 101, 46, 67, 97, 108, 108, 0, 0, + 86, 97, 108, 117, 101, 46, 67, 97, + 108, 108, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 8, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -271,23 +285,313 @@ static const ::capnp::_::AlignedData<54> b_9bbf84153dd4bb60 = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 16, 0, 0, 0, 0, 0, 0, 0, - 114, 218, 44, 133, 170, 255, 37, 136, + 131, 221, 25, 249, 69, 120, 250, 163, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; -::capnp::word const* const bp_9bbf84153dd4bb60 = b_9bbf84153dd4bb60.words; +::capnp::word const* const bp_a0d9f6eca1c93d48 = b_a0d9f6eca1c93d48.words; +#if !CAPNP_LITE +static const ::capnp::_::RawSchema* const d_a0d9f6eca1c93d48[] = { + &s_a3fa7845f919dd83, +}; +static const uint16_t m_a0d9f6eca1c93d48[] = {0, 1}; +static const uint16_t i_a0d9f6eca1c93d48[] = {0, 1}; +const ::capnp::_::RawSchema s_a0d9f6eca1c93d48 = { + 0xa0d9f6eca1c93d48, b_a0d9f6eca1c93d48.words, 54, d_a0d9f6eca1c93d48, m_a0d9f6eca1c93d48, + 1, 2, i_a0d9f6eca1c93d48, nullptr, nullptr, { &s_a0d9f6eca1c93d48, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<21> b_fa5b1fd61c2e7c3d = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 61, 124, 46, 28, 214, 31, 91, 250, + 24, 0, 0, 0, 5, 0, 232, 2, + 52, 94, 58, 164, 151, 146, 249, 142, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 234, 0, 0, 0, + 33, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 28, 0, 0, 0, 3, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 111, + 109, 112, 97, 116, 47, 106, 115, 111, + 110, 46, 99, 97, 112, 110, 112, 58, + 110, 97, 109, 101, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_fa5b1fd61c2e7c3d = b_fa5b1fd61c2e7c3d.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_fa5b1fd61c2e7c3d = { + 0xfa5b1fd61c2e7c3d, b_fa5b1fd61c2e7c3d.words, 21, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_fa5b1fd61c2e7c3d, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<21> b_82d3e852af0336bf = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 191, 54, 3, 175, 82, 232, 211, 130, + 24, 0, 0, 0, 5, 0, 224, 0, + 52, 94, 58, 164, 151, 146, 249, 142, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 2, 1, 0, 0, + 33, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 28, 0, 0, 0, 3, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 111, + 109, 112, 97, 116, 47, 106, 115, 111, + 110, 46, 99, 97, 112, 110, 112, 58, + 102, 108, 97, 116, 116, 101, 110, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 16, 0, 0, 0, 0, 0, 0, 0, + 97, 234, 194, 123, 37, 19, 223, 196, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_82d3e852af0336bf = b_82d3e852af0336bf.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_82d3e852af0336bf = { + 0x82d3e852af0336bf, b_82d3e852af0336bf.words, 21, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_82d3e852af0336bf, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<35> b_c4df13257bc2ea61 = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 97, 234, 194, 123, 37, 19, 223, 196, + 24, 0, 0, 0, 1, 0, 0, 0, + 52, 94, 58, 164, 151, 146, 249, 142, + 1, 0, 7, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 58, 1, 0, 0, + 37, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 33, 0, 0, 0, 63, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 111, + 109, 112, 97, 116, 47, 106, 115, 111, + 110, 46, 99, 97, 112, 110, 112, 58, + 70, 108, 97, 116, 116, 101, 110, 79, + 112, 116, 105, 111, 110, 115, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 4, 0, 0, 0, 3, 0, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 13, 0, 0, 0, 58, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 8, 0, 0, 0, 3, 0, 1, 0, + 20, 0, 0, 0, 2, 0, 1, 0, + 112, 114, 101, 102, 105, 120, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 10, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_c4df13257bc2ea61 = b_c4df13257bc2ea61.words; +#if !CAPNP_LITE +static const uint16_t m_c4df13257bc2ea61[] = {0}; +static const uint16_t i_c4df13257bc2ea61[] = {0}; +const ::capnp::_::RawSchema s_c4df13257bc2ea61 = { + 0xc4df13257bc2ea61, b_c4df13257bc2ea61.words, 35, nullptr, m_c4df13257bc2ea61, + 0, 1, i_c4df13257bc2ea61, nullptr, nullptr, { &s_c4df13257bc2ea61, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<22> b_cfa794e8d19a0162 = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 98, 1, 154, 209, 232, 148, 167, 207, + 24, 0, 0, 0, 5, 0, 80, 0, + 52, 94, 58, 164, 151, 146, 249, 142, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 50, 1, 0, 0, + 37, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 32, 0, 0, 0, 3, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 111, + 109, 112, 97, 116, 47, 106, 115, 111, + 110, 46, 99, 97, 112, 110, 112, 58, + 100, 105, 115, 99, 114, 105, 109, 105, + 110, 97, 116, 111, 114, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 16, 0, 0, 0, 0, 0, 0, 0, + 25, 83, 62, 41, 12, 194, 248, 194, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_cfa794e8d19a0162 = b_cfa794e8d19a0162.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_cfa794e8d19a0162 = { + 0xcfa794e8d19a0162, b_cfa794e8d19a0162.words, 22, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_cfa794e8d19a0162, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<51> b_c2f8c20c293e5319 = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 25, 83, 62, 41, 12, 194, 248, 194, + 24, 0, 0, 0, 1, 0, 0, 0, + 52, 94, 58, 164, 151, 146, 249, 142, + 2, 0, 7, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 106, 1, 0, 0, + 41, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 37, 0, 0, 0, 119, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 111, + 109, 112, 97, 116, 47, 106, 115, 111, + 110, 46, 99, 97, 112, 110, 112, 58, + 68, 105, 115, 99, 114, 105, 109, 105, + 110, 97, 116, 111, 114, 79, 112, 116, + 105, 111, 110, 115, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 8, 0, 0, 0, 3, 0, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 41, 0, 0, 0, 42, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 36, 0, 0, 0, 3, 0, 1, 0, + 48, 0, 0, 0, 2, 0, 1, 0, + 1, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 1, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 45, 0, 0, 0, 82, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 44, 0, 0, 0, 3, 0, 1, 0, + 56, 0, 0, 0, 2, 0, 1, 0, + 110, 97, 109, 101, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 118, 97, 108, 117, 101, 78, 97, 109, + 101, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_c2f8c20c293e5319 = b_c2f8c20c293e5319.words; #if !CAPNP_LITE -static const ::capnp::_::RawSchema* const d_9bbf84153dd4bb60[] = { - &s_8825ffaa852cda72, +static const uint16_t m_c2f8c20c293e5319[] = {0, 1}; +static const uint16_t i_c2f8c20c293e5319[] = {0, 1}; +const ::capnp::_::RawSchema s_c2f8c20c293e5319 = { + 0xc2f8c20c293e5319, b_c2f8c20c293e5319.words, 51, nullptr, m_c2f8c20c293e5319, + 0, 2, i_c2f8c20c293e5319, nullptr, nullptr, { &s_c2f8c20c293e5319, nullptr, nullptr, 0, 0, nullptr }, false }; -static const uint16_t m_9bbf84153dd4bb60[] = {0, 1}; -static const uint16_t i_9bbf84153dd4bb60[] = {0, 1}; -const ::capnp::_::RawSchema s_9bbf84153dd4bb60 = { - 0x9bbf84153dd4bb60, b_9bbf84153dd4bb60.words, 54, d_9bbf84153dd4bb60, m_9bbf84153dd4bb60, - 1, 2, i_9bbf84153dd4bb60, nullptr, nullptr, { &s_9bbf84153dd4bb60, nullptr, nullptr, 0, 0, nullptr } +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<21> b_d7d879450a253e4b = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 75, 62, 37, 10, 69, 121, 216, 215, + 24, 0, 0, 0, 5, 0, 32, 0, + 52, 94, 58, 164, 151, 146, 249, 142, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 250, 0, 0, 0, + 33, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 28, 0, 0, 0, 3, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 111, + 109, 112, 97, 116, 47, 106, 115, 111, + 110, 46, 99, 97, 112, 110, 112, 58, + 98, 97, 115, 101, 54, 52, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_d7d879450a253e4b = b_d7d879450a253e4b.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_d7d879450a253e4b = { + 0xd7d879450a253e4b, b_d7d879450a253e4b.words, 21, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_d7d879450a253e4b, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<21> b_f061e22f0ae5c7b5 = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 181, 199, 229, 10, 47, 226, 97, 240, + 24, 0, 0, 0, 5, 0, 32, 0, + 52, 94, 58, 164, 151, 146, 249, 142, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 226, 0, 0, 0, + 33, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 28, 0, 0, 0, 3, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 111, + 109, 112, 97, 116, 47, 106, 115, 111, + 110, 46, 99, 97, 112, 110, 112, 58, + 104, 101, 120, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_f061e22f0ae5c7b5 = b_f061e22f0ae5c7b5.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_f061e22f0ae5c7b5 = { + 0xf061e22f0ae5c7b5, b_f061e22f0ae5c7b5.words, 21, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_f061e22f0ae5c7b5, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<22> b_a0a054dea32fd98c = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 140, 217, 47, 163, 222, 84, 160, 160, + 24, 0, 0, 0, 5, 0, 0, 2, + 52, 94, 58, 164, 151, 146, 249, 142, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 42, 1, 0, 0, + 37, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 32, 0, 0, 0, 3, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 99, 111, + 109, 112, 97, 116, 47, 106, 115, 111, + 110, 46, 99, 97, 112, 110, 112, 58, + 110, 111, 116, 105, 102, 105, 99, 97, + 116, 105, 111, 110, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_a0a054dea32fd98c = b_a0a054dea32fd98c.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_a0a054dea32fd98c = { + 0xa0a054dea32fd98c, b_a0a054dea32fd98c.words, 22, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_a0a054dea32fd98c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -296,31 +600,69 @@ const ::capnp::_::RawSchema s_9bbf84153dd4bb60 = { // ======================================================================================= namespace capnp { +namespace json { -// JsonValue -constexpr uint16_t JsonValue::_capnpPrivate::dataWordSize; -constexpr uint16_t JsonValue::_capnpPrivate::pointerCount; +// Value +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr uint16_t Value::_capnpPrivate::dataWordSize; +constexpr uint16_t Value::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE -constexpr ::capnp::Kind JsonValue::_capnpPrivate::kind; -constexpr ::capnp::_::RawSchema const* JsonValue::_capnpPrivate::schema; +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr ::capnp::Kind Value::_capnpPrivate::kind; +constexpr ::capnp::_::RawSchema const* Value::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE -// JsonValue::Field -constexpr uint16_t JsonValue::Field::_capnpPrivate::dataWordSize; -constexpr uint16_t JsonValue::Field::_capnpPrivate::pointerCount; +// Value::Field +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr uint16_t Value::Field::_capnpPrivate::dataWordSize; +constexpr uint16_t Value::Field::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE -constexpr ::capnp::Kind JsonValue::Field::_capnpPrivate::kind; -constexpr ::capnp::_::RawSchema const* JsonValue::Field::_capnpPrivate::schema; +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr ::capnp::Kind Value::Field::_capnpPrivate::kind; +constexpr ::capnp::_::RawSchema const* Value::Field::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE -// JsonValue::Call -constexpr uint16_t JsonValue::Call::_capnpPrivate::dataWordSize; -constexpr uint16_t JsonValue::Call::_capnpPrivate::pointerCount; +// Value::Call +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr uint16_t Value::Call::_capnpPrivate::dataWordSize; +constexpr uint16_t Value::Call::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE -constexpr ::capnp::Kind JsonValue::Call::_capnpPrivate::kind; -constexpr ::capnp::_::RawSchema const* JsonValue::Call::_capnpPrivate::schema; +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr ::capnp::Kind Value::Call::_capnpPrivate::kind; +constexpr ::capnp::_::RawSchema const* Value::Call::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#endif // !CAPNP_LITE + +// FlattenOptions +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr uint16_t FlattenOptions::_capnpPrivate::dataWordSize; +constexpr uint16_t FlattenOptions::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr ::capnp::Kind FlattenOptions::_capnpPrivate::kind; +constexpr ::capnp::_::RawSchema const* FlattenOptions::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#endif // !CAPNP_LITE + +// DiscriminatorOptions +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr uint16_t DiscriminatorOptions::_capnpPrivate::dataWordSize; +constexpr uint16_t DiscriminatorOptions::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr ::capnp::Kind DiscriminatorOptions::_capnpPrivate::kind; +constexpr ::capnp::_::RawSchema const* DiscriminatorOptions::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE } // namespace +} // namespace diff --git a/c++/src/capnp/compat/json.capnp.h b/c++/src/capnp/compat/json.capnp.h index a8877e540b..553347264f 100644 --- a/c++/src/capnp/compat/json.capnp.h +++ b/c++/src/capnp/compat/json.capnp.h @@ -1,33 +1,46 @@ // Generated by Cap'n Proto compiler, DO NOT EDIT // source: json.capnp -#ifndef CAPNP_INCLUDED_8ef99297a43a5e34_ -#define CAPNP_INCLUDED_8ef99297a43a5e34_ +#pragma once #include +#include #if !CAPNP_LITE #include #endif // !CAPNP_LITE -#if CAPNP_VERSION != 6001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace schemas { -CAPNP_DECLARE_SCHEMA(8825ffaa852cda72); -CAPNP_DECLARE_SCHEMA(c27855d853a937cc); -CAPNP_DECLARE_SCHEMA(9bbf84153dd4bb60); +CAPNP_DECLARE_SCHEMA(a3fa7845f919dd83); +CAPNP_DECLARE_SCHEMA(e31026e735d69ddf); +CAPNP_DECLARE_SCHEMA(a0d9f6eca1c93d48); +CAPNP_DECLARE_SCHEMA(fa5b1fd61c2e7c3d); +CAPNP_DECLARE_SCHEMA(82d3e852af0336bf); +CAPNP_DECLARE_SCHEMA(c4df13257bc2ea61); +CAPNP_DECLARE_SCHEMA(cfa794e8d19a0162); +CAPNP_DECLARE_SCHEMA(c2f8c20c293e5319); +CAPNP_DECLARE_SCHEMA(d7d879450a253e4b); +CAPNP_DECLARE_SCHEMA(f061e22f0ae5c7b5); +CAPNP_DECLARE_SCHEMA(a0a054dea32fd98c); } // namespace schemas } // namespace capnp namespace capnp { +namespace json { -struct JsonValue { - JsonValue() = delete; +struct Value { + Value() = delete; class Reader; class Builder; @@ -40,19 +53,20 @@ struct JsonValue { ARRAY, OBJECT, CALL, + RAW, }; struct Field; struct Call; struct _capnpPrivate { - CAPNP_DECLARE_STRUCT_HEADER(8825ffaa852cda72, 2, 1) + CAPNP_DECLARE_STRUCT_HEADER(a3fa7845f919dd83, 2, 1) #if !CAPNP_LITE static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } #endif // !CAPNP_LITE }; }; -struct JsonValue::Field { +struct Value::Field { Field() = delete; class Reader; @@ -60,14 +74,14 @@ struct JsonValue::Field { class Pipeline; struct _capnpPrivate { - CAPNP_DECLARE_STRUCT_HEADER(c27855d853a937cc, 0, 2) + CAPNP_DECLARE_STRUCT_HEADER(e31026e735d69ddf, 0, 2) #if !CAPNP_LITE static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } #endif // !CAPNP_LITE }; }; -struct JsonValue::Call { +struct Value::Call { Call() = delete; class Reader; @@ -75,7 +89,37 @@ struct JsonValue::Call { class Pipeline; struct _capnpPrivate { - CAPNP_DECLARE_STRUCT_HEADER(9bbf84153dd4bb60, 0, 2) + CAPNP_DECLARE_STRUCT_HEADER(a0d9f6eca1c93d48, 0, 2) + #if !CAPNP_LITE + static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } + #endif // !CAPNP_LITE + }; +}; + +struct FlattenOptions { + FlattenOptions() = delete; + + class Reader; + class Builder; + class Pipeline; + + struct _capnpPrivate { + CAPNP_DECLARE_STRUCT_HEADER(c4df13257bc2ea61, 0, 1) + #if !CAPNP_LITE + static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } + #endif // !CAPNP_LITE + }; +}; + +struct DiscriminatorOptions { + DiscriminatorOptions() = delete; + + class Reader; + class Builder; + class Pipeline; + + struct _capnpPrivate { + CAPNP_DECLARE_STRUCT_HEADER(c2f8c20c293e5319, 0, 2) #if !CAPNP_LITE static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } #endif // !CAPNP_LITE @@ -84,9 +128,9 @@ struct JsonValue::Call { // ======================================================================================= -class JsonValue::Reader { +class Value::Reader { public: - typedef JsonValue Reads; + typedef Value Reads; Reader() = default; inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} @@ -117,15 +161,19 @@ class JsonValue::Reader { inline bool isArray() const; inline bool hasArray() const; - inline ::capnp::List< ::capnp::JsonValue>::Reader getArray() const; + inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Reader getArray() const; inline bool isObject() const; inline bool hasObject() const; - inline ::capnp::List< ::capnp::JsonValue::Field>::Reader getObject() const; + inline ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>::Reader getObject() const; inline bool isCall() const; inline bool hasCall() const; - inline ::capnp::JsonValue::Call::Reader getCall() const; + inline ::capnp::json::Value::Call::Reader getCall() const; + + inline bool isRaw() const; + inline bool hasRaw() const; + inline ::capnp::Text::Reader getRaw() const; private: ::capnp::_::StructReader _reader; @@ -139,9 +187,9 @@ class JsonValue::Reader { friend class ::capnp::Orphanage; }; -class JsonValue::Builder { +class Value::Builder { public: - typedef JsonValue Builds; + typedef Value Builds; Builder() = delete; // Deleted to discourage incorrect usage. // You can explicitly initialize to nullptr instead. @@ -178,27 +226,35 @@ class JsonValue::Builder { inline bool isArray(); inline bool hasArray(); - inline ::capnp::List< ::capnp::JsonValue>::Builder getArray(); - inline void setArray( ::capnp::List< ::capnp::JsonValue>::Reader value); - inline ::capnp::List< ::capnp::JsonValue>::Builder initArray(unsigned int size); - inline void adoptArray(::capnp::Orphan< ::capnp::List< ::capnp::JsonValue>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue>> disownArray(); + inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Builder getArray(); + inline void setArray( ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Builder initArray(unsigned int size); + inline void adoptArray(::capnp::Orphan< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>> disownArray(); inline bool isObject(); inline bool hasObject(); - inline ::capnp::List< ::capnp::JsonValue::Field>::Builder getObject(); - inline void setObject( ::capnp::List< ::capnp::JsonValue::Field>::Reader value); - inline ::capnp::List< ::capnp::JsonValue::Field>::Builder initObject(unsigned int size); - inline void adoptObject(::capnp::Orphan< ::capnp::List< ::capnp::JsonValue::Field>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue::Field>> disownObject(); + inline ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>::Builder getObject(); + inline void setObject( ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>::Builder initObject(unsigned int size); + inline void adoptObject(::capnp::Orphan< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>> disownObject(); inline bool isCall(); inline bool hasCall(); - inline ::capnp::JsonValue::Call::Builder getCall(); - inline void setCall( ::capnp::JsonValue::Call::Reader value); - inline ::capnp::JsonValue::Call::Builder initCall(); - inline void adoptCall(::capnp::Orphan< ::capnp::JsonValue::Call>&& value); - inline ::capnp::Orphan< ::capnp::JsonValue::Call> disownCall(); + inline ::capnp::json::Value::Call::Builder getCall(); + inline void setCall( ::capnp::json::Value::Call::Reader value); + inline ::capnp::json::Value::Call::Builder initCall(); + inline void adoptCall(::capnp::Orphan< ::capnp::json::Value::Call>&& value); + inline ::capnp::Orphan< ::capnp::json::Value::Call> disownCall(); + + inline bool isRaw(); + inline bool hasRaw(); + inline ::capnp::Text::Builder getRaw(); + inline void setRaw( ::capnp::Text::Reader value); + inline ::capnp::Text::Builder initRaw(unsigned int size); + inline void adoptRaw(::capnp::Orphan< ::capnp::Text>&& value); + inline ::capnp::Orphan< ::capnp::Text> disownRaw(); private: ::capnp::_::StructBuilder _builder; @@ -210,9 +266,9 @@ class JsonValue::Builder { }; #if !CAPNP_LITE -class JsonValue::Pipeline { +class Value::Pipeline { public: - typedef JsonValue Pipelines; + typedef Value Pipelines; inline Pipeline(decltype(nullptr)): _typeless(nullptr) {} inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) @@ -226,7 +282,7 @@ class JsonValue::Pipeline { }; #endif // !CAPNP_LITE -class JsonValue::Field::Reader { +class Value::Field::Reader { public: typedef Field Reads; @@ -247,7 +303,7 @@ class JsonValue::Field::Reader { inline ::capnp::Text::Reader getName() const; inline bool hasValue() const; - inline ::capnp::JsonValue::Reader getValue() const; + inline ::capnp::json::Value::Reader getValue() const; private: ::capnp::_::StructReader _reader; @@ -261,7 +317,7 @@ class JsonValue::Field::Reader { friend class ::capnp::Orphanage; }; -class JsonValue::Field::Builder { +class Value::Field::Builder { public: typedef Field Builds; @@ -285,11 +341,11 @@ class JsonValue::Field::Builder { inline ::capnp::Orphan< ::capnp::Text> disownName(); inline bool hasValue(); - inline ::capnp::JsonValue::Builder getValue(); - inline void setValue( ::capnp::JsonValue::Reader value); - inline ::capnp::JsonValue::Builder initValue(); - inline void adoptValue(::capnp::Orphan< ::capnp::JsonValue>&& value); - inline ::capnp::Orphan< ::capnp::JsonValue> disownValue(); + inline ::capnp::json::Value::Builder getValue(); + inline void setValue( ::capnp::json::Value::Reader value); + inline ::capnp::json::Value::Builder initValue(); + inline void adoptValue(::capnp::Orphan< ::capnp::json::Value>&& value); + inline ::capnp::Orphan< ::capnp::json::Value> disownValue(); private: ::capnp::_::StructBuilder _builder; @@ -301,7 +357,7 @@ class JsonValue::Field::Builder { }; #if !CAPNP_LITE -class JsonValue::Field::Pipeline { +class Value::Field::Pipeline { public: typedef Field Pipelines; @@ -309,7 +365,7 @@ class JsonValue::Field::Pipeline { inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) : _typeless(kj::mv(typeless)) {} - inline ::capnp::JsonValue::Pipeline getValue(); + inline ::capnp::json::Value::Pipeline getValue(); private: ::capnp::AnyPointer::Pipeline _typeless; friend class ::capnp::PipelineHook; @@ -318,7 +374,7 @@ class JsonValue::Field::Pipeline { }; #endif // !CAPNP_LITE -class JsonValue::Call::Reader { +class Value::Call::Reader { public: typedef Call Reads; @@ -339,7 +395,7 @@ class JsonValue::Call::Reader { inline ::capnp::Text::Reader getFunction() const; inline bool hasParams() const; - inline ::capnp::List< ::capnp::JsonValue>::Reader getParams() const; + inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Reader getParams() const; private: ::capnp::_::StructReader _reader; @@ -353,7 +409,7 @@ class JsonValue::Call::Reader { friend class ::capnp::Orphanage; }; -class JsonValue::Call::Builder { +class Value::Call::Builder { public: typedef Call Builds; @@ -377,11 +433,11 @@ class JsonValue::Call::Builder { inline ::capnp::Orphan< ::capnp::Text> disownFunction(); inline bool hasParams(); - inline ::capnp::List< ::capnp::JsonValue>::Builder getParams(); - inline void setParams( ::capnp::List< ::capnp::JsonValue>::Reader value); - inline ::capnp::List< ::capnp::JsonValue>::Builder initParams(unsigned int size); - inline void adoptParams(::capnp::Orphan< ::capnp::List< ::capnp::JsonValue>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue>> disownParams(); + inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Builder getParams(); + inline void setParams( ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Builder initParams(unsigned int size); + inline void adoptParams(::capnp::Orphan< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>> disownParams(); private: ::capnp::_::StructBuilder _builder; @@ -393,7 +449,7 @@ class JsonValue::Call::Builder { }; #if !CAPNP_LITE -class JsonValue::Call::Pipeline { +class Value::Call::Pipeline { public: typedef Call Pipelines; @@ -409,452 +465,784 @@ class JsonValue::Call::Pipeline { }; #endif // !CAPNP_LITE +class FlattenOptions::Reader { +public: + typedef FlattenOptions Reads; + + Reader() = default; + inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} + + inline ::capnp::MessageSize totalSize() const { + return _reader.totalSize().asPublic(); + } + +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { + return ::capnp::_::structString(_reader, *_capnpPrivate::brand()); + } +#endif // !CAPNP_LITE + + inline bool hasPrefix() const; + inline ::capnp::Text::Reader getPrefix() const; + +private: + ::capnp::_::StructReader _reader; + template + friend struct ::capnp::ToDynamic_; + template + friend struct ::capnp::_::PointerHelpers; + template + friend struct ::capnp::List; + friend class ::capnp::MessageBuilder; + friend class ::capnp::Orphanage; +}; + +class FlattenOptions::Builder { +public: + typedef FlattenOptions Builds; + + Builder() = delete; // Deleted to discourage incorrect usage. + // You can explicitly initialize to nullptr instead. + inline Builder(decltype(nullptr)) {} + inline explicit Builder(::capnp::_::StructBuilder base): _builder(base) {} + inline operator Reader() const { return Reader(_builder.asReader()); } + inline Reader asReader() const { return *this; } + + inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); } +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { return asReader().toString(); } +#endif // !CAPNP_LITE + + inline bool hasPrefix(); + inline ::capnp::Text::Builder getPrefix(); + inline void setPrefix( ::capnp::Text::Reader value); + inline ::capnp::Text::Builder initPrefix(unsigned int size); + inline void adoptPrefix(::capnp::Orphan< ::capnp::Text>&& value); + inline ::capnp::Orphan< ::capnp::Text> disownPrefix(); + +private: + ::capnp::_::StructBuilder _builder; + template + friend struct ::capnp::ToDynamic_; + friend class ::capnp::Orphanage; + template + friend struct ::capnp::_::PointerHelpers; +}; + +#if !CAPNP_LITE +class FlattenOptions::Pipeline { +public: + typedef FlattenOptions Pipelines; + + inline Pipeline(decltype(nullptr)): _typeless(nullptr) {} + inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) + : _typeless(kj::mv(typeless)) {} + +private: + ::capnp::AnyPointer::Pipeline _typeless; + friend class ::capnp::PipelineHook; + template + friend struct ::capnp::ToDynamic_; +}; +#endif // !CAPNP_LITE + +class DiscriminatorOptions::Reader { +public: + typedef DiscriminatorOptions Reads; + + Reader() = default; + inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} + + inline ::capnp::MessageSize totalSize() const { + return _reader.totalSize().asPublic(); + } + +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { + return ::capnp::_::structString(_reader, *_capnpPrivate::brand()); + } +#endif // !CAPNP_LITE + + inline bool hasName() const; + inline ::capnp::Text::Reader getName() const; + + inline bool hasValueName() const; + inline ::capnp::Text::Reader getValueName() const; + +private: + ::capnp::_::StructReader _reader; + template + friend struct ::capnp::ToDynamic_; + template + friend struct ::capnp::_::PointerHelpers; + template + friend struct ::capnp::List; + friend class ::capnp::MessageBuilder; + friend class ::capnp::Orphanage; +}; + +class DiscriminatorOptions::Builder { +public: + typedef DiscriminatorOptions Builds; + + Builder() = delete; // Deleted to discourage incorrect usage. + // You can explicitly initialize to nullptr instead. + inline Builder(decltype(nullptr)) {} + inline explicit Builder(::capnp::_::StructBuilder base): _builder(base) {} + inline operator Reader() const { return Reader(_builder.asReader()); } + inline Reader asReader() const { return *this; } + + inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); } +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { return asReader().toString(); } +#endif // !CAPNP_LITE + + inline bool hasName(); + inline ::capnp::Text::Builder getName(); + inline void setName( ::capnp::Text::Reader value); + inline ::capnp::Text::Builder initName(unsigned int size); + inline void adoptName(::capnp::Orphan< ::capnp::Text>&& value); + inline ::capnp::Orphan< ::capnp::Text> disownName(); + + inline bool hasValueName(); + inline ::capnp::Text::Builder getValueName(); + inline void setValueName( ::capnp::Text::Reader value); + inline ::capnp::Text::Builder initValueName(unsigned int size); + inline void adoptValueName(::capnp::Orphan< ::capnp::Text>&& value); + inline ::capnp::Orphan< ::capnp::Text> disownValueName(); + +private: + ::capnp::_::StructBuilder _builder; + template + friend struct ::capnp::ToDynamic_; + friend class ::capnp::Orphanage; + template + friend struct ::capnp::_::PointerHelpers; +}; + +#if !CAPNP_LITE +class DiscriminatorOptions::Pipeline { +public: + typedef DiscriminatorOptions Pipelines; + + inline Pipeline(decltype(nullptr)): _typeless(nullptr) {} + inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) + : _typeless(kj::mv(typeless)) {} + +private: + ::capnp::AnyPointer::Pipeline _typeless; + friend class ::capnp::PipelineHook; + template + friend struct ::capnp::ToDynamic_; +}; +#endif // !CAPNP_LITE + // ======================================================================================= -inline ::capnp::JsonValue::Which JsonValue::Reader::which() const { +inline ::capnp::json::Value::Which Value::Reader::which() const { return _reader.getDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS); } -inline ::capnp::JsonValue::Which JsonValue::Builder::which() { +inline ::capnp::json::Value::Which Value::Builder::which() { return _builder.getDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS); } -inline bool JsonValue::Reader::isNull() const { - return which() == JsonValue::NULL_; +inline bool Value::Reader::isNull() const { + return which() == Value::NULL_; } -inline bool JsonValue::Builder::isNull() { - return which() == JsonValue::NULL_; +inline bool Value::Builder::isNull() { + return which() == Value::NULL_; } -inline ::capnp::Void JsonValue::Reader::getNull() const { - KJ_IREQUIRE((which() == JsonValue::NULL_), +inline ::capnp::Void Value::Reader::getNull() const { + KJ_IREQUIRE((which() == Value::NULL_), "Must check which() before get()ing a union member."); return _reader.getDataField< ::capnp::Void>( ::capnp::bounded<0>() * ::capnp::ELEMENTS); } -inline ::capnp::Void JsonValue::Builder::getNull() { - KJ_IREQUIRE((which() == JsonValue::NULL_), +inline ::capnp::Void Value::Builder::getNull() { + KJ_IREQUIRE((which() == Value::NULL_), "Must check which() before get()ing a union member."); return _builder.getDataField< ::capnp::Void>( ::capnp::bounded<0>() * ::capnp::ELEMENTS); } -inline void JsonValue::Builder::setNull( ::capnp::Void value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::NULL_); +inline void Value::Builder::setNull( ::capnp::Void value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::NULL_); _builder.setDataField< ::capnp::Void>( ::capnp::bounded<0>() * ::capnp::ELEMENTS, value); } -inline bool JsonValue::Reader::isBoolean() const { - return which() == JsonValue::BOOLEAN; +inline bool Value::Reader::isBoolean() const { + return which() == Value::BOOLEAN; } -inline bool JsonValue::Builder::isBoolean() { - return which() == JsonValue::BOOLEAN; +inline bool Value::Builder::isBoolean() { + return which() == Value::BOOLEAN; } -inline bool JsonValue::Reader::getBoolean() const { - KJ_IREQUIRE((which() == JsonValue::BOOLEAN), +inline bool Value::Reader::getBoolean() const { + KJ_IREQUIRE((which() == Value::BOOLEAN), "Must check which() before get()ing a union member."); return _reader.getDataField( ::capnp::bounded<16>() * ::capnp::ELEMENTS); } -inline bool JsonValue::Builder::getBoolean() { - KJ_IREQUIRE((which() == JsonValue::BOOLEAN), +inline bool Value::Builder::getBoolean() { + KJ_IREQUIRE((which() == Value::BOOLEAN), "Must check which() before get()ing a union member."); return _builder.getDataField( ::capnp::bounded<16>() * ::capnp::ELEMENTS); } -inline void JsonValue::Builder::setBoolean(bool value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::BOOLEAN); +inline void Value::Builder::setBoolean(bool value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::BOOLEAN); _builder.setDataField( ::capnp::bounded<16>() * ::capnp::ELEMENTS, value); } -inline bool JsonValue::Reader::isNumber() const { - return which() == JsonValue::NUMBER; +inline bool Value::Reader::isNumber() const { + return which() == Value::NUMBER; } -inline bool JsonValue::Builder::isNumber() { - return which() == JsonValue::NUMBER; +inline bool Value::Builder::isNumber() { + return which() == Value::NUMBER; } -inline double JsonValue::Reader::getNumber() const { - KJ_IREQUIRE((which() == JsonValue::NUMBER), +inline double Value::Reader::getNumber() const { + KJ_IREQUIRE((which() == Value::NUMBER), "Must check which() before get()ing a union member."); return _reader.getDataField( ::capnp::bounded<1>() * ::capnp::ELEMENTS); } -inline double JsonValue::Builder::getNumber() { - KJ_IREQUIRE((which() == JsonValue::NUMBER), +inline double Value::Builder::getNumber() { + KJ_IREQUIRE((which() == Value::NUMBER), "Must check which() before get()ing a union member."); return _builder.getDataField( ::capnp::bounded<1>() * ::capnp::ELEMENTS); } -inline void JsonValue::Builder::setNumber(double value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::NUMBER); +inline void Value::Builder::setNumber(double value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::NUMBER); _builder.setDataField( ::capnp::bounded<1>() * ::capnp::ELEMENTS, value); } -inline bool JsonValue::Reader::isString() const { - return which() == JsonValue::STRING; +inline bool Value::Reader::isString() const { + return which() == Value::STRING; } -inline bool JsonValue::Builder::isString() { - return which() == JsonValue::STRING; +inline bool Value::Builder::isString() { + return which() == Value::STRING; } -inline bool JsonValue::Reader::hasString() const { - if (which() != JsonValue::STRING) return false; +inline bool Value::Reader::hasString() const { + if (which() != Value::STRING) return false; return !_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline bool JsonValue::Builder::hasString() { - if (which() != JsonValue::STRING) return false; +inline bool Value::Builder::hasString() { + if (which() != Value::STRING) return false; return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::Text::Reader JsonValue::Reader::getString() const { - KJ_IREQUIRE((which() == JsonValue::STRING), +inline ::capnp::Text::Reader Value::Reader::getString() const { + KJ_IREQUIRE((which() == Value::STRING), "Must check which() before get()ing a union member."); return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::Text::Builder JsonValue::Builder::getString() { - KJ_IREQUIRE((which() == JsonValue::STRING), +inline ::capnp::Text::Builder Value::Builder::getString() { + KJ_IREQUIRE((which() == Value::STRING), "Must check which() before get()ing a union member."); return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void JsonValue::Builder::setString( ::capnp::Text::Reader value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::STRING); +inline void Value::Builder::setString( ::capnp::Text::Reader value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::STRING); ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::Text::Builder JsonValue::Builder::initString(unsigned int size) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::STRING); +inline ::capnp::Text::Builder Value::Builder::initString(unsigned int size) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::STRING); return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } -inline void JsonValue::Builder::adoptString( +inline void Value::Builder::adoptString( ::capnp::Orphan< ::capnp::Text>&& value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::STRING); + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::STRING); ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::Text> JsonValue::Builder::disownString() { - KJ_IREQUIRE((which() == JsonValue::STRING), +inline ::capnp::Orphan< ::capnp::Text> Value::Builder::disownString() { + KJ_IREQUIRE((which() == Value::STRING), "Must check which() before get()ing a union member."); return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline bool JsonValue::Reader::isArray() const { - return which() == JsonValue::ARRAY; +inline bool Value::Reader::isArray() const { + return which() == Value::ARRAY; } -inline bool JsonValue::Builder::isArray() { - return which() == JsonValue::ARRAY; +inline bool Value::Builder::isArray() { + return which() == Value::ARRAY; } -inline bool JsonValue::Reader::hasArray() const { - if (which() != JsonValue::ARRAY) return false; +inline bool Value::Reader::hasArray() const { + if (which() != Value::ARRAY) return false; return !_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline bool JsonValue::Builder::hasArray() { - if (which() != JsonValue::ARRAY) return false; +inline bool Value::Builder::hasArray() { + if (which() != Value::ARRAY) return false; return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::JsonValue>::Reader JsonValue::Reader::getArray() const { - KJ_IREQUIRE((which() == JsonValue::ARRAY), +inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Reader Value::Reader::getArray() const { + KJ_IREQUIRE((which() == Value::ARRAY), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::JsonValue>::Builder JsonValue::Builder::getArray() { - KJ_IREQUIRE((which() == JsonValue::ARRAY), +inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Builder Value::Builder::getArray() { + KJ_IREQUIRE((which() == Value::ARRAY), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void JsonValue::Builder::setArray( ::capnp::List< ::capnp::JsonValue>::Reader value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::ARRAY); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::set(_builder.getPointerField( +inline void Value::Builder::setArray( ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Reader value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::ARRAY); + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::JsonValue>::Builder JsonValue::Builder::initArray(unsigned int size) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::ARRAY); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Builder Value::Builder::initArray(unsigned int size) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::ARRAY); + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } -inline void JsonValue::Builder::adoptArray( - ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue>>&& value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::ARRAY); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::adopt(_builder.getPointerField( +inline void Value::Builder::adoptArray( + ::capnp::Orphan< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>&& value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::ARRAY); + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue>> JsonValue::Builder::disownArray() { - KJ_IREQUIRE((which() == JsonValue::ARRAY), +inline ::capnp::Orphan< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>> Value::Builder::disownArray() { + KJ_IREQUIRE((which() == Value::ARRAY), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline bool JsonValue::Reader::isObject() const { - return which() == JsonValue::OBJECT; +inline bool Value::Reader::isObject() const { + return which() == Value::OBJECT; } -inline bool JsonValue::Builder::isObject() { - return which() == JsonValue::OBJECT; +inline bool Value::Builder::isObject() { + return which() == Value::OBJECT; } -inline bool JsonValue::Reader::hasObject() const { - if (which() != JsonValue::OBJECT) return false; +inline bool Value::Reader::hasObject() const { + if (which() != Value::OBJECT) return false; return !_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline bool JsonValue::Builder::hasObject() { - if (which() != JsonValue::OBJECT) return false; +inline bool Value::Builder::hasObject() { + if (which() != Value::OBJECT) return false; return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::JsonValue::Field>::Reader JsonValue::Reader::getObject() const { - KJ_IREQUIRE((which() == JsonValue::OBJECT), +inline ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>::Reader Value::Reader::getObject() const { + KJ_IREQUIRE((which() == Value::OBJECT), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue::Field>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::JsonValue::Field>::Builder JsonValue::Builder::getObject() { - KJ_IREQUIRE((which() == JsonValue::OBJECT), +inline ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>::Builder Value::Builder::getObject() { + KJ_IREQUIRE((which() == Value::OBJECT), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue::Field>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void JsonValue::Builder::setObject( ::capnp::List< ::capnp::JsonValue::Field>::Reader value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::OBJECT); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue::Field>>::set(_builder.getPointerField( +inline void Value::Builder::setObject( ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>::Reader value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::OBJECT); + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::JsonValue::Field>::Builder JsonValue::Builder::initObject(unsigned int size) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::OBJECT); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue::Field>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>::Builder Value::Builder::initObject(unsigned int size) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::OBJECT); + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } -inline void JsonValue::Builder::adoptObject( - ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue::Field>>&& value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::OBJECT); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue::Field>>::adopt(_builder.getPointerField( +inline void Value::Builder::adoptObject( + ::capnp::Orphan< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>>&& value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::OBJECT); + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue::Field>> JsonValue::Builder::disownObject() { - KJ_IREQUIRE((which() == JsonValue::OBJECT), +inline ::capnp::Orphan< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>> Value::Builder::disownObject() { + KJ_IREQUIRE((which() == Value::OBJECT), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue::Field>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value::Field, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline bool JsonValue::Reader::isCall() const { - return which() == JsonValue::CALL; +inline bool Value::Reader::isCall() const { + return which() == Value::CALL; } -inline bool JsonValue::Builder::isCall() { - return which() == JsonValue::CALL; +inline bool Value::Builder::isCall() { + return which() == Value::CALL; } -inline bool JsonValue::Reader::hasCall() const { - if (which() != JsonValue::CALL) return false; +inline bool Value::Reader::hasCall() const { + if (which() != Value::CALL) return false; return !_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline bool JsonValue::Builder::hasCall() { - if (which() != JsonValue::CALL) return false; +inline bool Value::Builder::hasCall() { + if (which() != Value::CALL) return false; return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::JsonValue::Call::Reader JsonValue::Reader::getCall() const { - KJ_IREQUIRE((which() == JsonValue::CALL), +inline ::capnp::json::Value::Call::Reader Value::Reader::getCall() const { + KJ_IREQUIRE((which() == Value::CALL), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::JsonValue::Call>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::json::Value::Call>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::JsonValue::Call::Builder JsonValue::Builder::getCall() { - KJ_IREQUIRE((which() == JsonValue::CALL), +inline ::capnp::json::Value::Call::Builder Value::Builder::getCall() { + KJ_IREQUIRE((which() == Value::CALL), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::JsonValue::Call>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::json::Value::Call>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void JsonValue::Builder::setCall( ::capnp::JsonValue::Call::Reader value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::CALL); - ::capnp::_::PointerHelpers< ::capnp::JsonValue::Call>::set(_builder.getPointerField( +inline void Value::Builder::setCall( ::capnp::json::Value::Call::Reader value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::CALL); + ::capnp::_::PointerHelpers< ::capnp::json::Value::Call>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::JsonValue::Call::Builder JsonValue::Builder::initCall() { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::CALL); - return ::capnp::_::PointerHelpers< ::capnp::JsonValue::Call>::init(_builder.getPointerField( +inline ::capnp::json::Value::Call::Builder Value::Builder::initCall() { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::CALL); + return ::capnp::_::PointerHelpers< ::capnp::json::Value::Call>::init(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline void Value::Builder::adoptCall( + ::capnp::Orphan< ::capnp::json::Value::Call>&& value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::CALL); + ::capnp::_::PointerHelpers< ::capnp::json::Value::Call>::adopt(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::json::Value::Call> Value::Builder::disownCall() { + KJ_IREQUIRE((which() == Value::CALL), + "Must check which() before get()ing a union member."); + return ::capnp::_::PointerHelpers< ::capnp::json::Value::Call>::disown(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} + +inline bool Value::Reader::isRaw() const { + return which() == Value::RAW; +} +inline bool Value::Builder::isRaw() { + return which() == Value::RAW; +} +inline bool Value::Reader::hasRaw() const { + if (which() != Value::RAW) return false; + return !_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline bool Value::Builder::hasRaw() { + if (which() != Value::RAW) return false; + return !_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::Text::Reader Value::Reader::getRaw() const { + KJ_IREQUIRE((which() == Value::RAW), + "Must check which() before get()ing a union member."); + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void JsonValue::Builder::adoptCall( - ::capnp::Orphan< ::capnp::JsonValue::Call>&& value) { - _builder.setDataField( - ::capnp::bounded<0>() * ::capnp::ELEMENTS, JsonValue::CALL); - ::capnp::_::PointerHelpers< ::capnp::JsonValue::Call>::adopt(_builder.getPointerField( +inline ::capnp::Text::Builder Value::Builder::getRaw() { + KJ_IREQUIRE((which() == Value::RAW), + "Must check which() before get()ing a union member."); + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline void Value::Builder::setRaw( ::capnp::Text::Reader value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::RAW); + ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), value); +} +inline ::capnp::Text::Builder Value::Builder::initRaw(unsigned int size) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::RAW); + return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), size); +} +inline void Value::Builder::adoptRaw( + ::capnp::Orphan< ::capnp::Text>&& value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Value::RAW); + ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::JsonValue::Call> JsonValue::Builder::disownCall() { - KJ_IREQUIRE((which() == JsonValue::CALL), +inline ::capnp::Orphan< ::capnp::Text> Value::Builder::disownRaw() { + KJ_IREQUIRE((which() == Value::RAW), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::JsonValue::Call>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline bool JsonValue::Field::Reader::hasName() const { +inline bool Value::Field::Reader::hasName() const { return !_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline bool JsonValue::Field::Builder::hasName() { +inline bool Value::Field::Builder::hasName() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::Text::Reader JsonValue::Field::Reader::getName() const { +inline ::capnp::Text::Reader Value::Field::Reader::getName() const { return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::Text::Builder JsonValue::Field::Builder::getName() { +inline ::capnp::Text::Builder Value::Field::Builder::getName() { return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void JsonValue::Field::Builder::setName( ::capnp::Text::Reader value) { +inline void Value::Field::Builder::setName( ::capnp::Text::Reader value) { ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::Text::Builder JsonValue::Field::Builder::initName(unsigned int size) { +inline ::capnp::Text::Builder Value::Field::Builder::initName(unsigned int size) { return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } -inline void JsonValue::Field::Builder::adoptName( +inline void Value::Field::Builder::adoptName( ::capnp::Orphan< ::capnp::Text>&& value) { ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::Text> JsonValue::Field::Builder::disownName() { +inline ::capnp::Orphan< ::capnp::Text> Value::Field::Builder::disownName() { return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline bool JsonValue::Field::Reader::hasValue() const { +inline bool Value::Field::Reader::hasValue() const { return !_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline bool JsonValue::Field::Builder::hasValue() { +inline bool Value::Field::Builder::hasValue() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::JsonValue::Reader JsonValue::Field::Reader::getValue() const { - return ::capnp::_::PointerHelpers< ::capnp::JsonValue>::get(_reader.getPointerField( +inline ::capnp::json::Value::Reader Value::Field::Reader::getValue() const { + return ::capnp::_::PointerHelpers< ::capnp::json::Value>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::JsonValue::Builder JsonValue::Field::Builder::getValue() { - return ::capnp::_::PointerHelpers< ::capnp::JsonValue>::get(_builder.getPointerField( +inline ::capnp::json::Value::Builder Value::Field::Builder::getValue() { + return ::capnp::_::PointerHelpers< ::capnp::json::Value>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } #if !CAPNP_LITE -inline ::capnp::JsonValue::Pipeline JsonValue::Field::Pipeline::getValue() { - return ::capnp::JsonValue::Pipeline(_typeless.getPointerField(1)); +inline ::capnp::json::Value::Pipeline Value::Field::Pipeline::getValue() { + return ::capnp::json::Value::Pipeline(_typeless.getPointerField(1)); } #endif // !CAPNP_LITE -inline void JsonValue::Field::Builder::setValue( ::capnp::JsonValue::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::JsonValue>::set(_builder.getPointerField( +inline void Value::Field::Builder::setValue( ::capnp::json::Value::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::json::Value>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::JsonValue::Builder JsonValue::Field::Builder::initValue() { - return ::capnp::_::PointerHelpers< ::capnp::JsonValue>::init(_builder.getPointerField( +inline ::capnp::json::Value::Builder Value::Field::Builder::initValue() { + return ::capnp::_::PointerHelpers< ::capnp::json::Value>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void JsonValue::Field::Builder::adoptValue( - ::capnp::Orphan< ::capnp::JsonValue>&& value) { - ::capnp::_::PointerHelpers< ::capnp::JsonValue>::adopt(_builder.getPointerField( +inline void Value::Field::Builder::adoptValue( + ::capnp::Orphan< ::capnp::json::Value>&& value) { + ::capnp::_::PointerHelpers< ::capnp::json::Value>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::JsonValue> JsonValue::Field::Builder::disownValue() { - return ::capnp::_::PointerHelpers< ::capnp::JsonValue>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::json::Value> Value::Field::Builder::disownValue() { + return ::capnp::_::PointerHelpers< ::capnp::json::Value>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline bool JsonValue::Call::Reader::hasFunction() const { +inline bool Value::Call::Reader::hasFunction() const { return !_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline bool JsonValue::Call::Builder::hasFunction() { +inline bool Value::Call::Builder::hasFunction() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::Text::Reader JsonValue::Call::Reader::getFunction() const { +inline ::capnp::Text::Reader Value::Call::Reader::getFunction() const { return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::Text::Builder JsonValue::Call::Builder::getFunction() { +inline ::capnp::Text::Builder Value::Call::Builder::getFunction() { return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void JsonValue::Call::Builder::setFunction( ::capnp::Text::Reader value) { +inline void Value::Call::Builder::setFunction( ::capnp::Text::Reader value) { ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::Text::Builder JsonValue::Call::Builder::initFunction(unsigned int size) { +inline ::capnp::Text::Builder Value::Call::Builder::initFunction(unsigned int size) { return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } -inline void JsonValue::Call::Builder::adoptFunction( +inline void Value::Call::Builder::adoptFunction( ::capnp::Orphan< ::capnp::Text>&& value) { ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::Text> JsonValue::Call::Builder::disownFunction() { +inline ::capnp::Orphan< ::capnp::Text> Value::Call::Builder::disownFunction() { return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline bool JsonValue::Call::Reader::hasParams() const { +inline bool Value::Call::Reader::hasParams() const { return !_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline bool JsonValue::Call::Builder::hasParams() { +inline bool Value::Call::Builder::hasParams() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::JsonValue>::Reader JsonValue::Call::Reader::getParams() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Reader Value::Call::Reader::getParams() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::JsonValue>::Builder JsonValue::Call::Builder::getParams() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Builder Value::Call::Builder::getParams() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void JsonValue::Call::Builder::setParams( ::capnp::List< ::capnp::JsonValue>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::set(_builder.getPointerField( +inline void Value::Call::Builder::setParams( ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::JsonValue>::Builder JsonValue::Call::Builder::initParams(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>::Builder Value::Call::Builder::initParams(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } -inline void JsonValue::Call::Builder::adoptParams( - ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::adopt(_builder.getPointerField( +inline void Value::Call::Builder::adoptParams( + ::capnp::Orphan< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::JsonValue>> JsonValue::Call::Builder::disownParams() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::JsonValue>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>> Value::Call::Builder::disownParams() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::json::Value, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} + +inline bool FlattenOptions::Reader::hasPrefix() const { + return !_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline bool FlattenOptions::Builder::hasPrefix() { + return !_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::Text::Reader FlattenOptions::Reader::getPrefix() const { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), + ::capnp::schemas::bp_c4df13257bc2ea61 + 34); +} +inline ::capnp::Text::Builder FlattenOptions::Builder::getPrefix() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), + ::capnp::schemas::bp_c4df13257bc2ea61 + 34); +} +inline void FlattenOptions::Builder::setPrefix( ::capnp::Text::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), value); +} +inline ::capnp::Text::Builder FlattenOptions::Builder::initPrefix(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), size); +} +inline void FlattenOptions::Builder::adoptPrefix( + ::capnp::Orphan< ::capnp::Text>&& value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Text> FlattenOptions::Builder::disownPrefix() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} + +inline bool DiscriminatorOptions::Reader::hasName() const { + return !_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline bool DiscriminatorOptions::Builder::hasName() { + return !_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::Text::Reader DiscriminatorOptions::Reader::getName() const { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline ::capnp::Text::Builder DiscriminatorOptions::Builder::getName() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline void DiscriminatorOptions::Builder::setName( ::capnp::Text::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), value); +} +inline ::capnp::Text::Builder DiscriminatorOptions::Builder::initName(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), size); +} +inline void DiscriminatorOptions::Builder::adoptName( + ::capnp::Orphan< ::capnp::Text>&& value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Text> DiscriminatorOptions::Builder::disownName() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} + +inline bool DiscriminatorOptions::Reader::hasValueName() const { + return !_reader.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); +} +inline bool DiscriminatorOptions::Builder::hasValueName() { + return !_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::Text::Reader DiscriminatorOptions::Reader::getValueName() const { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} +inline ::capnp::Text::Builder DiscriminatorOptions::Builder::getValueName() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} +inline void DiscriminatorOptions::Builder::setValueName( ::capnp::Text::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), value); +} +inline ::capnp::Text::Builder DiscriminatorOptions::Builder::initValueName(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), size); +} +inline void DiscriminatorOptions::Builder::adoptValueName( + ::capnp::Orphan< ::capnp::Text>&& value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Text> DiscriminatorOptions::Builder::disownValueName() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } } // namespace +} // namespace + +CAPNP_END_HEADER -#endif // CAPNP_INCLUDED_8ef99297a43a5e34_ diff --git a/c++/src/capnp/compat/json.h b/c++/src/capnp/compat/json.h index 7fa815e099..8ce477ed37 100644 --- a/c++/src/capnp/compat/json.h +++ b/c++/src/capnp/compat/json.h @@ -19,15 +19,21 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_COMPAT_JSON_H_ -#define CAPNP_COMPAT_JSON_H_ +#pragma once #include #include #include +CAPNP_BEGIN_HEADER + namespace capnp { +typedef json::Value JsonValue; +// For backwards-compatibility. +// +// TODO(cleanup): Consider replacing all uses of JsonValue with json::Value? + class JsonCodec { // Flexible class for encoding Cap'n Proto types as JSON, and decoding JSON back to Cap'n Proto. // @@ -50,15 +56,13 @@ class JsonCodec { // - 64-bit integers are encoded as strings, since JSON "numbers" are double-precision floating // points which cannot store a 64-bit integer without losing data. // - NaNs and infinite floating point numbers are not allowed by the JSON spec, and so are encoded - // as null. This matches the behavior of `JSON.stringify` in at least Firefox and Chrome. + // as strings. // - Data is encoded as an array of numbers in the range [0,255]. You probably want to register // a handler that does something better, like maybe base64 encoding, but there are a zillion // different ways people do this. // - Encoding/decoding capabilities and AnyPointers requires registering a Handler, since there's // no obvious default behavior. - // - When decoding, unrecognized field names are ignored. Note: This means that JSON is NOT a - // good format for receiving input from a human. Consider `capnp eval` or the SchemaParser - // library for human input. + // - When decoding, fields with unknown names are ignored by default to allow schema evolution. public: JsonCodec(); @@ -75,8 +79,19 @@ class JsonCodec { // Set maximum nesting depth when decoding JSON to prevent highly nested input from overflowing // the call stack. The default is 64. + void setHasMode(HasMode mode); + // Normally, primitive field values are always included even if they are equal to the default + // value (HasMode::NON_NULL -- only null pointers are omitted). You can use + // setHasMode(HasMode::NON_DEFAULT) to specify that default-valued primitive fields should be + // omitted as well. + + void setRejectUnknownFields(bool enable); + // Choose whether decoding JSON with unknown fields should produce an error. You may trade + // allowing schema evolution against a guarantee that all data is preserved when decoding JSON + // by toggling this option. The default is to ignore unknown fields. + template - kj::String encode(T&& value); + kj::String encode(T&& value) const; // Encode any Cap'n Proto value to JSON, including primitives and // Dynamic{Enum,Struct,List,Capability}, but not DynamicValue (see below). @@ -126,7 +141,7 @@ class JsonCodec { // Translate JsonValue <-> text. template - void encode(T&& value, JsonValue::Builder output); + void encode(T&& value, JsonValue::Builder output) const; void encode(DynamicValue::Reader input, Type type, JsonValue::Builder output) const; void decode(JsonValue::Reader input, DynamicStruct::Builder output) const; template @@ -193,27 +208,68 @@ class JsonCodec { void addFieldHandler(StructSchema::Field field, Handler& handler); // Matches only the specific field. T can be a dynamic type. T must match the field's type. + void handleByAnnotation(Schema schema); + template void handleByAnnotation(); + // Inspects the given type (as specified by type parameter or dynamic schema) and all its + // dependencies looking for JSON annotations (see json.capnp), building and registering Handlers + // based on these annotations. + // + // If you'd like to use annotations to control JSON, you must call these functions before you + // start using the codec. They are not loaded "on demand" because that would require mutex + // locking. + + // --------------------------------------------------------------------------- + // Hack to support string literal parameters + + template + auto decode(const char (&input)[size], Params&&... params) const + -> decltype(decode(kj::arrayPtr(input, size), kj::fwd(params)...)) { + return decode(kj::arrayPtr(input, size - 1), kj::fwd(params)...); + } + template + auto decodeRaw(const char (&input)[size], Params&&... params) const + -> decltype(decodeRaw(kj::arrayPtr(input, size), kj::fwd(params)...)) { + return decodeRaw(kj::arrayPtr(input, size - 1), kj::fwd(params)...); + } + private: class HandlerBase; + class AnnotatedHandler; + class AnnotatedEnumHandler; + class Base64Handler; + class HexHandler; + class JsonValueHandler; struct Impl; kj::Own impl; void encodeField(StructSchema::Field field, DynamicValue::Reader input, JsonValue::Builder output) const; - void decodeArray(List::Reader input, DynamicList::Builder output) const; - void decodeObject(List::Reader input, DynamicStruct::Builder output) const; + Orphan decodeArray(List::Reader input, ListSchema type, Orphanage orphanage) const; + void decodeObject(JsonValue::Reader input, StructSchema type, Orphanage orphanage, DynamicStruct::Builder output) const; + void decodeField(StructSchema::Field fieldSchema, JsonValue::Reader fieldValue, + Orphanage orphanage, DynamicStruct::Builder output) const; void addTypeHandlerImpl(Type type, HandlerBase& handler); void addFieldHandlerImpl(StructSchema::Field field, Type type, HandlerBase& handler); + + AnnotatedHandler& loadAnnotatedHandler( + StructSchema schema, + kj::Maybe discriminator, + kj::Maybe unionDeclName, + kj::Vector& dependencies); }; // ======================================================================================= // inline implementation details +template +struct EncodeImpl; + template -kj::String JsonCodec::encode(T&& value) { +kj::String JsonCodec::encode(T&& value) const { + Type type = Type::from(value); typedef FromAny> Base; - return encode(DynamicValue::Reader(ReaderFor(kj::fwd(value))), Type::from()); + return encode(DynamicValue::Reader(ReaderFor(kj::fwd(value))), type); } template @@ -247,11 +303,17 @@ inline DynamicEnum JsonCodec::decode(kj::ArrayPtr input, EnumSchema // ----------------------------------------------------------------------------- template -void JsonCodec::encode(T&& value, JsonValue::Builder output) { +void JsonCodec::encode(T&& value, JsonValue::Builder output) const { typedef FromAny> Base; encode(DynamicValue::Reader(ReaderFor(kj::fwd(value))), Type::from(), output); } +template <> +inline void JsonCodec::encode( + DynamicStruct::Reader&& value, JsonValue::Builder output) const { + encode(DynamicValue::Reader(value), value.getSchema(), output); +} + template inline Orphan JsonCodec::decode(JsonValue::Reader input, Orphanage orphanage) const { return decode(input, Type::from(), orphanage).template releaseAs(); @@ -457,6 +519,11 @@ template <> void JsonCodec::addTypeHandler(Handler& handler) // TODO(someday): Implement support for registering handlers that cover thinsg like "all structs" // or "all lists". Currently you can only target a specific struct or list type. +template +void JsonCodec::handleByAnnotation() { + return handleByAnnotation(Schema::from()); +} + } // namespace capnp -#endif // CAPNP_COMPAT_JSON_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/compat/std-iterator.h b/c++/src/capnp/compat/std-iterator.h new file mode 100644 index 0000000000..039d5371f0 --- /dev/null +++ b/c++/src/capnp/compat/std-iterator.h @@ -0,0 +1,47 @@ +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +// This exposes IndexingIterator as something compatible with std::iterator so that things like +// std::copy work with List::begin/List::end. + +// Make sure that if this header is before list.h by the user it includes it to make +// IndexingIterator visible to avoid brittle header problems. +#include "../list.h" +#include + +CAPNP_BEGIN_HEADER + +namespace std { + +template +struct iterator_traits> { + using iterator_category = std::random_access_iterator_tag; + using value_type = Element; + using difference_type = int; + using pointer = Element*; + using reference = Element&; +}; + +} // namespace std + +CAPNP_END_HEADER diff --git a/c++/src/capnp/compat/websocket-rpc-test.c++ b/c++/src/capnp/compat/websocket-rpc-test.c++ new file mode 100644 index 0000000000..359573a8dc --- /dev/null +++ b/c++/src/capnp/compat/websocket-rpc-test.c++ @@ -0,0 +1,113 @@ +// Copyright (c) 2021 Ian Denhardt and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "websocket-rpc.h" +#include + +#include + +KJ_TEST("WebSocketMessageStream") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newWebSocketPipe(); + + auto msgStreamA = capnp::WebSocketMessageStream(*pipe.ends[0]); + auto msgStreamB = capnp::WebSocketMessageStream(*pipe.ends[1]); + + // Make a message, fill it with some stuff + capnp::MallocMessageBuilder originalMsg; + auto object = originalMsg.initRoot().initStructList(10); + object[0].setTextField("Test"); + object[1].initStructField().setTextField("A string"); + object[2].setTextField("Another field"); + object[3].setInt64Field(42); + auto originalSegments = originalMsg.getSegmentsForOutput(); + + // Send the message across the websocket, make sure it comes out unharmed. + auto writePromise = msgStreamA.writeMessage(nullptr, originalSegments); + msgStreamB.tryReadMessage(nullptr) + .then([&](auto maybeResult) -> kj::Promise { + KJ_IF_MAYBE(result, maybeResult) { + KJ_ASSERT(result->fds.size() == 0); + KJ_ASSERT(result->reader->getSegment(originalSegments.size()) == nullptr); + for(size_t i = 0; i < originalSegments.size(); i++) { + auto oldSegment = originalSegments[i]; + auto newSegment = result->reader->getSegment(i); + + KJ_ASSERT(oldSegment.size() == newSegment.size()); + KJ_ASSERT(memcmp( + &oldSegment[0], + &newSegment[0], + oldSegment.size() * sizeof(capnp::word) + ) == 0); + } + return kj::READY_NOW; + } else { + KJ_FAIL_ASSERT("Reading first message failed"); + } + }).wait(waitScope); + writePromise.wait(waitScope); + + // Close the websocket, and make sure the other end gets nullptr when reading. + auto endPromise = msgStreamA.end(); + msgStreamB.tryReadMessage(nullptr).then([](auto maybe) -> kj::Promise { + KJ_IF_MAYBE(segments, maybe) { + KJ_FAIL_ASSERT("Should have gotten nullptr after websocket was closed"); + } + return kj::READY_NOW; + }).wait(waitScope); + endPromise.wait(waitScope); +} + +KJ_TEST("WebSocketMessageStreamByteCount") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe1 = kj::newWebSocketPipe(); + auto pipe2 = kj::newWebSocketPipe(); + + auto msgStreamA = capnp::WebSocketMessageStream(*pipe1.ends[0]); + auto msgStreamB = capnp::WebSocketMessageStream(*pipe2.ends[1]); + + auto pumpTask = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + + capnp::MallocMessageBuilder originalMsg; + auto object = originalMsg.initRoot().initStructList(10); + object[0].setTextField("Test"); + object[1].initStructField().setTextField("A string"); + object[2].setTextField("Another field"); + object[3].setInt64Field(42); + auto originalSegments = originalMsg.getSegmentsForOutput(); + + auto writePromise = msgStreamA.writeMessage(nullptr, originalSegments); + msgStreamB.tryReadMessage(nullptr).wait(waitScope); + writePromise.wait(waitScope); + + auto endPromise = msgStreamA.end(); + msgStreamB.tryReadMessage(nullptr).wait(waitScope); + pumpTask.wait(waitScope); + endPromise.wait(waitScope); + KJ_EXPECT(pipe1.ends[0]->sentByteCount() == 2585); + KJ_EXPECT(pipe1.ends[1]->receivedByteCount() == 2585); + KJ_EXPECT(pipe2.ends[0]->sentByteCount() == 2585); + KJ_EXPECT(pipe2.ends[1]->receivedByteCount() == 2585); +} diff --git a/c++/src/capnp/compat/websocket-rpc.c++ b/c++/src/capnp/compat/websocket-rpc.c++ new file mode 100644 index 0000000000..1db2ebc02c --- /dev/null +++ b/c++/src/capnp/compat/websocket-rpc.c++ @@ -0,0 +1,128 @@ +// Copyright (c) 2021 Ian Denhardt and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include +#include +#include + +namespace capnp { + +WebSocketMessageStream::WebSocketMessageStream(kj::WebSocket& socket) + : socket(socket) + {}; + +kj::Promise> WebSocketMessageStream::tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options, kj::ArrayPtr scratchSpace) { + return socket.receive(options.traversalLimitInWords * sizeof(word)) + .then([options](auto msg) -> kj::Promise> { + KJ_SWITCH_ONEOF(msg) { + KJ_CASE_ONEOF(closeMsg, kj::WebSocket::Close) { + return kj::Maybe(); + } + KJ_CASE_ONEOF(str, kj::String) { + KJ_FAIL_REQUIRE( + "Unexpected websocket text message; expected only binary messages."); + break; + } + KJ_CASE_ONEOF(bytes, kj::Array) { + kj::Own reader; + size_t sizeInWords = bytes.size() / sizeof(word); + if (reinterpret_cast(bytes.begin()) % alignof(word) == 0) { + reader = kj::heap( + kj::arrayPtr( + reinterpret_cast(bytes.begin()), + sizeInWords + ), + options).attach(kj::mv(bytes)); + } else { + // The array is misaligned, so we need to copy it. + auto words = kj::heapArray(sizeInWords); + + // Note: can't just use bytes.size(), since the the target buffer may + // be shorter due to integer division. + memcpy(words.begin(), bytes.begin(), sizeInWords * sizeof(word)); + reader = kj::heap( + kj::arrayPtr(words.begin(), sizeInWords), + options).attach(kj::mv(words)); + } + return kj::Maybe(MessageReaderAndFds { + kj::mv(reader), + nullptr + }); + } + } + KJ_UNREACHABLE; + }); +} + +kj::Promise WebSocketMessageStream::writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) { + // TODO(perf): Right now the WebSocket interface only supports send() for + // contiguous arrays, so we need to copy the whole message into a new buffer + // in order to send it, whereas ideally we could just write each segment + // (and the segment table) in sequence. Perhaps we should extend the WebSocket + // interface to be able to send an ArrayPtr> as one binary + // message, and then use that to avoid an extra copy here. + + auto stream = kj::heap( + computeSerializedSizeInWords(segments) * sizeof(word)); + capnp::writeMessage(*stream, segments); + auto arrayPtr = stream->getArray(); + return socket.send(arrayPtr).attach(kj::mv(stream)); +} + +kj::Promise WebSocketMessageStream::writeMessages( + kj::ArrayPtr>> messages) { + // TODO(perf): Extend WebSocket interface with a way to write multiple messages at once. + + if(messages.size() == 0) { + return kj::READY_NOW; + } + return writeMessage(nullptr, messages[0]) + .then([this, messages = messages.slice(1, messages.size())]() mutable -> kj::Promise { + return writeMessages(messages); + }); +} + +kj::Maybe WebSocketMessageStream::getSendBufferSize() { + return nullptr; +} + +kj::Promise WebSocketMessageStream::end() { + return socket.close( + 1005, // most generic code, indicates "No Status Received." + // Since the MessageStream API doesn't tell us why + // we're closing the connection, this is the best + // we can do. This is consistent with what browser + // implementations do if no status is provided, see: + // + // * https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close + // * https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent + + "Capnp connection closed" // Similarly not much information to go on here, + // but this at least lets us trace this back to + // capnp. + ); +}; + +}; diff --git a/c++/src/capnp/compat/websocket-rpc.h b/c++/src/capnp/compat/websocket-rpc.h new file mode 100644 index 0000000000..80c8ae2537 --- /dev/null +++ b/c++/src/capnp/compat/websocket-rpc.h @@ -0,0 +1,57 @@ +// Copyright (c) 2021 Ian Denhardt and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include + +CAPNP_BEGIN_HEADER + +namespace capnp { + +class WebSocketMessageStream final : public MessageStream { + // An implementation of MessageStream that sends messages over a websocket. + // + // Each capnproto message is sent in a single binary websocket frame. +public: + WebSocketMessageStream(kj::WebSocket& socket); + + // Implements MessageStream + kj::Promise> tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr) override; + kj::Promise writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) override + KJ_WARN_UNUSED_RESULT; + kj::Promise writeMessages( + kj::ArrayPtr>> messages) override + KJ_WARN_UNUSED_RESULT; + kj::Maybe getSendBufferSize() override; + kj::Promise end() override; +private: + kj::WebSocket& socket; +}; + +} // namespace capnp + +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/capnp-test.ekam-rule b/c++/src/capnp/compiler/capnp-test.ekam-rule index 254401f2ed..9f3c04e41d 100755 --- a/c++/src/capnp/compiler/capnp-test.ekam-rule +++ b/c++/src/capnp/compiler/capnp-test.ekam-rule @@ -42,6 +42,8 @@ echo findProvider file:capnp/testdata/pretty.txt; read JUNK echo findProvider file:capnp/testdata/short.txt; read JUNK echo findProvider file:capnp/testdata/errors.capnp.nobuild; read JUNK echo findProvider file:capnp/testdata/errors.txt; read JUNK +echo findProvider file:capnp/testdata/errors2.capnp.nobuild; read JUNK +echo findProvider file:capnp/testdata/errors2.txt; read JUNK # Register our interest in the schema files. echo findProvider file:capnp/c++.capnp diff --git a/c++/src/capnp/compiler/capnp-test.sh b/c++/src/capnp/compiler/capnp-test.sh index 0a5703b95f..c435a3d3a8 100755 --- a/c++/src/capnp/compiler/capnp-test.sh +++ b/c++/src/capnp/compiler/capnp-test.sh @@ -38,7 +38,54 @@ else fi SCHEMA=`dirname "$0"`/../test.capnp +JSON_SCHEMA=`dirname "$0"`/../compat/json-test.capnp TESTDATA=`dirname "$0"`/../testdata +SRCDIR=`dirname "$0"`/../.. + +SUFFIX=${TESTDATA#*/src/} +PREFIX=${TESTDATA%${SUFFIX}} + +if [ "$PREFIX" = "" ]; then + PREFIX=. +fi + +# ======================================================================================== +# convert + +$CAPNP convert text:binary $SCHEMA TestAllTypes < $TESTDATA/short.txt | cmp $TESTDATA/binary - || fail encode +$CAPNP convert text:flat $SCHEMA TestAllTypes < $TESTDATA/short.txt | cmp $TESTDATA/flat - || fail encode flat +$CAPNP convert text:packed $SCHEMA TestAllTypes < $TESTDATA/short.txt | cmp $TESTDATA/packed - || fail encode packed +$CAPNP convert text:flat-packed $SCHEMA TestAllTypes < $TESTDATA/short.txt | cmp $TESTDATA/packedflat - || fail encode packedflat +$CAPNP convert text:binary $SCHEMA TestAllTypes < $TESTDATA/pretty.txt | cmp $TESTDATA/binary - || fail parse pretty + +$CAPNP convert binary:text $SCHEMA TestAllTypes < $TESTDATA/binary | cmp $TESTDATA/pretty.txt - || fail decode +$CAPNP convert flat:text $SCHEMA TestAllTypes < $TESTDATA/flat | cmp $TESTDATA/pretty.txt - || fail decode flat +$CAPNP convert packed:text $SCHEMA TestAllTypes < $TESTDATA/packed | cmp $TESTDATA/pretty.txt - || fail decode packed +$CAPNP convert flat-packed:text $SCHEMA TestAllTypes < $TESTDATA/packedflat | cmp $TESTDATA/pretty.txt - || fail decode packedflat +$CAPNP convert binary:text --short $SCHEMA TestAllTypes < $TESTDATA/binary | cmp $TESTDATA/short.txt - || fail decode short + +$CAPNP convert binary:text $SCHEMA TestAllTypes < $TESTDATA/segmented | cmp $TESTDATA/pretty.txt - || fail decode segmented +$CAPNP convert packed:text $SCHEMA TestAllTypes < $TESTDATA/segmented-packed | cmp $TESTDATA/pretty.txt - || fail decode segmented-packed + +$CAPNP convert binary:packed < $TESTDATA/binary | cmp $TESTDATA/packed - || fail binary to packed +$CAPNP convert packed:binary < $TESTDATA/packed | cmp $TESTDATA/binary - || fail packed to binary + +$CAPNP convert binary:json $SCHEMA TestAllTypes < $TESTDATA/binary | cmp $TESTDATA/pretty.json - || fail binary to json +$CAPNP convert binary:json --short $SCHEMA TestAllTypes < $TESTDATA/binary | cmp $TESTDATA/short.json - || fail binary to short json + +$CAPNP convert json:binary $SCHEMA TestAllTypes < $TESTDATA/pretty.json | cmp $TESTDATA/binary - || fail json to binary +$CAPNP convert json:binary $SCHEMA TestAllTypes < $TESTDATA/short.json | cmp $TESTDATA/binary - || fail short json to binary + +$CAPNP convert json:binary $JSON_SCHEMA TestJsonAnnotations -I"$SRCDIR" < $TESTDATA/annotated.json | cmp $TESTDATA/annotated-json.binary - || fail annotated json to binary +$CAPNP convert binary:json $JSON_SCHEMA TestJsonAnnotations -I"$SRCDIR" < $TESTDATA/annotated-json.binary | cmp $TESTDATA/annotated.json - || fail annotated binary to json + +[ "$(echo '(foo = (text = "abc"))' | $CAPNP convert text:text "$SRCDIR/capnp/test.capnp" BrandedAlias)" = '(foo = (text = "abc"), uv = void)' ] || fail branded alias +[ "$(echo '(foo = (text = "abc"))' | $CAPNP convert text:text "$SRCDIR/capnp/test.capnp" BrandedAlias.Inner)" = '(foo = (text = "abc"))' ] || fail branded alias +[ "$(echo '(foo = (text = "abc"))' | $CAPNP convert text:text "$SRCDIR/capnp/test.capnp" 'TestGenerics(BoxedText, Text)')" = '(foo = (text = "abc"), uv = void)' ] || fail branded alias +[ "$(echo '(baz = (text = "abc"))' | $CAPNP convert text:text "$SRCDIR/capnp/test.capnp" 'TestGenerics(TestAllTypes, List(Int32)).Inner2(BoxedText)')" = '(baz = (text = "abc"))' ] || fail branded alias + +# ======================================================================================== +# DEPRECATED encode/decode $CAPNP encode $SCHEMA TestAllTypes < $TESTDATA/short.txt | cmp $TESTDATA/binary - || fail encode $CAPNP encode --flat $SCHEMA TestAllTypes < $TESTDATA/short.txt | cmp $TESTDATA/flat - || fail encode flat @@ -55,6 +102,9 @@ $CAPNP decode --short $SCHEMA TestAllTypes < $TESTDATA/binary | cmp $TESTDATA/sh $CAPNP decode $SCHEMA TestAllTypes < $TESTDATA/segmented | cmp $TESTDATA/pretty.txt - || fail decode segmented $CAPNP decode --packed $SCHEMA TestAllTypes < $TESTDATA/segmented-packed | cmp $TESTDATA/pretty.txt - || fail decode segmented-packed +# ======================================================================================== +# eval + test_eval() { test "x`$CAPNP eval $SCHEMA $1 | tr -d '\r'`" = "x$2" || fail eval "$1 == $2" } @@ -67,5 +117,13 @@ test_eval globalPrintableStruct '(someText = "foo")' test_eval TestConstants.enumConst corge test_eval 'TestListDefaults.lists.int32ListList[2][0]' 12341234 -$CAPNP compile -ofoo $TESTDATA/errors.capnp.nobuild 2>&1 | sed -e "s,^.*/errors[.]capnp[.]nobuild,file,g" | tr -d '\r' | - cmp $TESTDATA/errors.txt - || fail error output +test "x`$CAPNP eval $SCHEMA -ojson globalPrintableStruct | tr -d '\r'`" = "x{\"someText\": \"foo\"}" || fail eval json "globalPrintableStruct == {someText = \"foo\"}" + +$CAPNP eval $TESTDATA/no-file-id.capnp.nobuild foo >/dev/null || fail eval "file without file ID can be parsed" +test "x`$CAPNP eval $TESTDATA/no-file-id.capnp.nobuild foo | tr -d '\r'`" = 'x"bar"' || fail eval "file without file ID parsed correctly" + +$CAPNP compile --no-standard-import --src-prefix="$PREFIX" -ofoo $TESTDATA/errors.capnp.nobuild 2>&1 | sed -e "s,^.*errors[.]capnp[.]nobuild:,file:,g" | tr -d '\r' | + diff -u $TESTDATA/errors.txt - || fail error output + +$CAPNP compile --no-standard-import --src-prefix="$PREFIX" -ofoo $TESTDATA/errors2.capnp.nobuild 2>&1 | sed -e "s,^.*errors2[.]capnp[.]nobuild:,file:,g" | tr -d '\r' | + diff -u $TESTDATA/errors2.txt - || fail error2 output diff --git a/c++/src/capnp/compiler/capnp.c++ b/c++/src/capnp/compiler/capnp.c++ index 728c3f2da2..17f4fe3b80 100644 --- a/c++/src/capnp/compiler/capnp.c++ +++ b/c++/src/capnp/compiler/capnp.c++ @@ -19,6 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#if _WIN32 +#include +#endif + #include "lexer.h" #include "parser.h" #include "compiler.h" @@ -38,11 +46,17 @@ #include #include #include +#include +#include #include #include +#include #if _WIN32 #include +#include +#include +#undef CONST #else #include #endif @@ -63,10 +77,10 @@ static const char VERSION_STRING[] = "Cap'n Proto version " VERSION; class CompilerMain final: public GlobalErrorReporter { public: explicit CompilerMain(kj::ProcessContext& context) - : context(context), loader(*this) {} + : context(context), disk(kj::newDiskFilesystem()), loader(*this) {} kj::MainFunc getMain() { - if (context.getProgramName().endsWith("capnpc")) { + if (context.getProgramName().endsWith("capnpc") || context.getProgramName().endsWith("capnpc.exe")) { kj::MainBuilder builder(context, VERSION_STRING, "Compiles Cap'n Proto schema files and generates corresponding source code in one or " "more languages."); @@ -82,10 +96,12 @@ public: "Generate source code from schema files.") .addSubCommand("id", KJ_BIND_METHOD(*this, getGenIdMain), "Generate a new unique ID.") + .addSubCommand("convert", KJ_BIND_METHOD(*this, getConvertMain), + "Convert messages between binary, text, JSON, etc.") .addSubCommand("decode", KJ_BIND_METHOD(*this, getDecodeMain), - "Decode binary Cap'n Proto message to text.") + "DEPRECATED (use `convert`)") .addSubCommand("encode", KJ_BIND_METHOD(*this, getEncodeMain), - "Encode text Cap'n Proto message to binary.") + "DEPRECATED (use `convert`)") .addSubCommand("eval", KJ_BIND_METHOD(*this, getEvalMain), "Evaluate a const from a schema file."); addGlobalOptions(builder); @@ -109,6 +125,47 @@ public: .build(); } + kj::MainFunc getConvertMain() { + // Only parse the schemas we actually need for decoding. + compileEagerness = Compiler::NODE; + + // Drop annotations since we don't need them. This avoids importing files like c++.capnp. + annotationFlag = Compiler::DROP_ANNOTATIONS; + + kj::MainBuilder builder(context, VERSION_STRING, + "Converts messages between formats. Reads a stream of messages from stdin in format " + " and writes them to stdout in format . Valid formats are:\n" + " binary standard binary format\n" + " packed packed binary format (deflates zeroes)\n" + " flat binary single segment, no segment table (rare)\n" + " flat-packed flat and packed\n" + " canonical canonicalized binary single segment, no segment table\n" + " text schema language struct literal format\n" + " json JSON format\n" + "When using \"text\" or \"json\" format, you must specify and " + "(but they are ignored and can be omitted for binary-to-binary conversions). " + " names names a struct type defined in , which is the root type " + "of the message(s)."); + addGlobalOptions(builder); + builder.addOption({"short"}, KJ_BIND_METHOD(*this, printShort), + "Write text or JSON output in short (non-pretty) format. Each message will " + "be printed on one line, without using whitespace to improve readability.") + .addOptionWithArg({"segment-size"}, KJ_BIND_METHOD(*this, setSegmentSize), "", + "For binary output, sets the preferred segment size on the MallocMessageBuilder to " + "words and turns off heuristic growth. This flag is mainly useful " + "for testing. Without it, each message will be written as a single " + "segment.") + .addOption({"quiet"}, KJ_BIND_METHOD(*this, setQuiet), + "Do not print warning messages about the input being in the wrong format. " + "Use this if you find the warnings are wrong (but also let us know so " + "we can improve them).") + .expectArg(":", KJ_BIND_METHOD(*this, setConversion)) + .expectOptionalArg("", KJ_BIND_METHOD(*this, addSource)) + .expectOptionalArg("", KJ_BIND_METHOD(*this, setRootType)) + .callAfterParsing(KJ_BIND_METHOD(*this, convert)); + return builder.build(); + } + kj::MainFunc getDecodeMain() { // Only parse the schemas we actually need for decoding. compileEagerness = Compiler::NODE; @@ -153,7 +210,7 @@ public: kj::MainBuilder builder(context, VERSION_STRING, "Encodes one or more textual Cap'n Proto messages to binary. The messages have root " "type defined in . Messages are read from standard input. Each " - "mesage is a parenthesized struct literal, like the format used to specify constants " + "message is a parenthesized struct literal, like the format used to specify constants " "and default values of struct type in the schema language. For example:\n" " (foo = 123, bar = \"hello\", baz = [true, false, true])\n" "The input may contain any number of such values; each will be encoded as a separate " @@ -189,6 +246,15 @@ public: // Drop annotations since we don't need them. This avoids importing files like c++.capnp. annotationFlag = Compiler::DROP_ANNOTATIONS; + // Default convert to text unless -o is given. + convertTo = Format::TEXT; + + // When using `capnp eval`, type IDs don't really matter, because `eval` won't actually use + // them for anything. When using Cap'n Proto an a config format -- the common use case for + // `capnp eval` -- the exercise of adding a file ID to every file is pointless busy work. So, + // we don't require it. + loader.setFileIdsRequired(false); + kj::MainBuilder builder(context, VERSION_STRING, "Prints (or encodes) the value of , which must be defined in . " " must refer to a const declaration, a field of a struct type (prints the default " @@ -206,20 +272,19 @@ public: "and --flat flags specify binary output, in which case the const must be of struct " "type."); addGlobalOptions(builder); - builder.addOption({'b', "binary"}, KJ_BIND_METHOD(*this, codeBinary), - "Write the output as binary instead of text, using standard Cap'n Proto " - "serialization. (This writes the message using capnp::writeMessage() " - "from .)") + builder.addOptionWithArg({'o', "output"}, KJ_BIND_METHOD(*this, setEvalOutputFormat), + "", "Encode the output in the given format. See `capnp help convert` " + "for a list of formats. Defaults to \"text\".") + .addOption({'b', "binary"}, KJ_BIND_METHOD(*this, codeBinary), + "same as -obinary") .addOption({"flat"}, KJ_BIND_METHOD(*this, codeFlat), - "Write the output as a flat single-segment binary message, with no framing.") + "same as -oflat") .addOption({'p', "packed"}, KJ_BIND_METHOD(*this, codePacked), - "Write the output as packed binary instead of text, using standard Cap'n " - "Proto packing, which deflates zero-valued bytes. (This writes the " - "message using capnp::writePackedMessage() from " - ".)") + "same as -opacked") .addOption({"short"}, KJ_BIND_METHOD(*this, printShort), - "Print in short (non-pretty) text format. The message will be printed on " - "one line, without using whitespace to improve readability.") + "If output format is text or JSON, write in short (non-pretty) format. The " + "message will be printed on one line, without using whitespace to improve " + "readability.") .expectArg("", KJ_BIND_METHOD(*this, addSource)) .expectArg("", KJ_BIND_METHOD(*this, evalConst)); return builder.build(); @@ -258,8 +323,12 @@ public: // shared options kj::MainBuilder::Validity addImportPath(kj::StringPtr path) { - loader.addImportPath(kj::heapString(path)); - return true; + KJ_IF_MAYBE(dir, getSourceDirectory(path, false)) { + loader.addImportPath(*dir); + return true; + } else { + return "no such directory"; + } } kj::MainBuilder::Validity noStandardImport() { @@ -268,34 +337,35 @@ public: } kj::MainBuilder::Validity addSource(kj::StringPtr file) { - // Strip redundant "./" prefixes to make src-prefix matching more lenient. - while (file.startsWith("./")) { - file = file.slice(2); - - // Remove redundant slashes as well (e.g. ".////foo" -> "foo"). - while (file.startsWith("/")) { - file = file.slice(1); - } - } - if (!compilerConstructed) { compiler = compilerSpace.construct(annotationFlag); compilerConstructed = true; } if (addStandardImportPaths) { - loader.addImportPath(kj::heapString("/usr/local/include")); - loader.addImportPath(kj::heapString("/usr/include")); + static constexpr kj::StringPtr STANDARD_IMPORT_PATHS[] = { + "/usr/local/include"_kj, + "/usr/include"_kj, #ifdef CAPNP_INCLUDE_DIR - loader.addImportPath(kj::heapString(CAPNP_INCLUDE_DIR)); + KJ_CONCAT(CAPNP_INCLUDE_DIR, _kj), #endif + }; + for (auto path: STANDARD_IMPORT_PATHS) { + KJ_IF_MAYBE(dir, getSourceDirectory(path, false)) { + loader.addImportPath(*dir); + } else { + // ignore standard path that doesn't exist + } + } + addStandardImportPaths = false; } - KJ_IF_MAYBE(module, loadModule(file)) { - uint64_t id = compiler->add(*module); - compiler->eagerlyCompile(id, compileEagerness); - sourceFiles.add(SourceFile { id, module->getSourceName(), &*module }); + auto dirPathPair = interpretSourceFile(file); + KJ_IF_MAYBE(module, loader.loadModule(dirPathPair.dir, dirPathPair.path)) { + auto compiled = compiler->add(*module); + compiler->eagerlyCompile(compiled.getId(), compileEagerness); + sourceFiles.add(SourceFile { compiled.getId(), compiled, module->getSourceName(), &*module }); } else { return "no such file"; } @@ -303,20 +373,6 @@ public: return true; } -private: - kj::Maybe loadModule(kj::StringPtr file) { - size_t longestPrefix = 0; - - for (auto& prefix: sourcePrefixes) { - if (file.startsWith(prefix)) { - longestPrefix = kj::max(longestPrefix, prefix.size()); - } - } - - kj::StringPtr canonicalName = file.slice(longestPrefix); - return loader.loadModule(file, canonicalName); - } - public: // ===================================================================================== // "id" command @@ -334,10 +390,14 @@ public: kj::StringPtr dir = spec.slice(*split + 1); auto plugin = spec.slice(0, *split); - KJ_IF_MAYBE(split2, dir.findFirst(':')) { - // Grr, there are two colons. Might this be a Windows path? Let's do some heuristics. - if (*split == 1 && (dir.startsWith("/") || dir.startsWith("\\"))) { - // So, the first ':' was the second char, and was followed by '/' or '\', e.g.: + if (*split == 1 && (dir.startsWith("/") || dir.startsWith("\\"))) { + // The colon is the second character and is immediately followed by a slash or backslash. + // So, the user passed something like `-o c:/foo`. Is this a request to run the C plugin + // and output to `/foo`? Or are we on Windows, and is this a request to run the plugin + // `c:/foo`? + KJ_IF_MAYBE(split2, dir.findFirst(':')) { + // There are two colons. The first ':' was the second char, and was followed by '/' or + // '\', e.g.: // capnp compile -o c:/foo.exe:bar // // In this case we can conclude that the second colon is actually meant to be the @@ -359,18 +419,24 @@ public: // -> CONTRADICTION // // We therefore conclude that the *second* colon is in fact the plugin/location separator. - // - // Note that there is still an ambiguous case: - // capnp compile -o c:/foo - // - // In this unfortunate case, we have no way to tell if the user meant "use the 'c' plugin - // and output to /foo" or "use the plugin c:/foo and output to the default location". We - // prefer the former interpretation, because the latter is Windows-specific and such - // users can always explicitly specify the output location like: - // capnp compile -o c:/foo:. dir = dir.slice(*split2 + 1); plugin = spec.slice(0, *split2 + 2); +#if _WIN32 + } else { + // The user wrote something like: + // + // capnp compile -o c:/foo/bar + // + // What does this mean? It depends on what system we're on. On a Unix system, the above + // clearly is a request to run the `capnpc-c` plugin (perhaps to output C code) and write + // to the directory /foo/bar. But on Windows, absolute paths do not start with '/', and + // the above is actually a request to run the plugin `c:/foo/bar`, outputting to the + // current directory. + + outputs.add(OutputDirective { spec.asArray(), nullptr }); + return true; +#endif } } @@ -378,7 +444,7 @@ public: if (stat(dir.cStr(), &stats) < 0 || !S_ISDIR(stats.st_mode)) { return "output location is inaccessible or is not a directory"; } - outputs.add(OutputDirective { plugin, dir }); + outputs.add(OutputDirective { plugin, disk->getCurrentPath().evalNative(dir) }); } else { outputs.add(OutputDirective { spec.asArray(), nullptr }); } @@ -387,22 +453,11 @@ public: } kj::MainBuilder::Validity addSourcePrefix(kj::StringPtr prefix) { - // Strip redundant "./" prefixes to make src-prefix matching more lenient. - while (prefix.startsWith("./")) { - prefix = prefix.slice(2); - } - - if (prefix == "" || prefix == ".") { - // Irrelevant prefix. - return true; - } - - if (prefix.endsWith("/")) { - sourcePrefixes.add(kj::heapString(prefix)); + if (getSourceDirectory(prefix, true) == nullptr) { + return "no such directory"; } else { - sourcePrefixes.add(kj::str(prefix, '/')); + return true; } - return true; } kj::MainBuilder::Validity generateOutput() { @@ -433,6 +488,8 @@ public: nodes.setWithCaveats(i, schemas[i].getProto()); } + request.adoptSourceInfo(compiler->getAllSourceInfo(message.getOrphanage())); + auto requestedFiles = request.initRequestedFiles(sourceFiles.size()); for (size_t i = 0; i < sourceFiles.size(); i++) { auto requestedFile = requestedFiles[i]; @@ -492,17 +549,30 @@ public: KJ_SYSCALL(dup2(pipeFds[0], STDIN_FILENO)); KJ_SYSCALL(close(pipeFds[0])); - if (output.dir != nullptr) { - KJ_SYSCALL(chdir(output.dir.cStr()), output.dir); + KJ_IF_MAYBE(d, output.dir) { +#if _WIN32 + KJ_SYSCALL(SetCurrentDirectoryW(d->forWin32Api(true).begin()), d->toWin32String(true)); +#else + auto wd = d->toString(true); + KJ_SYSCALL(chdir(wd.cStr()), wd); + KJ_SYSCALL(setenv("PWD", wd.cStr(), true)); +#endif } +#if _WIN32 + // MSVCRT's spawn*() don't correctly escape arguments, which is necessary on Windows + // since the underlying system call takes a single command line string rather than + // an arg list. We do the escaping ourselves by wrapping the name in quotes. We know + // that exeName itself can't contain quotes (since filenames aren't allowed to contain + // quotes on Windows), so we don't have to account for those. + KJ_ASSERT(exeName.findFirst('\"') == nullptr, + "Windows filenames can't contain quotes", exeName); + auto escapedExeName = kj::str("\"", exeName, "\""); +#endif + if (shouldSearchPath) { #if _WIN32 - // MSVCRT's spawn*() don't correctly escape arguments, which is necessary on Windows - // since the underlying system call takes a single command line string rather than - // an arg list. Instead of trying to do the escaping ourselves, we just pass "plugin" - // for argv[0]. - child = _spawnlp(_P_NOWAIT, exeName.cStr(), "plugin", nullptr); + child = _spawnlp(_P_NOWAIT, exeName.cStr(), escapedExeName.cStr(), nullptr); #else execlp(exeName.cStr(), exeName.cStr(), nullptr); #endif @@ -518,11 +588,7 @@ public: } #if _WIN32 - // MSVCRT's spawn*() don't correctly escape arguments, which is necessary on Windows - // since the underlying system call takes a single command line string rather than - // an arg list. Instead of trying to do the escaping ourselves, we just pass "plugin" - // for argv[0]. - child = _spawnl(_P_NOWAIT, exeName.cStr(), "plugin", nullptr); + child = _spawnl(_P_NOWAIT, exeName.cStr(), escapedExeName.cStr(), nullptr); #else execl(exeName.cStr(), exeName.cStr(), nullptr); #endif @@ -585,139 +651,340 @@ public: } // ===================================================================================== - // "decode" command + // "convert" command - kj::MainBuilder::Validity codeBinary() { - if (packed) return "cannot be used with --packed"; - if (flat) return "cannot be used with --flat"; - binary = true; - return true; - } - kj::MainBuilder::Validity codeFlat() { - if (binary) return "cannot be used with --binary"; - flat = true; - return true; - } - kj::MainBuilder::Validity codePacked() { - if (binary) return "cannot be used with --binary"; - packed = true; - return true; - } - kj::MainBuilder::Validity printShort() { - pretty = false; - return true; - } - kj::MainBuilder::Validity setQuiet() { - quiet = true; - return true; +private: + enum class Format { + BINARY, + PACKED, + FLAT, + FLAT_PACKED, + CANONICAL, + TEXT, + JSON + }; + + kj::Maybe parseFormatName(kj::StringPtr name) { + if (name == "binary" ) return Format::BINARY; + if (name == "packed" ) return Format::PACKED; + if (name == "flat" ) return Format::FLAT; + if (name == "flat-packed") return Format::FLAT_PACKED; + if (name == "canonical" ) return Format::CANONICAL; + if (name == "text" ) return Format::TEXT; + if (name == "json" ) return Format::JSON; + + return nullptr; } - kj::MainBuilder::Validity setSegmentSize(kj::StringPtr size) { - if (flat) return "cannot be used with --flat"; - char* end; - segmentSize = strtol(size.cStr(), &end, 0); - if (size.size() == 0 || *end != '\0') { - return "not an integer"; + + kj::StringPtr toString(Format format) { + switch (format) { + case Format::BINARY : return "binary"; + case Format::PACKED : return "packed"; + case Format::FLAT : return "flat"; + case Format::FLAT_PACKED: return "flat-packed"; + case Format::CANONICAL : return "canonical"; + case Format::TEXT : return "text"; + case Format::JSON : return "json"; } - return true; + KJ_UNREACHABLE; } - kj::MainBuilder::Validity setRootType(kj::StringPtr type) { - KJ_ASSERT(sourceFiles.size() == 1); - - KJ_IF_MAYBE(schema, resolveName(sourceFiles[0].id, type)) { - if (schema->getProto().which() != schema::Node::STRUCT) { - return "not a struct type"; + Format formatFromDeprecatedFlags(Format defaultFormat) { + // For deprecated commands "decode" and "encode". + if (flat) { + if (packed) { + return Format::FLAT_PACKED; + } else { + return Format::FLAT; } - rootType = schema->asStruct(); - return true; + } if (packed) { + return Format::PACKED; + } else if (binary) { + return Format::BINARY; } else { - return "no such type"; + return defaultFormat; } } -private: - kj::Maybe resolveName(uint64_t scopeId, kj::StringPtr name) { - while (name.size() > 0) { - kj::String temp; - kj::StringPtr part; - KJ_IF_MAYBE(dotpos, name.findFirst('.')) { - temp = kj::heapString(name.slice(0, *dotpos)); - part = temp; - name = name.slice(*dotpos + 1); + kj::MainBuilder::Validity verifyRequirements(Format format) { + if ((format == Format::TEXT || format == Format::JSON) && rootType == StructSchema()) { + return kj::str("format requires schema: ", toString(format)); + } else { + return true; + } + } + +public: + kj::MainBuilder::Validity setConversion(kj::StringPtr conversion) { + KJ_IF_MAYBE(colon, conversion.findFirst(':')) { + auto from = kj::str(conversion.slice(0, *colon)); + auto to = conversion.slice(*colon + 1); + + KJ_IF_MAYBE(f, parseFormatName(from)) { + convertFrom = *f; } else { - part = name; - name = nullptr; + return kj::str("unknown format: ", from); } - KJ_IF_MAYBE(childId, compiler->lookup(scopeId, part)) { - scopeId = *childId; + KJ_IF_MAYBE(t, parseFormatName(to)) { + convertTo = *t; } else { - return nullptr; + return kj::str("unknown format: ", to); } + + if (convertFrom == Format::JSON || convertTo == Format::JSON) { + // We need annotations to process JSON. + // TODO(someday): Find a way that we can process annotations from json.capnp without + // requiring other annotation-only imports like c++.capnp + annotationFlag = Compiler::COMPILE_ANNOTATIONS; + } + + return true; + } else { + return "invalid conversion, format is: :"; } - return compiler->getLoader().get(scopeId); } -public: - kj::MainBuilder::Validity decode() { + kj::MainBuilder::Validity convert() { + { + auto result = verifyRequirements(convertFrom); + if (result.getError() != nullptr) return result; + } + { + auto result = verifyRequirements(convertTo); + if (result.getError() != nullptr) return result; + } + kj::FdInputStream rawInput(STDIN_FILENO); kj::BufferedInputStreamWrapper input(rawInput); + kj::FdOutputStream output(STDOUT_FILENO); + if (!quiet) { - auto result = checkPlausibility(input.getReadBuffer()); + auto result = checkPlausibility(convertFrom, input.getReadBuffer()); if (result.getError() != nullptr) { return kj::mv(result); } } - if (flat) { - // Read in the whole input to decode as one segment. - kj::Array words; - - { - kj::Vector allBytes; - for (;;) { - auto buffer = input.tryGetReadBuffer(); - if (buffer.size() == 0) break; - allBytes.addAll(buffer); - input.skip(buffer.size()); - } + while (input.tryGetReadBuffer().size() > 0) { + readOneAndConvert(input, output); + } - if (packed) { - words = kj::heapArray(computeUnpackedSizeInWords(allBytes)); - kj::ArrayInputStream input(allBytes); - capnp::_::PackedInputStream unpacker(input); - unpacker.read(words.asBytes().begin(), words.asBytes().size()); - word dummy; - KJ_ASSERT(unpacker.tryRead(&dummy, sizeof(dummy), sizeof(dummy)) == 0); - } else { - // Technically we don't know if the bytes are aligned so we'd better copy them to a new - // array. Note that if we have a non-whole number of words we chop off the straggler - // bytes. This is fine because if those bytes are actually part of the message we will - // hit an error later and if they are not then who cares? - words = kj::heapArray(allBytes.size() / sizeof(word)); - memcpy(words.begin(), allBytes.begin(), words.size() * sizeof(word)); - } + context.exit(); + KJ_CLANG_KNOWS_THIS_IS_UNREACHABLE_BUT_GCC_DOESNT; + } + +private: + kj::Vector readAll(kj::BufferedInputStreamWrapper& input) { + kj::Vector allBytes; + for (;;) { + auto buffer = input.tryGetReadBuffer(); + if (buffer.size() == 0) break; + allBytes.addAll(buffer); + input.skip(buffer.size()); + } + return allBytes; + } + + kj::String readOneText(kj::BufferedInputStreamWrapper& input) { + // Consume and return one parentheses-delimited message from the input. + // + // Accounts for nested parentheses, comments, and string literals. + + enum { + NORMAL, + COMMENT, + QUOTE, + QUOTE_ESCAPE, + DQUOTE, + DQUOTE_ESCAPE + } state = NORMAL; + uint depth = 0; + bool sawClose = false; + + kj::Vector chars; + + for (;;) { + auto buffer = input.tryGetReadBuffer(); + + if (buffer == nullptr) { + // EOF + chars.add('\0'); + return kj::String(chars.releaseAsArray()); } - kj::ArrayPtr segments = words; - decodeInner(arrayPtr(&segments, 1)); - } else { - while (input.tryGetReadBuffer().size() > 0) { - if (packed) { - decodeInner(input); - } else { - decodeInner(input); + for (auto i: kj::indices(buffer)) { + char c = buffer[i]; + switch (state) { + case NORMAL: + switch (c) { + case '#': state = COMMENT; break; + case '(': + if (depth == 0 && sawClose) { + // We already got one complete message. This is the start of the next message. + // Stop here. + chars.addAll(buffer.slice(0, i)); + chars.add('\0'); + input.skip(i); + return kj::String(chars.releaseAsArray()); + } + ++depth; + break; + case ')': + if (depth > 0) { + if (--depth == 0) { + sawClose = true; + } + } + break; + default: break; + } + break; + case COMMENT: + switch (c) { + case '\n': state = NORMAL; break; + default: break; + } + break; + case QUOTE: + switch (c) { + case '\'': state = NORMAL; break; + case '\\': state = QUOTE_ESCAPE; break; + default: break; + } + break; + case QUOTE_ESCAPE: + break; + case DQUOTE: + switch (c) { + case '\"': state = NORMAL; break; + case '\\': state = DQUOTE_ESCAPE; break; + default: break; + } + break; + case DQUOTE_ESCAPE: + break; } } + + chars.addAll(buffer); + input.skip(buffer.size()); } + } - context.exit(); - KJ_CLANG_KNOWS_THIS_IS_UNREACHABLE_BUT_GCC_DOESNT; + kj::String readOneJson(kj::BufferedInputStreamWrapper& input) { + // Consume and return one brace-delimited message from the input. + // + // Accounts for nested braces, string literals, and comments starting with # or //. Technically + // JSON does not permit comments but this code is lenient in case we change things later. + + enum { + NORMAL, + SLASH, + COMMENT, + QUOTE, + QUOTE_ESCAPE, + DQUOTE, + DQUOTE_ESCAPE + } state = NORMAL; + uint depth = 0; + bool sawClose = false; + + kj::Vector chars; + + for (;;) { + auto buffer = input.tryGetReadBuffer(); + + if (buffer == nullptr) { + // EOF + chars.add('\0'); + return kj::String(chars.releaseAsArray()); + } + + for (auto i: kj::indices(buffer)) { + char c = buffer[i]; + switch (state) { + case SLASH: + if (c == '/') { + state = COMMENT; + break; + } + KJ_FALLTHROUGH; + case NORMAL: + switch (c) { + case '#': state = COMMENT; break; + case '/': state = SLASH; break; + case '{': + if (depth == 0 && sawClose) { + // We already got one complete message. This is the start of the next message. + // Stop here. + chars.addAll(buffer.slice(0, i)); + chars.add('\0'); + input.skip(i); + return kj::String(chars.releaseAsArray()); + } + ++depth; + break; + case '}': + if (depth > 0) { + if (--depth == 0) { + sawClose = true; + } + } + break; + default: break; + } + break; + case COMMENT: + switch (c) { + case '\n': state = NORMAL; break; + default: break; + } + break; + case QUOTE: + switch (c) { + case '\'': state = NORMAL; break; + case '\\': state = QUOTE_ESCAPE; break; + default: break; + } + break; + case QUOTE_ESCAPE: + break; + case DQUOTE: + switch (c) { + case '\"': state = NORMAL; break; + case '\\': state = DQUOTE_ESCAPE; break; + default: break; + } + break; + case DQUOTE_ESCAPE: + break; + } + } + + chars.addAll(buffer); + input.skip(buffer.size()); + } } -private: - struct ParseErrorCatcher: public kj::ExceptionCallback { + class ParseErrorCatcher: public kj::ExceptionCallback { + public: + ParseErrorCatcher(kj::ProcessContext& context): context(context) {} + ~ParseErrorCatcher() noexcept(false) { + if (!unwindDetector.isUnwinding()) { + KJ_IF_MAYBE(e, exception) { + context.error(kj::str( + "*** ERROR CONVERTING PREVIOUS MESSAGE ***\n" + "The following error occurred while converting the message above.\n" + "This probably means the input data is invalid/corrupted.\n", + "Exception description: ", e->getDescription(), "\n" + "Code location: ", e->getFile(), ":", e->getLine(), "\n" + "*** END ERROR ***")); + } + } + } + void onRecoverableException(kj::Exception&& e) { // Only capture the first exception, on the assumption that later exceptions are probably // just cascading problems. @@ -726,45 +993,273 @@ private: } } + private: + kj::ProcessContext& context; kj::Maybe exception; + kj::UnwindDetector unwindDetector; }; - template - void decodeInner(Input&& input) { + void readOneAndConvert(kj::BufferedInputStreamWrapper& input, kj::OutputStream& output) { // Since this is a debug tool, lift the usual security limits. Worse case is the process // crashes or has to be killed. ReaderOptions options; options.nestingLimit = kj::maxValue; options.traversalLimitInWords = kj::maxValue; - MessageReaderType reader(input, options); - kj::String text; - kj::Maybe exception; + ParseErrorCatcher parseErrorCatcher(context); - { - ParseErrorCatcher catcher; - auto root = reader.template getRoot(rootType); - if (pretty) { - text = kj::str(prettyPrint(root), '\n'); - } else { - text = kj::str(root, '\n'); + switch (convertFrom) { + case Format::BINARY: { + capnp::InputStreamMessageReader message(input, options); + return writeConversion(message.getRoot(), output); + } + case Format::PACKED: { + capnp::PackedMessageReader message(input, options); + return writeConversion(message.getRoot(), output); + } + case Format::FLAT: + case Format::CANONICAL: { + auto allBytes = readAll(input); + + // Technically we don't know if the bytes are aligned so we'd better copy them to a new + // array. Note that if we have a non-whole number of words we chop off the straggler + // bytes. This is fine because if those bytes are actually part of the message we will + // hit an error later and if they are not then who cares? + auto words = kj::heapArray(allBytes.size() / sizeof(word)); + memcpy(words.begin(), allBytes.begin(), words.size() * sizeof(word)); + + kj::ArrayPtr segments[1] = { words }; + SegmentArrayMessageReader message(segments, options); + if (convertFrom == Format::CANONICAL) { + KJ_REQUIRE(message.isCanonical()); + } + return writeConversion(message.getRoot(), output); + } + case Format::FLAT_PACKED: { + auto allBytes = readAll(input); + + auto words = kj::heapArray(computeUnpackedSizeInWords(allBytes)); + kj::ArrayInputStream input(allBytes); + capnp::_::PackedInputStream unpacker(input); + unpacker.read(words.asBytes().begin(), words.asBytes().size()); + word dummy; + KJ_ASSERT(unpacker.tryRead(&dummy, sizeof(dummy), sizeof(dummy)) == 0); + + kj::ArrayPtr segments[1] = { words }; + SegmentArrayMessageReader message(segments, options); + return writeConversion(message.getRoot(), output); + } + case Format::TEXT: { + auto text = readOneText(input); + MallocMessageBuilder message; + TextCodec codec; + codec.setPrettyPrint(pretty); + auto root = message.initRoot(rootType); + codec.decode(text, root); + return writeConversion(root.asReader(), output); + } + case Format::JSON: { + auto text = readOneJson(input); + MallocMessageBuilder message; + JsonCodec codec; + codec.setPrettyPrint(pretty); + codec.handleByAnnotation(rootType); + auto root = message.initRoot(rootType); + codec.decode(text, root); + return writeConversion(root.asReader(), output); + } + } + + KJ_UNREACHABLE; + } + + void writeConversion(AnyStruct::Reader reader, kj::OutputStream& output) { + switch (convertTo) { + case Format::BINARY: { + MallocMessageBuilder message( + segmentSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : segmentSize, + segmentSize == 0 ? SUGGESTED_ALLOCATION_STRATEGY : AllocationStrategy::FIXED_SIZE); + message.setRoot(reader); + capnp::writeMessage(output, message); + return; + } + case Format::PACKED: { + MallocMessageBuilder message( + segmentSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : segmentSize, + segmentSize == 0 ? SUGGESTED_ALLOCATION_STRATEGY : AllocationStrategy::FIXED_SIZE); + message.setRoot(reader); + capnp::writePackedMessage(output, message); + return; + } + case Format::FLAT: { + auto words = kj::heapArray(reader.totalSize().wordCount + 1); + memset(words.begin(), 0, words.asBytes().size()); + copyToUnchecked(reader, words); + output.write(words.begin(), words.asBytes().size()); + return; + } + case Format::FLAT_PACKED: { + auto words = kj::heapArray(reader.totalSize().wordCount + 1); + memset(words.begin(), 0, words.asBytes().size()); + copyToUnchecked(reader, words); + kj::BufferedOutputStreamWrapper buffered(output); + capnp::_::PackedOutputStream packed(buffered); + packed.write(words.begin(), words.asBytes().size()); + return; + } + case Format::CANONICAL: { + auto words = reader.canonicalize(); + output.write(words.begin(), words.asBytes().size()); + return; } - exception = kj::mv(catcher.exception); + case Format::TEXT: { + TextCodec codec; + codec.setPrettyPrint(pretty); + auto text = codec.encode(reader.as(rootType)); + output.write({text.asBytes(), kj::StringPtr("\n").asBytes()}); + return; + } + case Format::JSON: { + JsonCodec codec; + codec.setPrettyPrint(pretty); + codec.handleByAnnotation(rootType); + auto text = codec.encode(reader.as(rootType)); + output.write({text.asBytes(), kj::StringPtr("\n").asBytes()}); + return; + } + } + + KJ_UNREACHABLE; + } + +public: + + // ===================================================================================== + // "decode" command + + kj::MainBuilder::Validity codeBinary() { + if (packed) return "cannot be used with --packed"; + if (flat) return "cannot be used with --flat"; + binary = true; + return true; + } + kj::MainBuilder::Validity codeFlat() { + if (binary) return "cannot be used with --binary"; + flat = true; + return true; + } + kj::MainBuilder::Validity codePacked() { + if (binary) return "cannot be used with --binary"; + packed = true; + return true; + } + kj::MainBuilder::Validity printShort() { + pretty = false; + return true; + } + kj::MainBuilder::Validity setQuiet() { + quiet = true; + return true; + } + kj::MainBuilder::Validity setSegmentSize(kj::StringPtr size) { + if (flat) return "cannot be used with --flat"; + char* end; + segmentSize = strtol(size.cStr(), &end, 0); + if (size.size() == 0 || *end != '\0') { + return "not an integer"; } + return true; + } + + kj::MainBuilder::Validity setRootType(kj::StringPtr input) { + KJ_ASSERT(sourceFiles.size() == 1); + + class CliArgumentErrorReporter: public ErrorReporter { + public: + void addError(uint32_t startByte, uint32_t endByte, kj::StringPtr message) override { + if (startByte < endByte) { + error = kj::str(startByte + 1, "-", endByte, ": ", message); + } else if (startByte > 0) { + error = kj::str(startByte + 1, ": ", message); + } else { + error = kj::str(message); + } + } + + bool hadErrors() override { + return error != nullptr; + } + + kj::MainBuilder::Validity getValidity() { + KJ_IF_MAYBE(e, error) { + return kj::mv(*e); + } else { + return true; + } + } - kj::FdOutputStream(STDOUT_FILENO).write(text.begin(), text.size()); + private: + kj::Maybe error; + }; - KJ_IF_MAYBE(e, exception) { - context.error(kj::str( - "*** ERROR DECODING PREVIOUS MESSAGE ***\n" - "The following error occurred while decoding the message above.\n" - "This probably means the input data is invalid/corrupted.\n", - "Exception description: ", e->getDescription(), "\n" - "Code location: ", e->getFile(), ":", e->getLine(), "\n" - "*** END ERROR ***")); + CliArgumentErrorReporter errorReporter; + + capnp::MallocMessageBuilder tokenArena; + auto lexedTokens = tokenArena.initRoot(); + lex(input, lexedTokens, errorReporter); + + CapnpParser parser(tokenArena.getOrphanage(), errorReporter); + auto tokens = lexedTokens.asReader().getTokens(); + CapnpParser::ParserInput parserInput(tokens.begin(), tokens.end()); + + bool success = false; + + if (parserInput.getPosition() == tokens.end()) { + // Empty argument? + errorReporter.addError(0, 0, "Couldn't parse type name."); + } else { + KJ_IF_MAYBE(expression, parser.getParsers().expression(parserInput)) { + // The input is expected to contain a *single* expression. + if (parserInput.getPosition() == tokens.end()) { + // Hooray, now parse it. + KJ_IF_MAYBE(compiledType, + sourceFiles[0].compiled.evalType(expression->getReader(), errorReporter)) { + KJ_IF_MAYBE(type, compiledType->getSchema()) { + if (type->isStruct()) { + rootType = type->asStruct(); + success = true; + } else { + errorReporter.addError(0, 0, "Type is not a struct."); + } + } else { + // Apparently named a file scope. + errorReporter.addError(0, 0, "Type is not a struct."); + } + } + } else { + errorReporter.addErrorOn(parserInput.current(), "Couldn't parse type name."); + } + } else { + auto best = parserInput.getBest(); + if (best == tokens.end()) { + errorReporter.addError(input.size(), input.size(), "Couldn't parse type name."); + } else { + errorReporter.addErrorOn(*best, "Couldn't parse type name."); + } + } } + + KJ_ASSERT(success || errorReporter.hadErrors()); + return errorReporter.getValidity(); + } + + kj::MainBuilder::Validity decode() { + convertTo = Format::TEXT; + convertFrom = formatFromDeprecatedFlags(Format::BINARY); + return convert(); } +private: enum Plausibility { IMPOSSIBLE, IMPLAUSIBLE, @@ -798,8 +1293,13 @@ private: return IMPOSSIBLE; } if ((prefix[3] & 0x80) != 0) { - // Offset is negative (invalid). - return IMPOSSIBLE; + if (prefix[0] == 0xff && prefix[1] == 0xff && prefix[2] == 0xff && prefix[3] == 0xff && + prefix[4] == 0 && prefix[5] == 0 && prefix[6] == 0 && prefix[7] == 0) { + // This is an empty struct with offset of -1. That's valid. + } else { + // Offset is negative (invalid). + return IMPOSSIBLE; + } } if ((prefix[3] & 0xe0) != 0) { // Offset is over a gigabyte (implausible). @@ -933,359 +1433,228 @@ private: }); } - kj::MainBuilder::Validity checkPlausibility(kj::ArrayPtr prefix) { - if (flat && packed) { - switch (isPlausiblyPackedFlat(prefix)) { - case PLAUSIBLE: - break; - case IMPOSSIBLE: - if (plausibleOrWrongType(isPlausiblyPacked(prefix))) { - return "The input is not in --packed --flat format. It looks like it is in --packed " - "format. Try removing --flat."; - } else if (plausibleOrWrongType(isPlausiblyFlat(prefix))) { - return "The input is not in --packed --flat format. It looks like it is in --flat " - "format. Try removing --packed."; - } else if (plausibleOrWrongType(isPlausiblyBinary(prefix))) { - return "The input is not in --packed --flat format. It looks like it is in regular " - "binary format. Try removing the --packed and --flat flags."; - } else { - return "The input is not a Cap'n Proto message."; + Plausibility isPlausiblyText(kj::ArrayPtr prefix) { + enum { PREAMBLE, COMMENT, BODY } state = PREAMBLE; + + for (char c: prefix.asChars()) { + switch (state) { + case PREAMBLE: + // Before opening parenthesis. + switch (c) { + case '(': state = BODY; continue; + case '#': state = COMMENT; continue; + case ' ': + case '\n': + case '\r': + case '\t': + case '\v': + // whitespace + break; + default: + // Not whitespace, not comment, not open parenthesis. Impossible! + return IMPOSSIBLE; } - case IMPLAUSIBLE: - if (plausibleOrWrongType(isPlausiblyPacked(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --packed --flat format. It looks like\n" - "it may be in --packed format. I'll try to parse it in --packed --flat format\n" - "as you requested, but if it doesn't work, try removing --flat. Use --quiet to\n" - "suppress this warning.\n" - "*** END WARNING ***\n"); - } else if (plausibleOrWrongType(isPlausiblyFlat(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --packed --flat format. It looks like\n" - "it may be in --flat format. I'll try to parse it in --packed --flat format as\n" - "you requested, but if it doesn't work, try removing --packed. Use --quiet to\n" - "suppress this warning.\n" - "*** END WARNING ***\n"); - } else if (plausibleOrWrongType(isPlausiblyBinary(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --packed --flat format. It looks like\n" - "it may be in regular binary format. I'll try to parse it in --packed --flat\n" - "format as you requested, but if it doesn't work, try removing --packed and\n" - "--flat. Use --quiet to suppress this warning.\n" - "*** END WARNING ***\n"); - } else { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be a Cap'n Proto message in any known\n" - "binary format. I'll try to parse it anyway, but if it doesn't work, please\n" - "check your input. Use --quiet to suppress this warning.\n" - "*** END WARNING ***\n"); + break; + case COMMENT: + switch (c) { + case '\n': state = PREAMBLE; continue; + default: break; } break; - case WRONG_TYPE: - if (plausibleOrWrongType(isPlausiblyPacked(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be the type that you specified. I'll try\n" - "to parse it anyway, but if it doesn't look right, please verify that you\n" - "have the right type. This could also be because the input is not in --flat\n" - "format; indeed, it looks like this input may be in regular --packed format,\n" - "so you might want to try removing --flat. Use --quiet to suppress this\n" - "warning.\n" - "*** END WARNING ***\n"); - } else { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be the type that you specified. I'll try\n" - "to parse it anyway, but if it doesn't look right, please verify that you\n" - "have the right type. Use --quiet to suppress this warning.\n" - "*** END WARNING ***\n"); + case BODY: + switch (c) { + case '\"': + case '\'': + // String literal. Let's stop here before things get complicated. + return PLAUSIBLE; + default: + break; } break; } - } else if (flat) { - switch (isPlausiblyFlat(prefix)) { - case PLAUSIBLE: - break; - case IMPOSSIBLE: - if (plausibleOrWrongType(isPlausiblyPacked(prefix))) { - return "The input is not in --flat format. It looks like it is in --packed format. " - "Try that instead."; - } else if (plausibleOrWrongType(isPlausiblyPackedFlat(prefix))) { - return "The input is not in --flat format. It looks like it is in --packed --flat " - "format. Try adding --packed."; - } else if (plausibleOrWrongType(isPlausiblyBinary(prefix))) { - return "The input is not in --flat format. It looks like it is in regular binary " - "format. Try removing the --flat flag."; - } else { - return "The input is not a Cap'n Proto message."; - } - case IMPLAUSIBLE: - if (plausibleOrWrongType(isPlausiblyPacked(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --flat format. It looks like it may\n" - "be in --packed format. I'll try to parse it in --flat format as you\n" - "requested, but if it doesn't work, try --packed instead. Use --quiet to\n" - "suppress this warning.\n" - "*** END WARNING ***\n"); - } else if (plausibleOrWrongType(isPlausiblyPackedFlat(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --flat format. It looks like it may\n" - "be in --packed --flat format. I'll try to parse it in --flat format as you\n" - "requested, but if it doesn't work, try adding --packed. Use --quiet to\n" - "suppress this warning.\n" - "*** END WARNING ***\n"); - } else if (plausibleOrWrongType(isPlausiblyBinary(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --flat format. It looks like it may\n" - "be in regular binary format. I'll try to parse it in --flat format as you\n" - "requested, but if it doesn't work, try removing --flat. Use --quiet to\n" - "suppress this warning.\n" - "*** END WARNING ***\n"); - } else { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be a Cap'n Proto message in any known\n" - "binary format. I'll try to parse it anyway, but if it doesn't work, please\n" - "check your input. Use --quiet to suppress this warning.\n" - "*** END WARNING ***\n"); + + if ((static_cast(c) < 0x20 && c != '\n' && c != '\r' && c != '\t' && c != '\v') + || c == 0x7f) { + // Unprintable character. + return IMPOSSIBLE; + } + } + + return PLAUSIBLE; + } + + Plausibility isPlausiblyJson(kj::ArrayPtr prefix) { + enum { PREAMBLE, COMMENT, BODY } state = PREAMBLE; + + for (char c: prefix.asChars()) { + switch (state) { + case PREAMBLE: + // Before opening parenthesis. + switch (c) { + case '{': state = BODY; continue; + case '#': state = COMMENT; continue; + case '/': state = COMMENT; continue; + case ' ': + case '\n': + case '\r': + case '\t': + case '\v': + // whitespace + break; + default: + // Not whitespace, not comment, not open brace. Impossible! + return IMPOSSIBLE; } break; - case WRONG_TYPE: - if (plausibleOrWrongType(isPlausiblyBinary(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be the type that you specified. I'll try\n" - "to parse it anyway, but if it doesn't look right, please verify that you\n" - "have the right type. This could also be because the input is not in --flat\n" - "format; indeed, it looks like this input may be in regular binary format,\n" - "so you might want to try removing --flat. Use --quiet to suppress this\n" - "warning.\n" - "*** END WARNING ***\n"); - } else { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be the type that you specified. I'll try\n" - "to parse it anyway, but if it doesn't look right, please verify that you\n" - "have the right type. Use --quiet to suppress this warning.\n" - "*** END WARNING ***\n"); + case COMMENT: + switch (c) { + case '\n': state = PREAMBLE; continue; + default: break; } break; - } - } else if (packed) { - switch (isPlausiblyPacked(prefix)) { - case PLAUSIBLE: - break; - case IMPOSSIBLE: - if (plausibleOrWrongType(isPlausiblyBinary(prefix))) { - return "The input is not in --packed format. It looks like it is in regular binary " - "format. Try removing the --packed flag."; - } else if (plausibleOrWrongType(isPlausiblyPackedFlat(prefix))) { - return "The input is not in --packed format, nor does it look like it is in regular " - "binary format. It looks like it could be in --packed --flat format, although " - "that is unusual so I could be wrong."; - } else if (plausibleOrWrongType(isPlausiblyFlat(prefix))) { - return "The input is not in --packed format, nor does it look like it is in regular " - "binary format. It looks like it could be in --flat format, although that " - "is unusual so I could be wrong."; - } else { - return "The input is not a Cap'n Proto message."; - } - case IMPLAUSIBLE: - if (plausibleOrWrongType(isPlausiblyPackedFlat(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --packed format. It looks like it may\n" - "be in --packed --flat format. I'll try to parse it in --packed format as you\n" - "requested, but if it doesn't work, try adding --flat. Use --quiet to\n" - "suppress this warning.\n" - "*** END WARNING ***\n"); - } else if (plausibleOrWrongType(isPlausiblyBinary(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --packed format. It looks like it\n" - "may be in regular binary format. I'll try to parse it in --packed format as\n" - "you requested, but if it doesn't work, try removing --packed. Use --quiet to\n" - "suppress this warning.\n" - "*** END WARNING ***\n"); - } else if (plausibleOrWrongType(isPlausiblyFlat(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in --packed format, nor does it look\n" - "like it's in regular binary format. It looks like it could be in --flat\n" - "format, although that is unusual so I could be wrong. I'll try to parse\n" - "it in --flat format as you requested, but if it doesn't work, you might\n" - "want to try --flat, or the data may not be Cap'n Proto at all. Use\n" - "--quiet to suppress this warning.\n" - "*** END WARNING ***\n"); - } else { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be a Cap'n Proto message in any known\n" - "binary format. I'll try to parse it anyway, but if it doesn't work, please\n" - "check your input. Use --quiet to suppress this warning.\n" - "*** END WARNING ***\n"); + case BODY: + switch (c) { + case '\"': + // String literal. Let's stop here before things get complicated. + return PLAUSIBLE; + default: + break; } break; - case WRONG_TYPE: + } + + if ((c > 0 && c < ' ' && c != '\n' && c != '\r' && c != '\t' && c != '\v') || c == 0x7f) { + // Unprintable character. + return IMPOSSIBLE; + } + } + + return PLAUSIBLE; + } + + Plausibility isPlausibly(Format format, kj::ArrayPtr prefix) { + switch (format) { + case Format::BINARY : return isPlausiblyBinary (prefix); + case Format::PACKED : return isPlausiblyPacked (prefix); + case Format::FLAT : return isPlausiblyFlat (prefix); + case Format::FLAT_PACKED: return isPlausiblyPackedFlat(prefix); + case Format::CANONICAL : return isPlausiblyFlat (prefix); + case Format::TEXT : return isPlausiblyText (prefix); + case Format::JSON : return isPlausiblyJson (prefix); + } + KJ_UNREACHABLE; + } + + kj::Maybe guessFormat(kj::ArrayPtr prefix) { + Format candidates[] = { + Format::BINARY, + Format::TEXT, + Format::PACKED, + Format::JSON, + Format::FLAT, + Format::FLAT_PACKED + }; + + for (Format candidate: candidates) { + if (plausibleOrWrongType(isPlausibly(candidate, prefix))) { + return candidate; + } + } + + return nullptr; + } + + kj::MainBuilder::Validity checkPlausibility(Format format, kj::ArrayPtr prefix) { + switch (isPlausibly(format, prefix)) { + case PLAUSIBLE: + return true; + + case IMPOSSIBLE: + KJ_IF_MAYBE(guess, guessFormat(prefix)) { + return kj::str( + "The input is not in \"", toString(format), "\" format. It looks like it is in \"", + toString(*guess), "\" format. Try that instead."); + } else { + return kj::str( + "The input is not in \"", toString(format), "\" format."); + } + + case IMPLAUSIBLE: + KJ_IF_MAYBE(guess, guessFormat(prefix)) { + context.warning(kj::str( + "*** WARNING ***\n" + "The input data does not appear to be in \"", toString(format), "\" format. It\n" + "looks like it may be in \"", toString(*guess), "\" format. I'll try to parse\n" + "it in \"", toString(format), "\" format as you requested, but if it doesn't work,\n" + "try \"", toString(format), "\" instead. Use --quiet to suppress this warning.\n" + "*** END WARNING ***\n")); + } else { + context.warning(kj::str( + "*** WARNING ***\n" + "The input data does not appear to be in \"", toString(format), "\" format, nor\n" + "in any other known format. I'll try to parse it in \"", toString(format), "\"\n" + "format anyway, as you requested. Use --quiet to suppress this warning.\n" + "*** END WARNING ***\n")); + } + return true; + + case WRONG_TYPE: + if (format == Format::FLAT && plausibleOrWrongType(isPlausiblyBinary(prefix))) { context.warning( "*** WARNING ***\n" - "The input data does not appear to be the type that you specified. I'll try\n" + "The input data does not appear to match the schema that you specified. I'll try\n" "to parse it anyway, but if it doesn't look right, please verify that you\n" - "have the right type. Use --quiet to suppress this warning.\n" + "have the right schema. This could also be because the input is not in \"flat\"\n" + "format; indeed, it looks like this input may be in regular binary format,\n" + "so you might want to try \"binary\" instead. Use --quiet to suppress this\n" + "warning.\n" "*** END WARNING ***\n"); - break; - } - } else { - switch (isPlausiblyBinary(prefix)) { - case PLAUSIBLE: - break; - case IMPOSSIBLE: - if (plausibleOrWrongType(isPlausiblyPacked(prefix))) { - return "The input is not in regular binary format. It looks like it is in --packed " - "format. Try adding the --packed flag."; - } else if (plausibleOrWrongType(isPlausiblyFlat(prefix))) { - return "The input is not in regular binary format, nor does it look like it is in " - "--packed format. It looks like it could be in --flat format, although that " - "is unusual so I could be wrong."; - } else if (plausibleOrWrongType(isPlausiblyPackedFlat(prefix))) { - return "The input is not in regular binary format, nor does it look like it is in " - "--packed format. It looks like it could be in --packed --flat format, " - "although that is unusual so I could be wrong."; - } else { - return "The input is not a Cap'n Proto message."; - } - case IMPLAUSIBLE: - if (plausibleOrWrongType(isPlausiblyPacked(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in regular binary format. It looks like\n" - "it may be in --packed format. I'll try to parse it in regular format as you\n" - "requested, but if it doesn't work, try adding --packed. Use --quiet to\n" - "suppress this warning.\n" - "*** END WARNING ***\n"); - } else if (plausibleOrWrongType(isPlausiblyPacked(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in regular binary format. It looks like\n" - "it may be in --packed --flat format. I'll try to parse it in regular format as\n" - "you requested, but if it doesn't work, try adding --packed --flat. Use --quiet\n" - "to suppress this warning.\n" - "*** END WARNING ***\n"); - } else if (plausibleOrWrongType(isPlausiblyFlat(prefix))) { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be in regular binary format, nor does it\n" - "look like it's in --packed format. It looks like it could be in --flat\n" - "format, although that is unusual so I could be wrong. I'll try to parse\n" - "it in regular format as you requested, but if it doesn't work, you might\n" - "want to try --flat, or the data may not be Cap'n Proto at all. Use\n" - "--quiet to suppress this warning.\n" - "*** END WARNING ***\n"); - } else { - context.warning( - "*** WARNING ***\n" - "The input data does not appear to be a Cap'n Proto message in any known\n" - "binary format. I'll try to parse it anyway, but if it doesn't work, please\n" - "check your input. Use --quiet to suppress this warning.\n" - "*** END WARNING ***\n"); - } - break; - case WRONG_TYPE: + } else if (format == Format::FLAT_PACKED && + plausibleOrWrongType(isPlausiblyPacked(prefix))) { + context.warning( + "*** WARNING ***\n" + "The input data does not appear to match the schema that you specified. I'll try\n" + "to parse it anyway, but if it doesn't look right, please verify that you\n" + "have the right schema. This could also be because the input is not in \"flat-packed\"\n" + "format; indeed, it looks like this input may be in regular packed format,\n" + "so you might want to try \"packed\" instead. Use --quiet to suppress this\n" + "warning.\n" + "*** END WARNING ***\n"); + } else { context.warning( "*** WARNING ***\n" "The input data does not appear to be the type that you specified. I'll try\n" "to parse it anyway, but if it doesn't look right, please verify that you\n" "have the right type. Use --quiet to suppress this warning.\n" "*** END WARNING ***\n"); - break; - } + } + return true; } - return true; + KJ_UNREACHABLE; } public: // ----------------------------------------------------------------- kj::MainBuilder::Validity encode() { - kj::Vector allText; - - { - kj::FdInputStream rawInput(STDIN_FILENO); - kj::BufferedInputStreamWrapper input(rawInput); - - for (;;) { - auto buf = input.tryGetReadBuffer(); - if (buf.size() == 0) break; - allText.addAll(buf.asChars()); - input.skip(buf.size()); - } - } - - EncoderErrorReporter errorReporter(*this, allText); - MallocMessageBuilder arena; - - // Lex the input. - auto lexedTokens = arena.initRoot(); - lex(allText, lexedTokens, errorReporter); - - // Set up the parser. - CapnpParser parser(arena.getOrphanage(), errorReporter); - auto tokens = lexedTokens.asReader().getTokens(); - CapnpParser::ParserInput parserInput(tokens.begin(), tokens.end()); - - // Set up stuff for the ValueTranslator. - ValueResolverGlue resolver(compiler->getLoader(), errorReporter); - - // Set up output stream. - kj::FdOutputStream rawOutput(STDOUT_FILENO); - kj::BufferedOutputStreamWrapper output(rawOutput); - - while (parserInput.getPosition() != tokens.end()) { - KJ_IF_MAYBE(expression, parser.getParsers().expression(parserInput)) { - MallocMessageBuilder item( - segmentSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : segmentSize, - segmentSize == 0 ? SUGGESTED_ALLOCATION_STRATEGY : AllocationStrategy::FIXED_SIZE); - ValueTranslator translator(resolver, errorReporter, item.getOrphanage()); + convertFrom = Format::TEXT; + convertTo = formatFromDeprecatedFlags(Format::BINARY); + return convert(); + } - KJ_IF_MAYBE(value, translator.compileValue(expression->getReader(), rootType)) { - if (segmentSize == 0) { - writeFlat(value->getReader().as(), output); - } else { - item.adoptRoot(value->releaseAs()); - if (packed) { - writePackedMessage(output, item); - } else { - writeMessage(output, item); - } - } - } else { - // Errors were reported, so we'll exit with a failure status later. - } - } else { - auto best = parserInput.getBest(); - if (best == tokens.end()) { - context.exitError("Premature EOF."); - } else { - errorReporter.addErrorOn(*best, "Parse error."); - context.exit(); - } - } + kj::MainBuilder::Validity setEvalOutputFormat(kj::StringPtr format) { + KJ_IF_MAYBE(f, parseFormatName(format)) { + convertTo = *f; + return true; + } else { + return kj::str("unknown format: ", format); } - - output.flush(); - context.exit(); - KJ_CLANG_KNOWS_THIS_IS_UNREACHABLE_BUT_GCC_DOESNT; } kj::MainBuilder::Validity evalConst(kj::StringPtr name) { + convertTo = formatFromDeprecatedFlags(convertTo); + KJ_ASSERT(sourceFiles.size() == 1); auto parser = kj::parse::sequence( @@ -1406,15 +1775,15 @@ public: } // OK, we have a value. Print it. - if (binary || packed || flat) { + if (convertTo != Format::TEXT) { if (value.getType() != DynamicValue::STRUCT) { return "not a struct; binary output is only available on structs"; } - kj::FdOutputStream rawOutput(STDOUT_FILENO); - kj::BufferedOutputStreamWrapper output(rawOutput); - writeFlat(value.as(), output); - output.flush(); + auto structValue = value.as(); + rootType = structValue.getSchema(); + kj::FdOutputStream output(STDOUT_FILENO); + writeConversion(structValue, output); context.exit(); } else { if (pretty && value.getType() == DynamicValue::STRUCT) { @@ -1429,79 +1798,14 @@ public: KJ_CLANG_KNOWS_THIS_IS_UNREACHABLE_BUT_GCC_DOESNT; } -private: - void writeFlat(DynamicStruct::Reader value, kj::BufferedOutputStream& output) { - // Always copy the message to a flat array so that the output is predictable (one segment, - // in canonical order). - size_t size = value.totalSize().wordCount + 1; - kj::Array space = kj::heapArray(size); - memset(space.begin(), 0, size * sizeof(word)); - FlatMessageBuilder flatMessage(space); - flatMessage.setRoot(value); - flatMessage.requireFilled(); - - if (flat && packed) { - capnp::_::PackedOutputStream packer(output); - packer.write(space.begin(), space.size() * sizeof(word)); - } else if (flat) { - output.write(space.begin(), space.size() * sizeof(word)); - } else if (packed) { - writePackedMessage(output, flatMessage); - } else { - writeMessage(output, flatMessage); - } - } - - class EncoderErrorReporter final: public ErrorReporter { - public: - EncoderErrorReporter(GlobalErrorReporter& globalReporter, - kj::ArrayPtr content) - : globalReporter(globalReporter), lineBreaks(content) {} - - void addError(uint32_t startByte, uint32_t endByte, kj::StringPtr message) override { - globalReporter.addError("", lineBreaks.toSourcePos(startByte), - lineBreaks.toSourcePos(endByte), message); - } - - bool hadErrors() override { - return globalReporter.hadErrors(); - } - - private: - GlobalErrorReporter& globalReporter; - LineBreakTable lineBreaks; - }; - - class ValueResolverGlue final: public ValueTranslator::Resolver { - public: - ValueResolverGlue(const SchemaLoader& loader, ErrorReporter& errorReporter) - : loader(loader), errorReporter(errorReporter) {} - - kj::Maybe resolveType(uint64_t id) { - // Don't use tryGet() here because we shouldn't even be here if there were compile errors. - return loader.get(id); - } - - kj::Maybe resolveConstant(Expression::Reader name) override { - errorReporter.addErrorOn(name, kj::str("External constants not allowed in encode input.")); - return nullptr; - } - - kj::Maybe> readEmbed(LocatedText::Reader filename) override { - errorReporter.addErrorOn(filename, kj::str("External embeds not allowed in encode input.")); - return nullptr; - } - - private: - const SchemaLoader& loader; - ErrorReporter& errorReporter; - }; - public: // ===================================================================================== - void addError(kj::StringPtr file, SourcePos start, SourcePos end, + void addError(const kj::ReadableDirectory& directory, kj::PathPtr path, + SourcePos start, SourcePos end, kj::StringPtr message) override { + auto file = getDisplayName(directory, path); + kj::String wholeMessage; if (end.line == start.line) { if (end.column == start.column) { @@ -1526,6 +1830,7 @@ public: private: kj::ProcessContext& context; + kj::Own disk; ModuleLoader loader; kj::SpaceFor compilerSpace; bool compilerConstructed = false; @@ -1539,9 +1844,29 @@ private: // of those schemas, plus the parent nodes of any dependencies. This is what most code generators // require to function. - kj::Vector sourcePrefixes; + struct SourceDirectory { + kj::Own dir; + bool isSourcePrefix; + }; + + kj::HashMap sourceDirectories; + // For each import path and source prefix, tracks the directory object we opened for it. + // + // Use via getSourceDirectory(). + + kj::HashMap dirPrefixes; + // For each open directory object, maps to a path prefix to add when displaying this path in + // error messages. This keeps track of the original directory name as given by the user, before + // canonicalization. + // + // Use via getDisplayName(). + bool addStandardImportPaths = true; + Format convertFrom = Format::BINARY; + Format convertTo = Format::BINARY; + // For the "convert" command. + bool binary = false; bool flat = false; bool packed = false; @@ -1553,6 +1878,7 @@ private: struct SourceFile { uint64_t id; + Compiler::ModuleScope compiled; kj::StringPtr name; Module* module; }; @@ -1561,11 +1887,117 @@ private: struct OutputDirective { kj::ArrayPtr name; - kj::StringPtr dir; + kj::Maybe dir; + + KJ_DISALLOW_COPY(OutputDirective); + OutputDirective(OutputDirective&&) = default; + OutputDirective(kj::ArrayPtr name, kj::Maybe dir) + : name(name), dir(kj::mv(dir)) {} }; kj::Vector outputs; bool hadErrors_ = false; + + kj::Maybe getSourceDirectory( + kj::StringPtr pathStr, bool isSourcePrefix) { + auto cwd = disk->getCurrentPath(); + auto path = cwd.evalNative(pathStr); + + if (path.size() == 0) return disk->getRoot(); + + KJ_IF_MAYBE(sdir, sourceDirectories.find(path)) { + sdir->isSourcePrefix = sdir->isSourcePrefix || isSourcePrefix; + return *sdir->dir; + } + + if (path == cwd) { + // Slight hack if the working directory is explicitly specified: + // - We want to avoid opening a new copy of the working directory, as tryOpenSubdir() would + // do. + // - If isSourcePrefix is true, we need to add it to sourceDirectories to track that. + // Otherwise we don't need to add it at all. + // - We do not need to add it to dirPrefixes since the cwd is already handled in + // getDisplayName(). + auto& result = disk->getCurrent(); + if (isSourcePrefix) { + kj::Own fakeOwn(&result, kj::NullDisposer::instance); + sourceDirectories.insert(kj::mv(path), { kj::mv(fakeOwn), isSourcePrefix }); + } + return result; + } + + KJ_IF_MAYBE(dir, disk->getRoot().tryOpenSubdir(path)) { + auto& result = *dir->get(); + sourceDirectories.insert(kj::mv(path), { kj::mv(*dir), isSourcePrefix }); +#if _WIN32 + kj::String prefix = pathStr.endsWith("/") || pathStr.endsWith("\\") + ? kj::str(pathStr) : kj::str(pathStr, '\\'); +#else + kj::String prefix = pathStr.endsWith("/") ? kj::str(pathStr) : kj::str(pathStr, '/'); +#endif + dirPrefixes.insert(&result, kj::mv(prefix)); + return result; + } else { + return nullptr; + } + } + + struct DirPathPair { + const kj::ReadableDirectory& dir; + kj::Path path; + }; + + DirPathPair interpretSourceFile(kj::StringPtr pathStr) { + auto cwd = disk->getCurrentPath(); + auto path = cwd.evalNative(pathStr); + + KJ_REQUIRE(path.size() > 0); + for (size_t i = path.size() - 1; i > 0; i--) { + auto prefix = path.slice(0, i); + auto remainder = path.slice(i, path.size()); + + KJ_IF_MAYBE(sdir, sourceDirectories.find(prefix)) { + if (sdir->isSourcePrefix) { + return { *sdir->dir, remainder.clone() }; + } + } + } + + // No source prefix matched. Fall back to heuristic: try stripping the current directory, + // otherwise don't strip anything. + if (path.startsWith(cwd)) { + return { disk->getCurrent(), path.slice(cwd.size(), path.size()).clone() }; + } else { + // Hmm, no src-prefix matched and the file isn't even in the current directory. This might + // be OK if we aren't generating any output anyway, but otherwise the results will almost + // certainly not be what the user wanted. Let's print a warning, unless the output directives + // are ones which we know do not produce output files. This is a hack. + for (auto& output: outputs) { + auto name = kj::str(output.name); + if (name != "-" && name != "capnp") { + context.warning(kj::str(pathStr, + ": File is not in the current directory and does not match any prefix defined with " + "--src-prefix. Please pass an appropriate --src-prefix so I can figure out where to " + "write the output for this file.")); + break; + } + } + + return { disk->getRoot(), kj::mv(path) }; + } + } + + kj::String getDisplayName(const kj::ReadableDirectory& dir, kj::PathPtr path) { + KJ_IF_MAYBE(prefix, dirPrefixes.find(&dir)) { + return kj::str(*prefix, path.toNativeString()); + } else if (&dir == &disk->getRoot()) { + return path.toNativeString(true); + } else if (&dir == &disk->getCurrent()) { + return path.toNativeString(false); + } else { + KJ_FAIL_ASSERT("unrecognized directory"); + } + } }; } // namespace compiler diff --git a/c++/src/capnp/compiler/capnpc-c++.c++ b/c++/src/capnp/compiler/capnpc-c++.c++ index a925d4459f..a60d4770e3 100644 --- a/c++/src/capnp/compiler/capnpc-c++.c++ +++ b/c++/src/capnp/compiler/capnpc-c++.c++ @@ -21,6 +21,10 @@ // This program is a code generator plugin for `capnp compile` which generates C++ code. +#if _WIN32 +#include +#endif + #include #include "../serialize.h" #include @@ -28,19 +32,24 @@ #include #include #include +#include #include "../schema-loader.h" #include "../dynamic.h" -#include #include #include #include #include #include #include -#include -#include -#include -#include +#include + +#if _WIN32 +#include +#include +#undef CONST +#else +#include +#endif #if HAVE_CONFIG_H #include "config.h" @@ -55,6 +64,7 @@ namespace { static constexpr uint64_t NAMESPACE_ANNOTATION_ID = 0xb9c6f99ebf805f2cull; static constexpr uint64_t NAME_ANNOTATION_ID = 0xf264a779fef191ceull; +static constexpr uint64_t ALLOW_CANCELLATION_ANNOTATION_ID = 0xac7096ff8cfc9dceull; bool hasDiscriminantValue(const schema::Field::Reader& reader) { return reader.getDiscriminantValue() != schema::Field::NO_DISCRIMINANT; @@ -295,6 +305,50 @@ kj::String KJ_STRINGIFY(const CppTypeName& typeName) { } } +CppTypeName whichKind(Type type) { + // Make a CppTypeName representing the capnp::Kind value for the given schema type. This makes + // CppTypeName conflate types and values, but this is all just a hack for MSVC's benefit. Its + // primary use is as a non-type template parameter to `capnp::List` -- normally the Kind K + // is deduced via SFINAE, but MSVC just can't do it in certain cases, such as when a nested type + // of `capnp::List` is the return type of a function, and the element type T is a template + // instantiation. + + switch (type.which()) { + case schema::Type::VOID: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + + case schema::Type::BOOL: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::INT8: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::INT16: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::INT32: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::INT64: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::UINT8: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::UINT16: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::UINT32: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::UINT64: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::FLOAT32: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + case schema::Type::FLOAT64: return CppTypeName::makePrimitive(" ::capnp::Kind::PRIMITIVE"); + + case schema::Type::TEXT: return CppTypeName::makePrimitive(" ::capnp::Kind::BLOB"); + case schema::Type::DATA: return CppTypeName::makePrimitive(" ::capnp::Kind::BLOB"); + + case schema::Type::ENUM: return CppTypeName::makePrimitive(" ::capnp::Kind::ENUM"); + case schema::Type::STRUCT: return CppTypeName::makePrimitive(" ::capnp::Kind::STRUCT"); + case schema::Type::INTERFACE: return CppTypeName::makePrimitive(" ::capnp::Kind::INTERFACE"); + + case schema::Type::LIST: return CppTypeName::makePrimitive(" ::capnp::Kind::LIST"); + case schema::Type::ANY_POINTER: { + switch (type.whichAnyPointerKind()) { + case schema::Type::AnyPointer::Unconstrained::CAPABILITY: + return CppTypeName::makePrimitive(" ::capnp::Kind::INTERFACE"); + default: + return CppTypeName::makePrimitive(" ::capnp::Kind::OTHER"); + } + } + } + + KJ_UNREACHABLE; +} + // ======================================================================================= class CapnpcCppMain { @@ -377,7 +431,7 @@ private: #if 0 // Figure out exactly how many params are not bound to AnyPointer. - // TODO(msvc): In a few obscure cases, MSVC does not like empty template pramater lists, + // TODO(msvc): In a few obscure cases, MSVC does not like empty template parameter lists, // even if all parameters have defaults. So, we give in and explicitly list all // parameters in our generated code for now. Try again later. uint paramCount = 0; @@ -467,8 +521,10 @@ private: case schema::Type::LIST: { CppTypeName result = CppTypeName::makeNamespace("capnp"); - auto params = kj::heapArrayBuilder(1); - params.add(typeName(type.asList().getElementType(), method)); + auto params = kj::heapArrayBuilder(2); + auto list = type.asList(); + params.add(typeName(list.getElementType(), method)); + params.add(whichKind(list.getElementType())); result.addMemberTemplate("List", params.finish()); return result; } @@ -493,9 +549,12 @@ private: return CppTypeName::makePrimitive(" ::capnp::AnyStruct"); case schema::Type::AnyPointer::Unconstrained::LIST: return CppTypeName::makePrimitive(" ::capnp::AnyList"); - case schema::Type::AnyPointer::Unconstrained::CAPABILITY: - hasInterfaces = true; // Probably need to #inculde . - return CppTypeName::makePrimitive(" ::capnp::Capability"); + case schema::Type::AnyPointer::Unconstrained::CAPABILITY: { + hasInterfaces = true; // Probably need to #include . + auto result = CppTypeName::makePrimitive(" ::capnp::Capability"); + result.setHasInterfaces(); + return result; + } } KJ_UNREACHABLE; } @@ -682,12 +741,12 @@ private: kj::StringTree dependencies; size_t dependencyCount; // TODO(msvc): `dependencyCount` is the number of individual dependency definitions in - // `dependencies`. It's a hack to allow makeGenericDefinitions to hard-code the size of the - // `_capnpPrivate::brandDependencies` array into the definition of - // `_capnpPrivate::specificBrand::dependencyCount`. This is necessary because MSVC cannot deduce - // the size of `brandDependencies` if it is nested under a class template. It's probably this - // demoralizingly deferred bug: - // https://connect.microsoft.com/VisualStudio/feedback/details/759407/can-not-get-size-of-static-array-defined-in-class-template + // `dependencies`. It's a hack to allow makeGenericDefinitions to hard-code the size of the + // `_capnpPrivate::brandDependencies` array into the definition of + // `_capnpPrivate::specificBrand::dependencyCount`. This is necessary because MSVC cannot + // deduce the size of `brandDependencies` if it is nested under a class template. It's + // probably this demoralizingly deferred bug: + // https://connect.microsoft.com/VisualStudio/feedback/details/759407/can-not-get-size-of-static-array-defined-in-class-template }; BrandInitializerText makeBrandInitializers( @@ -784,7 +843,14 @@ private: } kj::Maybe makeBrandDepInitializer(Schema type) { - return makeBrandDepInitializer(type, cppFullName(type, nullptr)); + // Be careful not to invoke cppFullName() if it would just be thrown away, as doing so will + // add the type's declaring file to `usedImports`. In particular, this causes `stream.capnp.h` + // to be #included unnecessarily. + if (type.isBranded()) { + return makeBrandDepInitializer(type, cppFullName(type, nullptr)); + } else { + return nullptr; + } } kj::Maybe makeBrandDepInitializer( @@ -1732,7 +1798,7 @@ private: " ::capnp::bounded<", offset, ">() * ::capnp::POINTERS), kj::mv(value));\n" "}\n", COND(type.hasDisambiguatedTemplate(), - "#ifndef _MSC_VER\n" + "#if !defined(_MSC_VER) || defined(__clang__)\n" "// Excluded under MSVC because bugs may make it unable to compile this method.\n"), templateContext.allDecls(), "inline ::capnp::Orphan<", type, "> ", scope, "Builder::disown", titleCase, "() {\n", @@ -1740,7 +1806,7 @@ private: " return ::capnp::_::PointerHelpers<", type, ">::disown(_builder.getPointerField(\n" " ::capnp::bounded<", offset, ">() * ::capnp::POINTERS));\n" "}\n", - COND(type.hasDisambiguatedTemplate(), "#endif // !_MSC_VER\n"), + COND(type.hasDisambiguatedTemplate(), "#endif // !_MSC_VER || __clang__\n"), COND(shouldExcludeInLiteMode, "#endif // !CAPNP_LITE\n"), "\n") }; @@ -1980,11 +2046,15 @@ private: kj::StringTree defineText = kj::strTree( "// ", fullName, "\n", + "#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n", templates, "constexpr uint16_t ", fullName, "::_capnpPrivate::dataWordSize;\n", - templates, "constexpr uint16_t ", fullName, "::_capnpPrivate::pointerCount;\n" + templates, "constexpr uint16_t ", fullName, "::_capnpPrivate::pointerCount;\n", + "#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n", "#if !CAPNP_LITE\n", + "#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n", templates, "constexpr ::capnp::Kind ", fullName, "::_capnpPrivate::kind;\n", - templates, "constexpr ::capnp::_::RawSchema const* ", fullName, "::_capnpPrivate::schema;\n"); + templates, "constexpr ::capnp::_::RawSchema const* ", fullName, "::_capnpPrivate::schema;\n", + "#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n"); if (templateContext.isGeneric()) { auto brandInitializers = makeBrandInitializers(templateContext, schema); @@ -2094,6 +2164,8 @@ private: auto paramProto = paramSchema.getProto(); auto resultProto = resultSchema.getProto(); + bool isStreaming = method.isStreaming(); + auto implicitParamsReader = proto.getImplicitParameters(); auto implicitParamsBuilder = kj::heapArrayBuilder(implicitParamsReader.size()); for (auto param: implicitParamsReader) { @@ -2130,7 +2202,10 @@ private: } CppTypeName resultType; CppTypeName genericResultType; - if (resultProto.getScopeId() == 0) { + if (isStreaming) { + // We don't use resultType or genericResultType in this case. We want to avoid computing them + // at all so that we don't end up marking stream.capnp.h in usedImports. + } else if (resultProto.getScopeId() == 0) { resultType = interfaceTypeName; if (implicitParams.size() == 0) { resultType.addMemberType(kj::str(titleCase, "Results")); @@ -2148,7 +2223,7 @@ private: kj::String shortParamType = paramProto.getScopeId() == 0 ? kj::str(titleCase, "Params") : kj::str(genericParamType); - kj::String shortResultType = resultProto.getScopeId() == 0 ? + kj::String shortResultType = resultProto.getScopeId() == 0 || isStreaming ? kj::str(titleCase, "Results") : kj::str(genericResultType); auto interfaceProto = method.getContainingInterface().getProto(); @@ -2166,22 +2241,43 @@ private: // the `CAPNP_AUTO_IF_MSVC()` hackery in the return type declarations below. We're depending on // the fact that that this function has an inline implementation for the deduction to work. + bool noPromisePipelining = !resultSchema.mayContainCapabilities(); + auto requestMethodImpl = kj::strTree( templateContext.allDecls(), implicitParamsTemplateDecl, templateContext.isGeneric() ? "CAPNP_AUTO_IF_MSVC(" : "", - "::capnp::Request<", paramType, ", ", resultType, ">", + isStreaming ? kj::strTree("::capnp::StreamingRequest<", paramType, ">") + : kj::strTree("::capnp::Request<", paramType, ", ", resultType, ">"), templateContext.isGeneric() ? ")\n" : "\n", - interfaceName, "::Client::", name, "Request(::kj::Maybe< ::capnp::MessageSize> sizeHint) {\n" - " return newCall<", paramType, ", ", resultType, ">(\n" - " 0x", interfaceIdHex, "ull, ", methodId, ", sizeHint);\n" + interfaceName, "::Client::", name, "Request(::kj::Maybe< ::capnp::MessageSize> sizeHint) {\n", + isStreaming + ? kj::strTree(" return newStreamingCall<", paramType, ">(\n") + : kj::strTree(" return newCall<", paramType, ", ", resultType, ">(\n"), + " 0x", interfaceIdHex, "ull, ", methodId, ", sizeHint, {", noPromisePipelining, "});\n" "}\n"); + bool allowCancellation = false; + if (annotationValue(proto, ALLOW_CANCELLATION_ANNOTATION_ID) != nullptr) { + allowCancellation = true; + } else if (annotationValue(interfaceProto, ALLOW_CANCELLATION_ANNOTATION_ID) != nullptr) { + allowCancellation = true; + } else { + schema::Node::Reader node = interfaceProto; + while (!node.isFile()) { + node = schemaLoader.get(node.getScopeId()).getProto(); + } + if (annotationValue(node, ALLOW_CANCELLATION_ANNOTATION_ID) != nullptr) { + allowCancellation = true; + } + } + return MethodText { kj::strTree( implicitParamsTemplateDecl.size() == 0 ? "" : " ", implicitParamsTemplateDecl, templateContext.isGeneric() ? " CAPNP_AUTO_IF_MSVC(" : " ", - "::capnp::Request<", paramType, ", ", resultType, ">", + isStreaming ? kj::strTree("::capnp::StreamingRequest<", paramType, ">") + : kj::strTree("::capnp::Request<", paramType, ", ", resultType, ">"), templateContext.isGeneric() ? ")" : "", " ", name, "Request(\n" " ::kj::Maybe< ::capnp::MessageSize> sizeHint = nullptr);\n"), @@ -2191,8 +2287,11 @@ private: " typedef ", genericParamType, " ", titleCase, "Params;\n"), resultProto.getScopeId() != 0 ? kj::strTree() : kj::strTree( " typedef ", genericResultType, " ", titleCase, "Results;\n"), - " typedef ::capnp::CallContext<", shortParamType, ", ", shortResultType, "> ", - titleCase, "Context;\n" + isStreaming + ? kj::strTree(" typedef ::capnp::StreamingCallContext<", shortParamType, "> ") + : kj::strTree( + " typedef ::capnp::CallContext<", shortParamType, ", ", shortResultType, "> "), + titleCase, "Context;\n" " virtual ::kj::Promise ", identifierName, "(", titleCase, "Context context);\n"), implicitParams.size() == 0 ? kj::strTree() : kj::mv(requestMethodImpl), @@ -2207,9 +2306,31 @@ private: "}\n"), kj::strTree( - " case ", methodId, ":\n" - " return ", identifierName, "(::capnp::Capability::Server::internalGetTypedContext<\n" - " ", genericParamType, ", ", genericResultType, ">(context));\n") + " case ", methodId, ":\n", + isStreaming + ? kj::strTree( + // For streaming calls, we need to add an evalNow() here so that exceptions thrown + // directly from the call can propagate to later calls. If we don't capture the + // exception properly then the caller will never find out that this is a streaming + // call (indicated by the boolean in the return value) so won't know to propagate + // the exception. + " return {\n" + " kj::evalNow([&]() {\n" + " return ", identifierName, "(::capnp::Capability::Server::internalGetTypedStreamingContext<\n" + " ", genericParamType, ">(context));\n" + " }),\n" + " true,\n" + " ", allowCancellation, "\n" + " };\n") + : kj::strTree( + // For non-streaming calls we let exceptions just flow through for a little more + // efficiency. + " return {\n" + " ", identifierName, "(::capnp::Capability::Server::internalGetTypedContext<\n" + " ", genericParamType, ", ", genericResultType, ">(context)),\n" + " false,\n" + " ", allowCancellation, "\n" + " };\n")) }; } @@ -2275,8 +2396,10 @@ private: kj::StringTree defineText = kj::strTree( "// ", fullName, "\n", "#if !CAPNP_LITE\n", + "#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n", templates, "constexpr ::capnp::Kind ", fullName, "::_capnpPrivate::kind;\n", - templates, "constexpr ::capnp::_::RawSchema const* ", fullName, "::_capnpPrivate::schema;\n"); + templates, "constexpr ::capnp::_::RawSchema const* ", fullName, "::_capnpPrivate::schema;\n" + "#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n"); if (templateContext.isGeneric()) { auto brandInitializers = makeBrandInitializers(templateContext, schema); @@ -2358,7 +2481,8 @@ private: "public:\n", " typedef ", name, " Serves;\n" "\n" - " ::kj::Promise dispatchCall(uint64_t interfaceId, uint16_t methodId,\n" + " ::capnp::Capability::Server::DispatchCallResult dispatchCall(\n" + " uint64_t interfaceId, uint16_t methodId,\n" " ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context)\n" " override;\n" "\n" @@ -2370,7 +2494,8 @@ private: " .template castAs<", typeName, ">();\n" " }\n" "\n" - " ::kj::Promise dispatchCallInternal(uint16_t methodId,\n" + " ::capnp::Capability::Server::DispatchCallResult dispatchCallInternal(\n" + " uint16_t methodId,\n" " ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context);\n" "};\n" "#endif // !CAPNP_LITE\n" @@ -2414,7 +2539,7 @@ private: "#if !CAPNP_LITE\n", KJ_MAP(m, methods) { return kj::mv(m.sourceDefs); }, templateContext.allDecls(), - "::kj::Promise ", fullName, "::Server::dispatchCall(\n" + "::capnp::Capability::Server::DispatchCallResult ", fullName, "::Server::dispatchCall(\n" " uint64_t interfaceId, uint16_t methodId,\n" " ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) {\n" " switch (interfaceId) {\n" @@ -2431,7 +2556,7 @@ private: " }\n" "}\n", templateContext.allDecls(), - "::kj::Promise ", fullName, "::Server::dispatchCallInternal(\n" + "::capnp::Capability::Server::DispatchCallResult ", fullName, "::Server::dispatchCallInternal(\n" " uint16_t methodId,\n" " ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) {\n" " switch (methodId) {\n", @@ -2469,31 +2594,29 @@ private: const char* linkage = scope.size() == 0 ? "extern " : "static "; switch (type.which()) { - case schema::Value::BOOL: - case schema::Value::INT8: - case schema::Value::INT16: - case schema::Value::INT32: - case schema::Value::INT64: - case schema::Value::UINT8: - case schema::Value::UINT16: - case schema::Value::UINT32: - case schema::Value::UINT64: - case schema::Value::ENUM: + case schema::Type::BOOL: + case schema::Type::INT8: + case schema::Type::INT16: + case schema::Type::INT32: + case schema::Type::INT64: + case schema::Type::UINT8: + case schema::Type::UINT16: + case schema::Type::UINT32: + case schema::Type::UINT64: + case schema::Type::ENUM: return ConstText { false, kj::strTree("static constexpr ", typeName_, ' ', upperCase, " = ", literalValue(schema.getType(), constProto.getValue()), ";\n"), scope.size() == 0 ? kj::strTree() : kj::strTree( - // TODO(msvc): MSVC doesn't like definitions of constexprs, but other compilers and - // the standard require them. - "#ifndef _MSC_VER\n" + "#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL\n" "constexpr ", typeName_, ' ', scope, upperCase, ";\n" "#endif\n") }; - case schema::Value::VOID: - case schema::Value::FLOAT32: - case schema::Value::FLOAT64: { + case schema::Type::VOID: + case schema::Type::FLOAT32: + case schema::Type::FLOAT64: { // TODO(msvc): MSVC doesn't like float- or class-typed constexprs. As soon as this is fixed, // treat VOID, FLOAT32, and FLOAT64 the same as the other primitives. kj::String value = literalValue(schema.getType(), constProto.getValue()).flatten(); @@ -2507,7 +2630,7 @@ private: }; } - case schema::Value::TEXT: { + case schema::Type::TEXT: { kj::String constType = kj::strTree( "::capnp::_::ConstText<", schema.as().size(), ">").flatten(); return ConstText { @@ -2518,7 +2641,7 @@ private: }; } - case schema::Value::DATA: { + case schema::Type::DATA: { kj::String constType = kj::strTree( "::capnp::_::ConstData<", schema.as().size(), ">").flatten(); return ConstText { @@ -2529,7 +2652,7 @@ private: }; } - case schema::Value::STRUCT: { + case schema::Type::STRUCT: { kj::String constType = kj::strTree( "::capnp::_::ConstStruct<", typeName_, ">").flatten(); return ConstText { @@ -2540,7 +2663,7 @@ private: }; } - case schema::Value::LIST: { + case schema::Type::LIST: { kj::String constType = kj::strTree( "::capnp::_::ConstList<", typeName(type.asList().getElementType(), nullptr), ">") .flatten(); @@ -2552,8 +2675,8 @@ private: }; } - case schema::Value::ANY_POINTER: - case schema::Value::INTERFACE: + case schema::Type::ANY_POINTER: + case schema::Type::INTERFACE: return ConstText { false, kj::strTree(), kj::strTree() }; } @@ -2672,6 +2795,8 @@ private: auto brandDeps = makeBrandDepInitializers( makeBrandDepMap(templateContext, schema.getGeneric())); + bool mayContainCapabilities = proto.isStruct() && schema.asStruct().mayContainCapabilities(); + auto schemaDef = kj::strTree( "static const ::capnp::_::AlignedData<", rawSchema.size(), "> b_", hexId, " = {\n" " {", kj::mv(schemaLiteral), " }\n" @@ -2704,7 +2829,7 @@ private: ", nullptr, nullptr, { &s_", hexId, ", nullptr, ", brandDeps.size() == 0 ? kj::strTree("nullptr, 0, 0") : kj::strTree( "bd_", hexId, ", 0, " "sizeof(bd_", hexId, ") / sizeof(bd_", hexId, "[0])"), - ", nullptr }\n" + ", nullptr }, ", mayContainCapabilities, "\n" "};\n" "#endif // !CAPNP_LITE\n"); @@ -2933,17 +3058,19 @@ private: "// Generated by Cap'n Proto compiler, DO NOT EDIT\n" "// source: ", baseName(displayName), "\n" "\n" - "#ifndef CAPNP_INCLUDED_", kj::hex(node.getId()), "_\n", - "#define CAPNP_INCLUDED_", kj::hex(node.getId()), "_\n" + "#pragma once\n" "\n" - "#include \n", + "#include \n" + "#include \n", // work-around macro conflict with VOID hasInterfaces ? kj::strTree( "#if !CAPNP_LITE\n" "#include \n" "#endif // !CAPNP_LITE\n" ) : kj::strTree(), "\n" - "#if CAPNP_VERSION != ", CAPNP_VERSION, "\n" + "#ifndef CAPNP_VERSION\n" + "#error \"CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?\"\n" + "#elif CAPNP_VERSION != ", CAPNP_VERSION, "\n" "#error \"Version mismatch between generated code and library headers. You must " "use the same version of the Cap'n Proto compiler and library.\"\n" "#endif\n" @@ -2956,6 +3083,8 @@ private: } }, "\n" + "CAPNP_BEGIN_HEADER\n" + "\n" "namespace capnp {\n" "namespace schemas {\n" "\n", @@ -2971,8 +3100,10 @@ private: KJ_MAP(n, nodeTexts) { return kj::mv(n.readerBuilderDefs); }, separator, "\n", KJ_MAP(n, nodeTexts) { return kj::mv(n.inlineMethodDefs); }, - KJ_MAP(n, namespaceParts) { return kj::strTree("} // namespace\n"); }, "\n", - "#endif // CAPNP_INCLUDED_", kj::hex(node.getId()), "_\n"), + KJ_MAP(n, namespaceParts) { return kj::strTree("} // namespace\n"); }, + "\n" + "CAPNP_END_HEADER\n" + "\n"), kj::strTree( "// Generated by Cap'n Proto compiler, DO NOT EDIT\n" @@ -2995,42 +3126,33 @@ private: // ----------------------------------------------------------------- - void makeDirectory(kj::StringPtr path) { - KJ_IF_MAYBE(slashpos, path.findLast('/')) { - // Make the parent dir. - makeDirectory(kj::str(path.slice(0, *slashpos))); - } - - if (kj::miniposix::mkdir(path.cStr(), 0777) < 0) { - int error = errno; - if (error != EEXIST) { - KJ_FAIL_SYSCALL("mkdir(path)", error, path); - } - } - } + kj::Own fs = kj::newDiskFilesystem(); void writeFile(kj::StringPtr filename, const kj::StringTree& text) { - if (!filename.startsWith("/")) { - KJ_IF_MAYBE(slashpos, filename.findLast('/')) { - // Make the parent dir. - makeDirectory(kj::str(filename.slice(0, *slashpos))); - } - } - - int fd; - KJ_SYSCALL(fd = open(filename.cStr(), O_CREAT | O_WRONLY | O_TRUNC, 0666), filename); - kj::FdOutputStream out((kj::AutoCloseFd(fd))); - - text.visit( - [&](kj::ArrayPtr text) { - out.write(text.begin(), text.size()); - }); + // We don't use replaceFile() here because atomic replacements are actually detrimental for + // build tools: + // - It's the responsibility of the build pipeline to ensure that no one else is concurrently + // reading the file when we write it, so atomicity brings no benefit. + // - Atomic replacements force disk syncs which could slow us down for no benefit at all. + // - Opening the existing file and overwriting it may allow the filesystem to reuse + // already-allocated blocks, or maybe even notice that no actual changes occurred. + // - In a power outage scenario, the user would obviously restart the build from scratch + // anyway. + // + // At one point, in a fit of over-engineering, we used writable mmap() here. That turned out + // to be a bad idea: writable mmap() is not implemented on some filesystems, especially shared + // folders in VirtualBox. Oh well. + + auto path = kj::Path::parse(filename); + auto file = fs->getCurrent().openFile(path, + kj::WriteMode::CREATE | kj::WriteMode::MODIFY | kj::WriteMode::CREATE_PARENT); + file->writeAll(text.flatten()); } kj::MainBuilder::Validity run() { ReaderOptions options; options.traversalLimitInWords = 1 << 30; // Don't limit. - StreamFdMessageReader reader(STDIN_FILENO, options); + StreamFdMessageReader reader(0, options); auto request = reader.getRoot(); auto capnpVersion = request.getCapnpVersion(); @@ -3058,8 +3180,7 @@ private: schemaLoader.load(node); } - kj::FdOutputStream rawOut(STDOUT_FILENO); - kj::BufferedOutputStreamWrapper out(rawOut); + schemaLoader.computeOptimizationHints(); for (auto requestedFile: request.getRequestedFiles()) { auto schema = schemaLoader.get(requestedFile.getId()); diff --git a/c++/src/capnp/compiler/capnpc-capnp.c++ b/c++/src/capnp/compiler/capnpc-capnp.c++ index df4dcd949b..3660c827df 100644 --- a/c++/src/capnp/compiler/capnpc-capnp.c++ +++ b/c++/src/capnp/compiler/capnpc-capnp.c++ @@ -22,6 +22,10 @@ // This program is a code generator plugin for `capnp compile` which writes the schema back to // stdout in roughly capnpc format. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include #include "../serialize.h" #include @@ -35,6 +39,7 @@ #include #include #include +#include #if HAVE_CONFIG_H #include "config.h" @@ -489,7 +494,9 @@ private: kj::StringTree genParamList(InterfaceSchema interface, StructSchema schema, schema::Brand::Reader brand, InterfaceSchema::Method method) { - if (schema.getProto().getScopeId() == 0) { + if (schema.getProto().getId() == typeId()) { + return kj::strTree("stream"); + } else if (schema.getProto().getScopeId() == 0) { // A named parameter list. return kj::strTree("(", kj::StringTree( KJ_MAP(field, schema.getFields()) { diff --git a/c++/src/capnp/compiler/compiler.c++ b/c++/src/capnp/compiler/compiler.c++ index c85725f0d4..03eda9fc3c 100644 --- a/c++/src/capnp/compiler/compiler.c++ +++ b/c++/src/capnp/compiler/compiler.c++ @@ -30,28 +30,29 @@ #include #include #include "node-translator.h" -#include "md5.h" namespace capnp { namespace compiler { +typedef std::unordered_map> SourceInfoMap; + class Compiler::Alias { public: Alias(CompiledModule& module, Node& parent, const Expression::Reader& targetName) : module(module), parent(parent), targetName(targetName) {} - kj::Maybe compile(); + kj::Maybe compile(); private: CompiledModule& module; Node& parent; Expression::Reader targetName; - kj::Maybe target; + kj::Maybe target; Orphan brandOrphan; bool initialized = false; }; -class Compiler::Node final: public NodeTranslator::Resolver { +class Compiler::Node final: public Resolver { // Passes through four states: // - Stub: On initial construction, the Node is just a placeholder object. Its ID has been // determined, and it is placed in its parent's member table as well as the compiler's @@ -81,14 +82,15 @@ public: void loadFinalSchema(const SchemaLoader& loader); void traverse(uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader); + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo); // Get the final schema for this node, and also possibly traverse the node's children and // dependencies to ensure that they are loaded, depending on the mode. void addError(kj::StringPtr error); // Report an error on this Node. - // implements NodeTranslator::Resolver ----------------------------- + // implements Resolver --------------------------------------------- kj::Maybe resolve(kj::StringPtr name) override; kj::Maybe resolveMember(kj::StringPtr name) override; ResolvedDecl resolveBuiltin(Declaration::Which which) override; @@ -177,6 +179,9 @@ private: kj::Array auxSchemas; // Schemas for all auxiliary nodes built by the NodeTranslator. + + kj::Array sourceInfo; + // All source info structs as built by the NodeTranslator. }; Content guardedContent; // Read using getContent() only! @@ -204,19 +209,24 @@ private: void traverseNodeDependencies(const schema::Node::Reader& schemaNode, uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader); + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo); void traverseType(const schema::Type::Reader& type, uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader); + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo); void traverseBrand(const schema::Brand::Reader& brand, uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader); + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo); void traverseAnnotations(const List::Reader& annotations, uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader); + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo); void traverseDependency(uint64_t depId, uint eagerness, std::unordered_map& seen, const SchemaLoader& finalLoader, + kj::Vector& sourceInfo, bool ignoreIfNotFound = false); // Helpers for traverse(). }; @@ -253,8 +263,10 @@ public: uint64_t add(Module& module); kj::Maybe lookup(uint64_t parent, kj::StringPtr childName); + kj::Maybe getSourceInfo(uint64_t id); Orphan> getFileImportTable(Module& module, Orphanage orphanage); + Orphan> getAllSourceInfo(Orphanage orphanage); void eagerlyCompile(uint64_t id, uint eagerness, const SchemaLoader& loader); CompiledModule& addInternal(Module& parsedModule); @@ -330,6 +342,10 @@ private: std::unordered_map nodesById; // Map of nodes by ID. + std::unordered_map sourceInfoById; + // Map of SourceInfos by ID, including SourceInfos for groups and param sturcts (which are not + // listed in nodesById). + std::map> builtinDecls; std::map builtinDeclsByKind; // Map of built-in declarations, like "Int32" and "List", which make up the global scope. @@ -340,7 +356,7 @@ private: // ======================================================================================= -kj::Maybe Compiler::Alias::compile() { +kj::Maybe Compiler::Alias::compile() { if (!initialized) { initialized = true; @@ -504,8 +520,7 @@ kj::Maybe Compiler::Node::getContent(Content::State mi } content.advanceState(Content::EXPANDED); - // no break - } + } KJ_FALLTHROUGH; case Content::EXPANDED: { if (minimumState <= Content::EXPANDED) break; @@ -568,20 +583,28 @@ kj::Maybe Compiler::Node::getContent(Content::State mi })); content.advanceState(Content::BOOTSTRAP); - // no break - } + } KJ_FALLTHROUGH; case Content::BOOTSTRAP: { if (minimumState <= Content::BOOTSTRAP) break; // Create the final schema. - auto nodeSet = content.translator->finish(); + NodeTranslator::NodeSet nodeSet; + if (content.bootstrapSchema == nullptr) { + // Must have failed in an earlier stage. + KJ_ASSERT(module->getErrorReporter().hadErrors()); + nodeSet = content.translator->getBootstrapNode(); + } else { + nodeSet = content.translator->finish( + module->getCompiler().getWorkspace().bootstrapLoader.getUnbound(id)); + } + content.finalSchema = nodeSet.node; content.auxSchemas = kj::mv(nodeSet.auxNodes); + content.sourceInfo = kj::mv(nodeSet.sourceInfo); content.advanceState(Content::FINISHED); - // no break - } + } KJ_FALLTHROUGH; case Content::FINISHED: break; @@ -645,7 +668,8 @@ void Compiler::Node::loadFinalSchema(const SchemaLoader& loader) { } void Compiler::Node::traverse(uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader) { + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo) { uint& slot = seen[this]; if ((slot & eagerness) == eagerness) { // We've already covered this node. @@ -662,24 +686,26 @@ void Compiler::Node::traverse(uint eagerness, std::unordered_map& s // them with the bits above DEPENDENCIES shifted over. uint newEagerness = (eagerness & ~(DEPENDENCIES - 1)) | (eagerness / DEPENDENCIES); - traverseNodeDependencies(*schema, newEagerness, seen, finalLoader); + traverseNodeDependencies(*schema, newEagerness, seen, finalLoader, sourceInfo); for (auto& aux: content->auxSchemas) { - traverseNodeDependencies(aux, newEagerness, seen, finalLoader); + traverseNodeDependencies(aux, newEagerness, seen, finalLoader, sourceInfo); } } } + + sourceInfo.addAll(content->sourceInfo); } if (eagerness & PARENTS) { KJ_IF_MAYBE(p, parent) { - p->traverse(eagerness, seen, finalLoader); + p->traverse(eagerness, seen, finalLoader, sourceInfo); } } if (eagerness & CHILDREN) { KJ_IF_MAYBE(content, getContent(Content::EXPANDED)) { for (auto& child: content->orderedNestedNodes) { - child->traverse(eagerness, seen, finalLoader); + child->traverse(eagerness, seen, finalLoader, sourceInfo); } // Also traverse `using` declarations. @@ -693,26 +719,27 @@ void Compiler::Node::traverse(uint eagerness, std::unordered_map& s void Compiler::Node::traverseNodeDependencies( const schema::Node::Reader& schemaNode, uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader) { + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo) { switch (schemaNode.which()) { case schema::Node::STRUCT: for (auto field: schemaNode.getStruct().getFields()) { switch (field.which()) { case schema::Field::SLOT: - traverseType(field.getSlot().getType(), eagerness, seen, finalLoader); + traverseType(field.getSlot().getType(), eagerness, seen, finalLoader, sourceInfo); break; case schema::Field::GROUP: // Aux node will be scanned later. break; } - traverseAnnotations(field.getAnnotations(), eagerness, seen, finalLoader); + traverseAnnotations(field.getAnnotations(), eagerness, seen, finalLoader, sourceInfo); } break; case schema::Node::ENUM: for (auto enumerant: schemaNode.getEnum().getEnumerants()) { - traverseAnnotations(enumerant.getAnnotations(), eagerness, seen, finalLoader); + traverseAnnotations(enumerant.getAnnotations(), eagerness, seen, finalLoader, sourceInfo); } break; @@ -721,38 +748,41 @@ void Compiler::Node::traverseNodeDependencies( for (auto superclass: interface.getSuperclasses()) { uint64_t superclassId = superclass.getId(); if (superclassId != 0) { // if zero, we reported an error earlier - traverseDependency(superclassId, eagerness, seen, finalLoader); + traverseDependency(superclassId, eagerness, seen, finalLoader, sourceInfo); } - traverseBrand(superclass.getBrand(), eagerness, seen, finalLoader); + traverseBrand(superclass.getBrand(), eagerness, seen, finalLoader, sourceInfo); } for (auto method: interface.getMethods()) { - traverseDependency(method.getParamStructType(), eagerness, seen, finalLoader, true); - traverseBrand(method.getParamBrand(), eagerness, seen, finalLoader); - traverseDependency(method.getResultStructType(), eagerness, seen, finalLoader, true); - traverseBrand(method.getResultBrand(), eagerness, seen, finalLoader); - traverseAnnotations(method.getAnnotations(), eagerness, seen, finalLoader); + traverseDependency( + method.getParamStructType(), eagerness, seen, finalLoader, sourceInfo, true); + traverseBrand(method.getParamBrand(), eagerness, seen, finalLoader, sourceInfo); + traverseDependency( + method.getResultStructType(), eagerness, seen, finalLoader, sourceInfo, true); + traverseBrand(method.getResultBrand(), eagerness, seen, finalLoader, sourceInfo); + traverseAnnotations(method.getAnnotations(), eagerness, seen, finalLoader, sourceInfo); } break; } case schema::Node::CONST: - traverseType(schemaNode.getConst().getType(), eagerness, seen, finalLoader); + traverseType(schemaNode.getConst().getType(), eagerness, seen, finalLoader, sourceInfo); break; case schema::Node::ANNOTATION: - traverseType(schemaNode.getAnnotation().getType(), eagerness, seen, finalLoader); + traverseType(schemaNode.getAnnotation().getType(), eagerness, seen, finalLoader, sourceInfo); break; default: break; } - traverseAnnotations(schemaNode.getAnnotations(), eagerness, seen, finalLoader); + traverseAnnotations(schemaNode.getAnnotations(), eagerness, seen, finalLoader, sourceInfo); } void Compiler::Node::traverseType(const schema::Type::Reader& type, uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader) { + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo) { uint64_t id = 0; schema::Brand::Reader brand; switch (type.which()) { @@ -769,20 +799,21 @@ void Compiler::Node::traverseType(const schema::Type::Reader& type, uint eagerne brand = type.getInterface().getBrand(); break; case schema::Type::LIST: - traverseType(type.getList().getElementType(), eagerness, seen, finalLoader); + traverseType(type.getList().getElementType(), eagerness, seen, finalLoader, sourceInfo); return; default: return; } - traverseDependency(id, eagerness, seen, finalLoader); - traverseBrand(brand, eagerness, seen, finalLoader); + traverseDependency(id, eagerness, seen, finalLoader, sourceInfo); + traverseBrand(brand, eagerness, seen, finalLoader, sourceInfo); } void Compiler::Node::traverseBrand( const schema::Brand::Reader& brand, uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader) { + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo) { for (auto scope: brand.getScopes()) { switch (scope.which()) { case schema::Brand::Scope::BIND: @@ -791,7 +822,7 @@ void Compiler::Node::traverseBrand( case schema::Brand::Binding::UNBOUND: break; case schema::Brand::Binding::TYPE: - traverseType(binding.getType(), eagerness, seen, finalLoader); + traverseType(binding.getType(), eagerness, seen, finalLoader, sourceInfo); break; } } @@ -805,9 +836,10 @@ void Compiler::Node::traverseBrand( void Compiler::Node::traverseDependency(uint64_t depId, uint eagerness, std::unordered_map& seen, const SchemaLoader& finalLoader, + kj::Vector& sourceInfo, bool ignoreIfNotFound) { KJ_IF_MAYBE(node, module->getCompiler().findNode(depId)) { - node->traverse(eagerness, seen, finalLoader); + node->traverse(eagerness, seen, finalLoader, sourceInfo); } else if (!ignoreIfNotFound) { KJ_FAIL_ASSERT("Dependency ID not present in compiler?", depId); } @@ -816,10 +848,11 @@ void Compiler::Node::traverseDependency(uint64_t depId, uint eagerness, void Compiler::Node::traverseAnnotations(const List::Reader& annotations, uint eagerness, std::unordered_map& seen, - const SchemaLoader& finalLoader) { + const SchemaLoader& finalLoader, + kj::Vector& sourceInfo) { for (auto annotation: annotations) { KJ_IF_MAYBE(node, module->getCompiler().findNode(annotation.getId())) { - node->traverse(eagerness, seen, finalLoader); + node->traverse(eagerness, seen, finalLoader, sourceInfo); } } } @@ -829,7 +862,7 @@ void Compiler::Node::addError(kj::StringPtr error) { module->getErrorReporter().addError(startByte, endByte, error); } -kj::Maybe +kj::Maybe Compiler::Node::resolve(kj::StringPtr name) { // Check members. KJ_IF_MAYBE(member, resolveMember(name)) { @@ -859,7 +892,7 @@ Compiler::Node::resolve(kj::StringPtr name) { } } -kj::Maybe +kj::Maybe Compiler::Node::resolveMember(kj::StringPtr name) { if (isBuiltin) return nullptr; @@ -884,25 +917,25 @@ Compiler::Node::resolveMember(kj::StringPtr name) { return nullptr; } -NodeTranslator::Resolver::ResolvedDecl Compiler::Node::resolveBuiltin(Declaration::Which which) { +Resolver::ResolvedDecl Compiler::Node::resolveBuiltin(Declaration::Which which) { auto& b = module->getCompiler().getBuiltin(which); return { b.id, b.genericParamCount, 0, b.kind, &b, nullptr }; } -NodeTranslator::Resolver::ResolvedDecl Compiler::Node::resolveId(uint64_t id) { +Resolver::ResolvedDecl Compiler::Node::resolveId(uint64_t id) { auto& n = KJ_ASSERT_NONNULL(module->getCompiler().findNode(id)); uint64_t parentId = n.parent.map([](Node& n) { return n.id; }).orDefault(0); return { n.id, n.genericParamCount, parentId, n.kind, &n, nullptr }; } -kj::Maybe Compiler::Node::getParent() { +kj::Maybe Compiler::Node::getParent() { return parent.map([](Node& parent) { uint64_t scopeId = parent.parent.map([](Node& gp) { return gp.id; }).orDefault(0); return ResolvedDecl { parent.id, parent.genericParamCount, scopeId, parent.kind, &parent, nullptr }; }); } -NodeTranslator::Resolver::ResolvedDecl Compiler::Node::getTopScope() { +Resolver::ResolvedDecl Compiler::Node::getTopScope() { Node& node = module->getRootNode(); return ResolvedDecl { node.id, 0, 0, node.kind, &node, nullptr }; } @@ -930,7 +963,7 @@ kj::Maybe Compiler::Node::resolveFinalSchema(uint64_t id) } } -kj::Maybe +kj::Maybe Compiler::Node::resolveImport(kj::StringPtr name) { KJ_IF_MAYBE(m, module->importRelative(name)) { Node& root = m->getRootNode(); @@ -1024,6 +1057,25 @@ static void findImports(Expression::Reader exp, std::set& output) } } +static void findImports(Declaration::ParamList::Reader paramList, std::set& output) { + switch (paramList.which()) { + case Declaration::ParamList::NAMED_LIST: + for (auto param: paramList.getNamedList()) { + findImports(param.getType(), output); + for (auto ann: param.getAnnotations()) { + findImports(ann.getName(), output); + } + } + break; + case Declaration::ParamList::TYPE: + findImports(paramList.getType(), output); + break; + case Declaration::ParamList::STREAM: + output.insert("/capnp/stream.capnp"); + break; + } +} + static void findImports(Declaration::Reader decl, std::set& output) { switch (decl.which()) { case Declaration::USING: @@ -1043,30 +1095,9 @@ static void findImports(Declaration::Reader decl, std::set& outpu case Declaration::METHOD: { auto method = decl.getMethod(); - auto params = method.getParams(); - if (params.isNamedList()) { - for (auto param: params.getNamedList()) { - findImports(param.getType(), output); - for (auto ann: param.getAnnotations()) { - findImports(ann.getName(), output); - } - } - } else { - findImports(params.getType(), output); - } - + findImports(method.getParams(), output); if (method.getResults().isExplicit()) { - auto results = method.getResults().getExplicit(); - if (results.isNamedList()) { - for (auto param: results.getNamedList()) { - findImports(param.getType(), output); - for (auto ann: param.getAnnotations()) { - findImports(ann.getName(), output); - } - } - } else { - findImports(results.getType(), output); - } + findImports(method.getResults().getExplicit(), output); } break; } @@ -1203,16 +1234,12 @@ Compiler::Node& Compiler::Impl::getBuiltin(Declaration::Which which) { return *iter->second; } -uint64_t Compiler::Impl::add(Module& module) { - return addInternal(module).getRootNode().getId(); -} - kj::Maybe Compiler::Impl::lookup(uint64_t parent, kj::StringPtr childName) { // Looking up members does not use the workspace, so we don't need to lock it. KJ_IF_MAYBE(parentNode, findNode(parent)) { KJ_IF_MAYBE(child, parentNode->resolveMember(childName)) { - if (child->is()) { - return child->get().id; + if (child->is()) { + return child->get().id; } else { // An alias. We don't support looking up aliases with this method. return nullptr; @@ -1225,16 +1252,48 @@ kj::Maybe Compiler::Impl::lookup(uint64_t parent, kj::StringPtr childN } } +kj::Maybe Compiler::Impl::getSourceInfo(uint64_t id) { + auto iter = sourceInfoById.find(id); + if (iter == sourceInfoById.end()) { + return nullptr; + } else { + return iter->second; + } +} + Orphan> Compiler::Impl::getFileImportTable(Module& module, Orphanage orphanage) { return addInternal(module).getFileImportTable(orphanage); } +Orphan> Compiler::Impl::getAllSourceInfo(Orphanage orphanage) { + auto result = orphanage.newOrphan>(sourceInfoById.size()); + + auto builder = result.get(); + size_t i = 0; + for (auto& entry: sourceInfoById) { + builder.setWithCaveats(i++, entry.second); + } + + return result; +} + void Compiler::Impl::eagerlyCompile(uint64_t id, uint eagerness, const SchemaLoader& finalLoader) { KJ_IF_MAYBE(node, findNode(id)) { std::unordered_map seen; - node->traverse(eagerness, seen, finalLoader); + kj::Vector sourceInfos; + node->traverse(eagerness, seen, finalLoader, sourceInfos); + + // Copy the SourceInfo structures into permanent space so that they aren't invalidated when + // clearWorkspace() is called. + for (auto& sourceInfo: sourceInfos) { + auto words = nodeArena.allocateArray(sourceInfo.totalSize().wordCount + 1); + memset(words.begin(), 0, words.asBytes().size()); + copyToUnchecked(sourceInfo, words); + sourceInfoById.insert(std::make_pair(sourceInfo.getId(), + readMessageUnchecked(words.begin()))); + } } else { KJ_FAIL_REQUIRE("id did not come from this Compiler.", id); } @@ -1263,19 +1322,28 @@ Compiler::Compiler(AnnotationFlag annotationFlag) loader(*this) {} Compiler::~Compiler() noexcept(false) {} -uint64_t Compiler::add(Module& module) const { - return impl.lockExclusive()->get()->add(module); +Compiler::ModuleScope Compiler::add(Module& module) const { + Node& root = impl.lockExclusive()->get()->addInternal(module).getRootNode(); + return ModuleScope(*this, root.getId(), root); } kj::Maybe Compiler::lookup(uint64_t parent, kj::StringPtr childName) const { return impl.lockExclusive()->get()->lookup(parent, childName); } +kj::Maybe Compiler::getSourceInfo(uint64_t id) const { + return impl.lockExclusive()->get()->getSourceInfo(id); +} + Orphan> Compiler::getFileImportTable(Module& module, Orphanage orphanage) const { return impl.lockExclusive()->get()->getFileImportTable(module, orphanage); } +Orphan> Compiler::getAllSourceInfo(Orphanage orphanage) const { + return impl.lockExclusive()->get()->getAllSourceInfo(orphanage); +} + void Compiler::eagerlyCompile(uint64_t id, uint eagerness) const { impl.lockExclusive()->get()->eagerlyCompile(id, eagerness, loader); } @@ -1288,5 +1356,116 @@ void Compiler::load(const SchemaLoader& loader, uint64_t id) const { impl.lockExclusive()->get()->loadFinal(loader, id); } +// ----------------------------------------------------------------------------- + +class Compiler::ErrorIgnorer: public ErrorReporter { +public: + void addError(uint32_t startByte, uint32_t endByte, kj::StringPtr message) override {} + bool hadErrors() override { return false; } + + static ErrorIgnorer instance; +}; +Compiler::ErrorIgnorer Compiler::ErrorIgnorer::instance; + +kj::Maybe Compiler::CompiledType::getSchema() { + capnp::word scratch[32]; + memset(&scratch, 0, sizeof(scratch)); + capnp::MallocMessageBuilder message(scratch); + auto builder = message.getRoot(); + + { + auto lock = compiler.impl.lockShared(); + decl.get(lock).compileAsType(ErrorIgnorer::instance, builder); + } + + // No need to pass `scope` as second parameter since CompiledType always represents a type + // expression evaluated free-standing, not in any scope. + return compiler.loader.getType(builder.asReader()); +} + +Compiler::CompiledType Compiler::CompiledType::clone() { + kj::ExternalMutexGuarded newDecl; + { + auto lock = compiler.impl.lockExclusive(); + newDecl.set(lock, kj::cp(decl.get(lock))); + } + return CompiledType(compiler, kj::mv(newDecl)); +} + +kj::Maybe Compiler::CompiledType::getMember(kj::StringPtr name) { + kj::ExternalMutexGuarded newDecl; + bool found = false; + + { + auto lock = compiler.impl.lockShared(); + KJ_IF_MAYBE(member, decl.get(lock).getMember(name, {})) { + newDecl.set(lock, kj::mv(*member)); + found = true; + } + } + + if (found) { + return CompiledType(compiler, kj::mv(newDecl)); + } else { + return nullptr; + } +} + +kj::Maybe Compiler::CompiledType::applyBrand( + kj::Array arguments) { + kj::ExternalMutexGuarded newDecl; + bool found = false; + + { + auto lock = compiler.impl.lockShared(); + auto args = KJ_MAP(arg, arguments) { return kj::mv(arg.decl.get(lock)); }; + KJ_IF_MAYBE(member, decl.get(lock).applyParams(kj::mv(args), {})) { + newDecl.set(lock, kj::mv(*member)); + found = true; + } + } + + if (found) { + return CompiledType(compiler, kj::mv(newDecl)); + } else { + return nullptr; + } +} + +Compiler::CompiledType Compiler::ModuleScope::getRoot() { + kj::ExternalMutexGuarded newDecl; + + { + auto lock = compiler.impl.lockExclusive(); + auto brandScope = kj::refcounted(ErrorIgnorer::instance, node.getId(), 0, node); + Resolver::ResolvedDecl decl { node.getId(), 0, 0, node.getKind(), &node, nullptr }; + newDecl.set(lock, BrandedDecl(kj::mv(decl), kj::mv(brandScope), {})); + } + + return CompiledType(compiler, kj::mv(newDecl)); +} + +kj::Maybe Compiler::ModuleScope::evalType( + Expression::Reader expression, ErrorReporter& errorReporter) { + kj::ExternalMutexGuarded newDecl; + bool found = false; + + { + auto lock = compiler.impl.lockExclusive(); + auto brandScope = kj::refcounted(errorReporter, node.getId(), 0, node); + KJ_IF_MAYBE(result, brandScope->compileDeclExpression( + expression, node, ImplicitParams::none())) { + newDecl.set(lock, kj::mv(*result)); + found = true; + }; + } + + if (found) { + return CompiledType(compiler, kj::mv(newDecl)); + } else { + return nullptr; + } +} + } // namespace compiler } // namespace capnp diff --git a/c++/src/capnp/compiler/compiler.h b/c++/src/capnp/compiler/compiler.h index b957405ae4..375bf59d43 100644 --- a/c++/src/capnp/compiler/compiler.h +++ b/c++/src/capnp/compiler/compiler.h @@ -19,17 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_COMPILER_COMPILER_H_ -#define CAPNP_COMPILER_COMPILER_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include #include #include "error-reporter.h" +#include "generics.h" + +CAPNP_BEGIN_HEADER namespace capnp { namespace compiler { @@ -57,6 +55,8 @@ class Compiler final: private SchemaLoader::LazyLoadCallback { // // This class is thread-safe, hence all its methods are const. + class Node; + public: enum AnnotationFlag { COMPILE_ANNOTATIONS, @@ -76,13 +76,68 @@ class Compiler final: private SchemaLoader::LazyLoadCallback { explicit Compiler(AnnotationFlag annotationFlag = COMPILE_ANNOTATIONS); ~Compiler() noexcept(false); - KJ_DISALLOW_COPY(Compiler); + KJ_DISALLOW_COPY_AND_MOVE(Compiler); + + class CompiledType { + // Represents a compiled type expression, from which you can traverse to nested types, apply + // generics, etc. + + public: + CompiledType clone(); + // Make another CompiledType pointing to the same type. + + kj::Maybe getSchema(); + // Evaluate to a type schema. Returns null if this "type" cannot actually be used as a field + // type, e.g. because it's the pseudo-type representing a file's top-level scope. + + kj::Maybe getMember(kj::StringPtr name); + // Look up a nested declaration. Returns null if there is no such member, or if the member is + // not a type. + + kj::Maybe applyBrand(kj::Array arguments); + // If this is a generic type, specializes apply a brand to it. Returns null if this is + // not a generic type or too many arguments were specified. + + private: + const Compiler& compiler; + kj::ExternalMutexGuarded decl; + + CompiledType(const Compiler& compiler, kj::ExternalMutexGuarded decl) + : compiler(compiler), decl(kj::mv(decl)) {} + + friend class Compiler; + }; - uint64_t add(Module& module) const; - // Add a module to the Compiler, returning the module's file ID. The ID can then be looked up in - // the `SchemaLoader` returned by `getLoader()`. However, the SchemaLoader may behave as if the - // schema node doesn't exist if any compilation errors occur (reported via the module's - // ErrorReporter). The module is parsed at the time `add()` is called, but not fully compiled -- + class ModuleScope { + // Result of compiling a module. + + public: + uint64_t getId() { return id; } + + CompiledType getRoot(); + // Get a CompiledType representing the root, which can be used to programmatically look up + // declarations. + + kj::Maybe evalType(Expression::Reader expression, ErrorReporter& errorReporter); + // Evaluate some type expression within the scope of this module. + // + // Returns null if errors prevented evaluation; the errors will have been reported to + // `errorReporter`. + + private: + const Compiler& compiler; + uint64_t id; + Node& node; + + ModuleScope(const Compiler& compiler, uint64_t id, Node& node) + : compiler(compiler), id(id), node(node) {} + + friend class Compiler; + }; + + ModuleScope add(Module& module) const; + // Add a module to the Compiler, returning a CompiledType representing the top-level scope of + // the module. The module is parsed at the time `add()` is called, but not fully compiled -- // individual schema nodes are compiled lazily. If you want to force eager compilation, // see `eagerlyCompile()`, below. @@ -90,11 +145,20 @@ class Compiler final: private SchemaLoader::LazyLoadCallback { // Given the type ID of a schema node, find the ID of a node nested within it. Throws an // exception if the parent ID is not recognized; returns null if the parent has no child of the // given name. Neither the parent nor the child schema node is actually compiled. + // + // TODO(cleanup): This interface does not handle generics correctly. Use the + // ModuleScope/CompiledType interface instead. + + kj::Maybe getSourceInfo(uint64_t id) const; + // Get the SourceInfo for the given type ID, if available. Orphan> getFileImportTable(Module& module, Orphanage orphanage) const; // Build the import table for the CodeGeneratorRequest for the given module. + Orphan> getAllSourceInfo(Orphanage orphanage) const; + // Gets the SourceInfo structs for all nodes parsed by the compiler. + enum Eagerness: uint32_t { // Flags specifying how eager to be about compilation. These are intended to be bitwise OR'd. // Used with the method `eagerlyCompile()`. @@ -133,11 +197,11 @@ class Compiler final: private SchemaLoader::LazyLoadCallback { // dependencies. PARENTS = 1 << 1, - // Eagerly compile all lexical parents of the requested node. Only meaningful in conjuction + // Eagerly compile all lexical parents of the requested node. Only meaningful in conjunction // with NODE. CHILDREN = 1 << 2, - // Eagerly compile all of the node's lexically nested nodes. Only meaningful in conjuction + // Eagerly compile all of the node's lexically nested nodes. Only meaningful in conjunction // with NODE. DEPENDENCIES = NODE << 15, @@ -186,8 +250,8 @@ class Compiler final: private SchemaLoader::LazyLoadCallback { SchemaLoader loader; class CompiledModule; - class Node; class Alias; + class ErrorIgnorer; void load(const SchemaLoader& loader, uint64_t id) const override; }; @@ -195,4 +259,4 @@ class Compiler final: private SchemaLoader::LazyLoadCallback { } // namespace compiler } // namespace capnp -#endif // CAPNP_COMPILER_COMPILER_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/error-reporter.h b/c++/src/capnp/compiler/error-reporter.h index e4a04bfb62..1fc66c52cd 100644 --- a/c++/src/capnp/compiler/error-reporter.h +++ b/c++/src/capnp/compiler/error-reporter.h @@ -19,17 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef ERROR_REPORTER_H_ -#define ERROR_REPORTER_H_ +#pragma once -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif - -#include "../common.h" +#include #include #include #include +#include + +CAPNP_BEGIN_HEADER namespace capnp { namespace compiler { @@ -68,7 +66,8 @@ class GlobalErrorReporter { uint column; }; - virtual void addError(kj::StringPtr file, SourcePos start, SourcePos end, + virtual void addError(const kj::ReadableDirectory& directory, kj::PathPtr path, + SourcePos start, SourcePos end, kj::StringPtr message) = 0; // Report an error at the given location in the given file. @@ -94,4 +93,4 @@ class LineBreakTable { } // namespace compiler } // namespace capnp -#endif // ERROR_REPORTER_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/evolution-test.c++ b/c++/src/capnp/compiler/evolution-test.c++ index b1daec42ec..964105a612 100644 --- a/c++/src/capnp/compiler/evolution-test.c++ +++ b/c++/src/capnp/compiler/evolution-test.c++ @@ -26,6 +26,10 @@ // the types are expected to be compatible, the test also constructs an instance of the old // type and reads it as the new type, and vice versa. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include #include #include @@ -678,7 +682,7 @@ static kj::Maybe loadFile( uint sharedOrdinalCount) { Compiler compiler; ModuleImpl module(file); - KJ_ASSERT(compiler.add(module) == 0x8123456789abcdefllu); + KJ_ASSERT(compiler.add(module).getId() == 0x8123456789abcdefllu); if (allNodes) { // Eagerly compile and load the whole thing. @@ -868,7 +872,7 @@ public: } kj::MainBuilder::Validity run() { - // https://github.com/sandstorm-io/capnproto/issues/344 describes an obscure bug in the layout + // https://github.com/capnproto/capnproto/issues/344 describes an obscure bug in the layout // algorithm, the fix for which breaks backwards-compatibility for any schema triggering the // bug. In order to avoid silently breaking protocols, we are temporarily throwing an exception // in cases where this bug would have occurred, so that people can decide what to do. diff --git a/c++/src/capnp/compiler/generics.c++ b/c++/src/capnp/compiler/generics.c++ new file mode 100644 index 0000000000..71df4e14da --- /dev/null +++ b/c++/src/capnp/compiler/generics.c++ @@ -0,0 +1,656 @@ +// Copyright (c) 2013-2020 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "generics.h" +#include "parser.h" // for expressionString() + +namespace capnp { +namespace compiler { + +BrandedDecl::BrandedDecl(BrandedDecl& other) + : body(other.body), + source(other.source) { + if (body.is()) { + brand = kj::addRef(*other.brand); + } +} + +BrandedDecl& BrandedDecl::operator=(BrandedDecl& other) { + body = other.body; + source = other.source; + if (body.is()) { + brand = kj::addRef(*other.brand); + } + return *this; +} + +kj::Maybe BrandedDecl::applyParams( + kj::Array params, Expression::Reader subSource) { + if (body.is()) { + return nullptr; + } else { + return brand->setParams(kj::mv(params), body.get().kind, subSource) + .map([&](kj::Own&& scope) { + BrandedDecl result = *this; + result.brand = kj::mv(scope); + result.source = subSource; + return result; + }); + } +} + +kj::Maybe BrandedDecl::getMember( + kj::StringPtr memberName, Expression::Reader subSource) { + if (body.is()) { + return nullptr; + } else KJ_IF_MAYBE(r, body.get().resolver->resolveMember(memberName)) { + return brand->interpretResolve(*body.get().resolver, *r, subSource); + } else { + return nullptr; + } +} + +kj::Maybe BrandedDecl::getKind() { + if (body.is()) { + return nullptr; + } else { + return body.get().kind; + } +} + +kj::Maybe BrandedDecl::getListParam() { + KJ_REQUIRE(body.is()); + + auto& decl = body.get(); + KJ_REQUIRE(decl.kind == Declaration::BUILTIN_LIST); + + auto params = KJ_ASSERT_NONNULL(brand->getParams(decl.id)); + if (params.size() != 1) { + return nullptr; + } else { + return params[0]; + } +} + +Resolver::ResolvedParameter BrandedDecl::asVariable() { + KJ_REQUIRE(body.is()); + + return body.get(); +} + +bool BrandedDecl::compileAsType( + ErrorReporter& errorReporter, schema::Type::Builder target) { + KJ_IF_MAYBE(kind, getKind()) { + switch (*kind) { + case Declaration::ENUM: { + auto enum_ = target.initEnum(); + enum_.setTypeId(getIdAndFillBrand([&]() { return enum_.initBrand(); })); + return true; + } + + case Declaration::STRUCT: { + auto struct_ = target.initStruct(); + struct_.setTypeId(getIdAndFillBrand([&]() { return struct_.initBrand(); })); + return true; + } + + case Declaration::INTERFACE: { + auto interface = target.initInterface(); + interface.setTypeId(getIdAndFillBrand([&]() { return interface.initBrand(); })); + return true; + } + + case Declaration::BUILTIN_LIST: { + auto elementType = target.initList().initElementType(); + + KJ_IF_MAYBE(param, getListParam()) { + if (!param->compileAsType(errorReporter, elementType)) { + return false; + } + } else { + addError(errorReporter, "'List' requires exactly one parameter."); + return false; + } + + if (elementType.isAnyPointer()) { + auto unconstrained = elementType.getAnyPointer().getUnconstrained(); + + if (unconstrained.isAnyKind()) { + addError(errorReporter, "'List(AnyPointer)' is not supported."); + // Seeing List(AnyPointer) later can mess things up, so change the type to Void. + elementType.setVoid(); + return false; + } else if (unconstrained.isStruct()) { + addError(errorReporter, "'List(AnyStruct)' is not supported."); + // Seeing List(AnyStruct) later can mess things up, so change the type to Void. + elementType.setVoid(); + return false; + } + } + + return true; + } + + case Declaration::BUILTIN_VOID: target.setVoid(); return true; + case Declaration::BUILTIN_BOOL: target.setBool(); return true; + case Declaration::BUILTIN_INT8: target.setInt8(); return true; + case Declaration::BUILTIN_INT16: target.setInt16(); return true; + case Declaration::BUILTIN_INT32: target.setInt32(); return true; + case Declaration::BUILTIN_INT64: target.setInt64(); return true; + case Declaration::BUILTIN_U_INT8: target.setUint8(); return true; + case Declaration::BUILTIN_U_INT16: target.setUint16(); return true; + case Declaration::BUILTIN_U_INT32: target.setUint32(); return true; + case Declaration::BUILTIN_U_INT64: target.setUint64(); return true; + case Declaration::BUILTIN_FLOAT32: target.setFloat32(); return true; + case Declaration::BUILTIN_FLOAT64: target.setFloat64(); return true; + case Declaration::BUILTIN_TEXT: target.setText(); return true; + case Declaration::BUILTIN_DATA: target.setData(); return true; + + case Declaration::BUILTIN_OBJECT: + addError(errorReporter, + "As of Cap'n Proto 0.4, 'Object' has been renamed to 'AnyPointer'. Sorry for the " + "inconvenience, and thanks for being an early adopter. :)"); + KJ_FALLTHROUGH; + case Declaration::BUILTIN_ANY_POINTER: + target.initAnyPointer().initUnconstrained().setAnyKind(); + return true; + case Declaration::BUILTIN_ANY_STRUCT: + target.initAnyPointer().initUnconstrained().setStruct(); + return true; + case Declaration::BUILTIN_ANY_LIST: + target.initAnyPointer().initUnconstrained().setList(); + return true; + case Declaration::BUILTIN_CAPABILITY: + target.initAnyPointer().initUnconstrained().setCapability(); + return true; + + case Declaration::FILE: + case Declaration::USING: + case Declaration::CONST: + case Declaration::ENUMERANT: + case Declaration::FIELD: + case Declaration::UNION: + case Declaration::GROUP: + case Declaration::METHOD: + case Declaration::ANNOTATION: + case Declaration::NAKED_ID: + case Declaration::NAKED_ANNOTATION: + addError(errorReporter, kj::str("'", toString(), "' is not a type.")); + return false; + } + + KJ_UNREACHABLE; + } else { + // Oh, this is a type variable. + auto var = asVariable(); + if (var.id == 0) { + // This is actually a method implicit parameter. + auto builder = target.initAnyPointer().initImplicitMethodParameter(); + builder.setParameterIndex(var.index); + return true; + } else { + auto builder = target.initAnyPointer().initParameter(); + builder.setScopeId(var.id); + builder.setParameterIndex(var.index); + return true; + } + } +} + +Resolver::ResolveResult BrandedDecl::asResolveResult( + uint64_t scopeId, schema::Brand::Builder brandBuilder) { + auto result = body; + if (result.is()) { + // May need to compile our context as the "brand". + + result.get().scopeId = scopeId; + + getIdAndFillBrand([&]() { + result.get().brand = brandBuilder.asReader(); + return brandBuilder; + }); + } + return result; +} + +kj::String BrandedDecl::toString() { + return expressionString(source); +} + +kj::String BrandedDecl::toDebugString() { + if (body.is()) { + auto variable = body.get(); + return kj::str("variable(", variable.id, ", ", variable.index, ")"); + } else { + auto decl = body.get(); + return kj::str("decl(", decl.id, ", ", (uint)decl.kind, "')"); + } +} + +BrandScope::BrandScope(ErrorReporter& errorReporter, uint64_t startingScopeId, + uint startingScopeParamCount, Resolver& startingScope) + : errorReporter(errorReporter), parent(nullptr), leafId(startingScopeId), + leafParamCount(startingScopeParamCount), inherited(true) { + // Create all lexical parent scopes, all with no brand bindings. + KJ_IF_MAYBE(p, startingScope.getParent()) { + parent = kj::refcounted( + errorReporter, p->id, p->genericParamCount, *p->resolver); + } +} + +bool BrandScope::isGeneric() { + if (leafParamCount > 0) return true; + + KJ_IF_MAYBE(p, parent) { + return p->get()->isGeneric(); + } else { + return false; + } +} + +kj::Own BrandScope::push(uint64_t typeId, uint paramCount) { + return kj::refcounted(kj::addRef(*this), typeId, paramCount); +} + +kj::Maybe> BrandScope::setParams( + kj::Array params, Declaration::Which genericType, Expression::Reader source) { + if (this->params.size() != 0) { + errorReporter.addErrorOn(source, "Double-application of generic parameters."); + return nullptr; + } else if (params.size() > leafParamCount) { + if (leafParamCount == 0) { + errorReporter.addErrorOn(source, "Declaration does not accept generic parameters."); + } else { + errorReporter.addErrorOn(source, "Too many generic parameters."); + } + return nullptr; + } else if (params.size() < leafParamCount) { + errorReporter.addErrorOn(source, "Not enough generic parameters."); + return nullptr; + } else { + if (genericType != Declaration::BUILTIN_LIST) { + for (auto& param: params) { + KJ_IF_MAYBE(kind, param.getKind()) { + switch (*kind) { + case Declaration::BUILTIN_LIST: + case Declaration::BUILTIN_TEXT: + case Declaration::BUILTIN_DATA: + case Declaration::BUILTIN_ANY_POINTER: + case Declaration::STRUCT: + case Declaration::INTERFACE: + break; + + default: + param.addError(errorReporter, + "Sorry, only pointer types can be used as generic parameters."); + break; + } + } + } + } + + return kj::refcounted(*this, kj::mv(params)); + } +} + +kj::Own BrandScope::pop(uint64_t newLeafId) { + if (leafId == newLeafId) { + return kj::addRef(*this); + } + KJ_IF_MAYBE(p, parent) { + return (*p)->pop(newLeafId); + } else { + // Looks like we're moving into a whole top-level scope. + return kj::refcounted(errorReporter, newLeafId); + } +} + +kj::Maybe BrandScope::lookupParameter( + Resolver& resolver, uint64_t scopeId, uint index) { + // Returns null if the param should be inherited from the client scope. + + if (scopeId == leafId) { + if (index < params.size()) { + return params[index]; + } else if (inherited) { + return nullptr; + } else { + // Unbound and not inherited, so return AnyPointer. + auto decl = resolver.resolveBuiltin(Declaration::BUILTIN_ANY_POINTER); + return BrandedDecl(decl, + evaluateBrand(resolver, decl, List::Reader()), + Expression::Reader()); + } + } else KJ_IF_MAYBE(p, parent) { + return p->get()->lookupParameter(resolver, scopeId, index); + } else { + KJ_FAIL_REQUIRE("scope is not a parent"); + } +} + +kj::Maybe> BrandScope::getParams(uint64_t scopeId) { + // Returns null if params at the requested scope should be inherited from the client scope. + + if (scopeId == leafId) { + if (inherited) { + return nullptr; + } else { + return params.asPtr(); + } + } else KJ_IF_MAYBE(p, parent) { + return p->get()->getParams(scopeId); + } else { + KJ_FAIL_REQUIRE("scope is not a parent"); + } +} + +BrandedDecl BrandScope::interpretResolve( + Resolver& resolver, Resolver::ResolveResult& result, Expression::Reader source) { + if (result.is()) { + auto& decl = result.get(); + + auto scope = pop(decl.scopeId); + KJ_IF_MAYBE(brand, decl.brand) { + scope = scope->evaluateBrand(resolver, decl, brand->getScopes()); + } else { + scope = scope->push(decl.id, decl.genericParamCount); + } + + return BrandedDecl(decl, kj::mv(scope), source); + } else { + auto& param = result.get(); + KJ_IF_MAYBE(p, lookupParameter(resolver, param.id, param.index)) { + return *p; + } else { + return BrandedDecl(param, source); + } + } +} + +kj::Own BrandScope::evaluateBrand( + Resolver& resolver, Resolver::ResolvedDecl decl, + List::Reader brand, uint index) { + auto result = kj::refcounted(errorReporter, decl.id); + result->leafParamCount = decl.genericParamCount; + + // Fill in `params`. + if (index < brand.size()) { + auto nextScope = brand[index]; + if (decl.id == nextScope.getScopeId()) { + // Initialize our parameters. + + switch (nextScope.which()) { + case schema::Brand::Scope::BIND: { + auto bindings = nextScope.getBind(); + auto params = kj::heapArrayBuilder(bindings.size()); + for (auto binding: bindings) { + switch (binding.which()) { + case schema::Brand::Binding::UNBOUND: { + // Build an AnyPointer-equivalent. + auto anyPointerDecl = resolver.resolveBuiltin(Declaration::BUILTIN_ANY_POINTER); + params.add(BrandedDecl(anyPointerDecl, + kj::refcounted(errorReporter, anyPointerDecl.scopeId), + Expression::Reader())); + break; + } + + case schema::Brand::Binding::TYPE: + // Reverse this schema::Type back into a BrandedDecl. + params.add(decompileType(resolver, binding.getType())); + break; + } + } + result->params = params.finish(); + break; + } + + case schema::Brand::Scope::INHERIT: + KJ_IF_MAYBE(p, getParams(decl.id)) { + result->params = kj::heapArray(*p); + } else { + result->inherited = true; + } + break; + } + + // Parent should start one level deeper in the list. + ++index; + } + } + + // Fill in `parent`. + KJ_IF_MAYBE(parent, decl.resolver->getParent()) { + result->parent = evaluateBrand(resolver, *parent, brand, index); + } + + return result; +} + +BrandedDecl BrandScope::decompileType( + Resolver& resolver, schema::Type::Reader type) { + auto builtin = [&](Declaration::Which which) -> BrandedDecl { + auto decl = resolver.resolveBuiltin(which); + return BrandedDecl(decl, + evaluateBrand(resolver, decl, List::Reader()), + Expression::Reader()); + }; + + switch (type.which()) { + case schema::Type::VOID: return builtin(Declaration::BUILTIN_VOID); + case schema::Type::BOOL: return builtin(Declaration::BUILTIN_BOOL); + case schema::Type::INT8: return builtin(Declaration::BUILTIN_INT8); + case schema::Type::INT16: return builtin(Declaration::BUILTIN_INT16); + case schema::Type::INT32: return builtin(Declaration::BUILTIN_INT32); + case schema::Type::INT64: return builtin(Declaration::BUILTIN_INT64); + case schema::Type::UINT8: return builtin(Declaration::BUILTIN_U_INT8); + case schema::Type::UINT16: return builtin(Declaration::BUILTIN_U_INT16); + case schema::Type::UINT32: return builtin(Declaration::BUILTIN_U_INT32); + case schema::Type::UINT64: return builtin(Declaration::BUILTIN_U_INT64); + case schema::Type::FLOAT32: return builtin(Declaration::BUILTIN_FLOAT32); + case schema::Type::FLOAT64: return builtin(Declaration::BUILTIN_FLOAT64); + case schema::Type::TEXT: return builtin(Declaration::BUILTIN_TEXT); + case schema::Type::DATA: return builtin(Declaration::BUILTIN_DATA); + + case schema::Type::ENUM: { + auto enumType = type.getEnum(); + Resolver::ResolvedDecl decl = resolver.resolveId(enumType.getTypeId()); + return BrandedDecl(decl, + evaluateBrand(resolver, decl, enumType.getBrand().getScopes()), + Expression::Reader()); + } + + case schema::Type::INTERFACE: { + auto interfaceType = type.getInterface(); + Resolver::ResolvedDecl decl = resolver.resolveId(interfaceType.getTypeId()); + return BrandedDecl(decl, + evaluateBrand(resolver, decl, interfaceType.getBrand().getScopes()), + Expression::Reader()); + } + + case schema::Type::STRUCT: { + auto structType = type.getStruct(); + Resolver::ResolvedDecl decl = resolver.resolveId(structType.getTypeId()); + return BrandedDecl(decl, + evaluateBrand(resolver, decl, structType.getBrand().getScopes()), + Expression::Reader()); + } + + case schema::Type::LIST: { + auto elementType = decompileType(resolver, type.getList().getElementType()); + return KJ_ASSERT_NONNULL(builtin(Declaration::BUILTIN_LIST) + .applyParams(kj::heapArray(&elementType, 1), Expression::Reader())); + } + + case schema::Type::ANY_POINTER: { + auto anyPointer = type.getAnyPointer(); + switch (anyPointer.which()) { + case schema::Type::AnyPointer::UNCONSTRAINED: + return builtin(Declaration::BUILTIN_ANY_POINTER); + + case schema::Type::AnyPointer::PARAMETER: { + auto param = anyPointer.getParameter(); + auto id = param.getScopeId(); + uint index = param.getParameterIndex(); + KJ_IF_MAYBE(binding, lookupParameter(resolver, id, index)) { + return *binding; + } else { + return BrandedDecl(Resolver::ResolvedParameter {id, index}, Expression::Reader()); + } + } + + case schema::Type::AnyPointer::IMPLICIT_METHOD_PARAMETER: + KJ_FAIL_ASSERT("Alias pointed to implicit method type parameter?"); + } + + KJ_UNREACHABLE; + } + } + + KJ_UNREACHABLE; +} + +kj::Maybe BrandScope::compileDeclExpression( + Expression::Reader source, Resolver& resolver, + ImplicitParams implicitMethodParams) { + switch (source.which()) { + case Expression::UNKNOWN: + // Error reported earlier. + return nullptr; + + case Expression::POSITIVE_INT: + case Expression::NEGATIVE_INT: + case Expression::FLOAT: + case Expression::STRING: + case Expression::BINARY: + case Expression::LIST: + case Expression::TUPLE: + case Expression::EMBED: + errorReporter.addErrorOn(source, "Expected name."); + return nullptr; + + case Expression::RELATIVE_NAME: { + auto name = source.getRelativeName(); + auto nameValue = name.getValue(); + + // Check implicit method params first. + for (auto i: kj::indices(implicitMethodParams.params)) { + if (implicitMethodParams.params[i].getName() == nameValue) { + if (implicitMethodParams.scopeId == 0) { + return BrandedDecl::implicitMethodParam(i); + } else { + return BrandedDecl(Resolver::ResolvedParameter { + implicitMethodParams.scopeId, static_cast(i) }, + Expression::Reader()); + } + } + } + + KJ_IF_MAYBE(r, resolver.resolve(nameValue)) { + return interpretResolve(resolver, *r, source); + } else { + errorReporter.addErrorOn(name, kj::str("Not defined: ", nameValue)); + return nullptr; + } + } + + case Expression::ABSOLUTE_NAME: { + auto name = source.getAbsoluteName(); + KJ_IF_MAYBE(r, resolver.getTopScope().resolver->resolveMember(name.getValue())) { + return interpretResolve(resolver, *r, source); + } else { + errorReporter.addErrorOn(name, kj::str("Not defined: ", name.getValue())); + return nullptr; + } + } + + case Expression::IMPORT: { + auto filename = source.getImport(); + KJ_IF_MAYBE(decl, resolver.resolveImport(filename.getValue())) { + // Import is always a root scope, so create a fresh BrandScope. + return BrandedDecl(*decl, kj::refcounted( + errorReporter, decl->id, decl->genericParamCount, *decl->resolver), source); + } else { + errorReporter.addErrorOn(filename, kj::str("Import failed: ", filename.getValue())); + return nullptr; + } + } + + case Expression::APPLICATION: { + auto app = source.getApplication(); + KJ_IF_MAYBE(decl, compileDeclExpression(app.getFunction(), resolver, implicitMethodParams)) { + // Compile all params. + auto params = app.getParams(); + auto compiledParams = kj::heapArrayBuilder(params.size()); + bool paramFailed = false; + for (auto param: params) { + if (param.isNamed()) { + errorReporter.addErrorOn(param.getNamed(), "Named parameter not allowed here."); + } + + KJ_IF_MAYBE(d, compileDeclExpression(param.getValue(), resolver, implicitMethodParams)) { + compiledParams.add(kj::mv(*d)); + } else { + // Param failed to compile. Error was already reported. + paramFailed = true; + } + }; + + if (paramFailed) { + return kj::mv(*decl); + } + + // Add the parameters to the brand. + KJ_IF_MAYBE(applied, decl->applyParams(compiledParams.finish(), source)) { + return kj::mv(*applied); + } else { + // Error already reported. Ignore parameters. + return kj::mv(*decl); + } + } else { + // error already reported + return nullptr; + } + } + + case Expression::MEMBER: { + auto member = source.getMember(); + KJ_IF_MAYBE(decl, compileDeclExpression(member.getParent(), resolver, implicitMethodParams)) { + auto name = member.getName(); + KJ_IF_MAYBE(memberDecl, decl->getMember(name.getValue(), source)) { + return kj::mv(*memberDecl); + } else { + errorReporter.addErrorOn(name, kj::str( + "'", expressionString(member.getParent()), + "' has no member named '", name.getValue(), "'")); + return nullptr; + } + } else { + // error already reported + return nullptr; + } + } + } + + KJ_UNREACHABLE; +} + +} // namespace compiler +} // namespace capnp diff --git a/c++/src/capnp/compiler/generics.h b/c++/src/capnp/compiler/generics.h new file mode 100644 index 0000000000..fbdbddb488 --- /dev/null +++ b/c++/src/capnp/compiler/generics.h @@ -0,0 +1,310 @@ +// Copyright (c) 2013-2020 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "error-reporter.h" +#include "resolver.h" + +CAPNP_BEGIN_HEADER + +namespace capnp { +namespace compiler { + +class BrandedDecl; +class BrandScope; + +struct ImplicitParams { + // Represents a set of implicit brand parameters visible in the current context. + // + // As of this writing, implicit parameters occur only in the context of RPC methods. That is, + // like this: + // + // makeBox @0 [T] (value :T) -> Box(T); + // + // Here, `T` is an implicit parameter. + + uint64_t scopeId; + // If zero, then any reference to an implicit param in this context should be compiled to a + // `implicitMethodParam` AnyPointer. If non-zero, it should be compiled to a `parameter` + // AnyPointer using this scopeId. This comes into play when compiling the implicitly-generated + // struct types corresponding to a method's params or results; these implicitly-generated types + // themselves have *explicit* brand parameters corresponding to the *implicit* brand parameters + // of the method. + // + // TODO(cleanup): Unclear why ImplicitParams is even used when compiling the implicit structs + // with explicit params. Missing abstraction? + + List::Reader params; + // Name and metadata about the parameter declaration. + + static inline ImplicitParams none() { + // Convenience helper to create an empty `ImplicitParams`. + return { 0, List::Reader() }; + } +}; + +class BrandedDecl { + // Represents a declaration possibly with generic parameter bindings. + +public: + inline BrandedDecl(Resolver::ResolvedDecl decl, + kj::Own&& brand, + Expression::Reader source) + : brand(kj::mv(brand)), source(source) { + // `source`, is the expression which specified this branded decl. It is provided so that errors + // can be reported against it. It is acceptable to pass a default-initialized reader if there's + // no source expression; errors will then be reported at 0, 0. + + body.init(kj::mv(decl)); + } + inline BrandedDecl(Resolver::ResolvedParameter variable, Expression::Reader source) + : source(source) { + body.init(kj::mv(variable)); + } + inline BrandedDecl(decltype(nullptr)) {} + inline BrandedDecl() {} // exists only for ExternalMutexGuarded to work... + + static BrandedDecl implicitMethodParam(uint index) { + // Get a BrandedDecl referring to an implicit method parameter. + // (As a hack, we internally represent this as a ResolvedParameter. Sorry.) + return BrandedDecl(Resolver::ResolvedParameter { 0, index }, Expression::Reader()); + } + + BrandedDecl(BrandedDecl& other); + BrandedDecl(BrandedDecl&& other) = default; + + BrandedDecl& operator=(BrandedDecl& other); + BrandedDecl& operator=(BrandedDecl&& other) = default; + + kj::Maybe applyParams(kj::Array params, Expression::Reader subSource); + // Treat the declaration as a generic and apply it to the given parameter list. + + kj::Maybe getMember(kj::StringPtr memberName, Expression::Reader subSource); + // Get a member of this declaration. + + kj::Maybe getKind(); + // Returns the kind of declaration, or null if this is an unbound generic variable. + + template + uint64_t getIdAndFillBrand(InitBrandFunc&& initBrand); + // Returns the type ID of this node. `initBrand` is a zero-arg functor which returns + // schema::Brand::Builder; this will be called if this decl has brand bindings, and + // the returned builder filled in to reflect those bindings. + // + // It is an error to call this when `getKind()` returns null. + + kj::Maybe getListParam(); + // Only if the kind is BUILTIN_LIST: Get the list's type parameter. + + Resolver::ResolvedParameter asVariable(); + // If this is an unbound generic variable (i.e. `getKind()` returns null), return information + // about the variable. + // + // It is an error to call this when `getKind()` does not return null. + + bool compileAsType(ErrorReporter& errorReporter, schema::Type::Builder target); + // Compile this decl to a schema::Type. + + inline void addError(ErrorReporter& errorReporter, kj::StringPtr message) { + errorReporter.addErrorOn(source, message); + } + + Resolver::ResolveResult asResolveResult(uint64_t scopeId, schema::Brand::Builder brandBuilder); + // Reverse this into a ResolveResult. If necessary, use `brandBuilder` to fill in + // ResolvedDecl.brand. + + kj::String toString(); + kj::String toDebugString(); + +private: + Resolver::ResolveResult body; + kj::Own brand; // null if parameter + Expression::Reader source; +}; + +class BrandScope: public kj::Refcounted { + // Tracks the brand parameter bindings affecting the scope specified by some expression. For + // example, if we are interpreting the type expression "Foo(Text).Bar", we would start with the + // current scope's BrandScope, create a new child BrandScope representing "Foo", add the "(Text)" + // parameter bindings to it, then create a further child scope for "Bar". Thus the BrandScope for + // Bar knows that Foo's parameter list has been bound to "(Text)". + +public: + BrandScope(ErrorReporter& errorReporter, uint64_t startingScopeId, + uint startingScopeParamCount, Resolver& startingScope); + // TODO(bug): Passing an `errorReporter` to the constructor of `BrandScope` turns out not to + // make a ton of sense, as an `errorReporter` is meant to report errors in a specific module, + // but `BrandScope` might be constructed while compiling one module but then used when + // compiling a different module, or not compiling a module at all. Note, though, that it DOES + // make sense for BrandedDecl to have an ErrorReporter, specifically associated with its + // `source` expression. + + bool isGeneric(); + // Returns true if this scope or any parent scope is a generic (has brand parameters). + + kj::Own push(uint64_t typeId, uint paramCount); + // Creates a new child scope with the given type ID and number of brand parameters. + + kj::Maybe> setParams( + kj::Array params, Declaration::Which genericType, Expression::Reader source); + // Create a new BrandScope representing the same scope, but with parameters filled in. + // + // This should only be called on the generic version of the scope. If called on a branded + // version, an error will be reported. + // + // Returns null if an error occurred that prevented creating the BrandScope; the error will have + // been reported to the ErrorReporter. + + kj::Own pop(uint64_t newLeafId); + // Return the parent scope. + + kj::Maybe lookupParameter(Resolver& resolver, uint64_t scopeId, uint index); + // Search up the scope chain for the scope matching `scopeId`, and return its `index`th parameter + // binding. Returns null if the parameter is from a scope that we are currently compiling, and + // hasn't otherwise been bound to any argument (see Brand.Scope.inherit in schema.capnp). + // + // In the case that a parameter wasn't specified, but isn't part of the current scope, this + // returns the declaration for `AnyPointer`. + // + // TODO(cleanup): Should be called lookupArgument()? + + kj::Maybe> getParams(uint64_t scopeId); + // Get the whole list of parameter bindings at the given scope. Returns null if the scope is + // currently be compiled and the parameters are unbound. + // + // Note that it's possible that not all declared parameters were actually specified for a given + // scope. For example, if you declare a generic `Foo(T, U)`, and then you intiantiate it + // somewhere as `Foo(Text)`, then `U` is unspecified -- this is not an error, because Cap'n + // Proto allows new type parameters to be added over time. `U` should be treated as `AnyPointer` + // in this case, but `getParams()` doesn't know how many parameters are expected, so it will + // return an array that only contains one item. Use `lookupParameter()` if you want unspecified + // parameters to be filled in with `AnyPointer` automatically. + // + // TODO(cleanup): Should be called getArguments()? + + template + void compile(InitBrandFunc&& initBrand); + // Constructs the schema::Brand corresponding to this brand scope. + // + // `initBrand` is a zero-arg functor which returns an empty schema::Brand::Builder, into which + // the brand is constructed. If no generics are present, then `initBrand` is never called. + // + // TODO(cleanup): Should this return Maybe> instead? + + kj::Maybe compileDeclExpression( + Expression::Reader source, Resolver& resolver, + ImplicitParams implicitMethodParams); + // Interpret a type expression within this branded scope. + + BrandedDecl interpretResolve( + Resolver& resolver, Resolver::ResolveResult& result, Expression::Reader source); + // After using a Resolver to resolve a symbol, call interpretResolve() to interpret the result + // within the current brand scope. For example, if a name resolved to a brand parameter, this + // replaces it with the appropriate argument from the scope. + + inline uint64_t getScopeId() { return leafId; } + +private: + ErrorReporter& errorReporter; + kj::Maybe> parent; + uint64_t leafId; // zero = this is the root + uint leafParamCount; // number of generic parameters on this leaf + bool inherited; + kj::Array params; + + BrandScope(kj::Own parent, uint64_t leafId, uint leafParamCount) + : errorReporter(parent->errorReporter), + parent(kj::mv(parent)), leafId(leafId), leafParamCount(leafParamCount), + inherited(false) {} + BrandScope(BrandScope& base, kj::Array params) + : errorReporter(base.errorReporter), + leafId(base.leafId), leafParamCount(base.leafParamCount), + inherited(false), params(kj::mv(params)) { + KJ_IF_MAYBE(p, base.parent) { + parent = kj::addRef(**p); + } + } + BrandScope(ErrorReporter& errorReporter, uint64_t scopeId) + : errorReporter(errorReporter), leafId(scopeId), leafParamCount(0), inherited(false) {} + + kj::Own evaluateBrand( + Resolver& resolver, Resolver::ResolvedDecl decl, + List::Reader brand, uint index = 0); + + BrandedDecl decompileType(Resolver& resolver, schema::Type::Reader type); + + template + friend kj::Own kj::refcounted(Params&&... params); + friend class BrandedDecl; +}; + +template +uint64_t BrandedDecl::getIdAndFillBrand(InitBrandFunc&& initBrand) { + KJ_REQUIRE(body.is()); + + brand->compile(kj::fwd(initBrand)); + return body.get().id; +} + +template +void BrandScope::compile(InitBrandFunc&& initBrand) { + kj::Vector levels; + BrandScope* ptr = this; + for (;;) { + if (ptr->params.size() > 0 || (ptr->inherited && ptr->leafParamCount > 0)) { + levels.add(ptr); + } + KJ_IF_MAYBE(p, ptr->parent) { + ptr = *p; + } else { + break; + } + } + + if (levels.size() > 0) { + auto scopes = initBrand().initScopes(levels.size()); + for (uint i: kj::indices(levels)) { + auto scope = scopes[i]; + scope.setScopeId(levels[i]->leafId); + + if (levels[i]->inherited) { + scope.setInherit(); + } else { + auto bindings = scope.initBind(levels[i]->params.size()); + for (uint j: kj::indices(bindings)) { + levels[i]->params[j].compileAsType(errorReporter, bindings[j].initType()); + } + } + } + } +} + +} // namespace compiler +} // namespace capnp + +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/grammar.capnp b/c++/src/capnp/compiler/grammar.capnp index 209a327901..4434f2d944 100644 --- a/c++/src/capnp/compiler/grammar.capnp +++ b/c++/src/capnp/compiler/grammar.capnp @@ -244,6 +244,9 @@ struct Declaration { type @1 :Expression; # Specified some other struct type instead of a named list. + + stream @4 :Void; + # The keyword "stream". } startByte @2 :UInt32; diff --git a/c++/src/capnp/compiler/grammar.capnp.c++ b/c++/src/capnp/compiler/grammar.capnp.c++ index f84b43703f..f433a1cf14 100644 --- a/c++/src/capnp/compiler/grammar.capnp.c++ +++ b/c++/src/capnp/compiler/grammar.capnp.c++ @@ -79,7 +79,7 @@ static const uint16_t m_e75816b56529d464[] = {2, 1, 0}; static const uint16_t i_e75816b56529d464[] = {0, 1, 2}; const ::capnp::_::RawSchema s_e75816b56529d464 = { 0xe75816b56529d464, b_e75816b56529d464.words, 66, nullptr, m_e75816b56529d464, - 0, 3, i_e75816b56529d464, nullptr, nullptr, { &s_e75816b56529d464, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_e75816b56529d464, nullptr, nullptr, { &s_e75816b56529d464, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<66> b_991c7a3693d62cf2 = { @@ -156,7 +156,7 @@ static const uint16_t m_991c7a3693d62cf2[] = {2, 1, 0}; static const uint16_t i_991c7a3693d62cf2[] = {0, 1, 2}; const ::capnp::_::RawSchema s_991c7a3693d62cf2 = { 0x991c7a3693d62cf2, b_991c7a3693d62cf2.words, 66, nullptr, m_991c7a3693d62cf2, - 0, 3, i_991c7a3693d62cf2, nullptr, nullptr, { &s_991c7a3693d62cf2, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_991c7a3693d62cf2, nullptr, nullptr, { &s_991c7a3693d62cf2, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<66> b_90f2a60678fd2367 = { @@ -233,7 +233,7 @@ static const uint16_t m_90f2a60678fd2367[] = {2, 1, 0}; static const uint16_t i_90f2a60678fd2367[] = {0, 1, 2}; const ::capnp::_::RawSchema s_90f2a60678fd2367 = { 0x90f2a60678fd2367, b_90f2a60678fd2367.words, 66, nullptr, m_90f2a60678fd2367, - 0, 3, i_90f2a60678fd2367, nullptr, nullptr, { &s_90f2a60678fd2367, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_90f2a60678fd2367, nullptr, nullptr, { &s_90f2a60678fd2367, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<262> b_8e207d4dfe54d0de = { @@ -513,7 +513,7 @@ static const uint16_t m_8e207d4dfe54d0de[] = {13, 11, 10, 15, 9, 3, 14, 6, 12, 2 static const uint16_t i_8e207d4dfe54d0de[] = {0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 8, 9}; const ::capnp::_::RawSchema s_8e207d4dfe54d0de = { 0x8e207d4dfe54d0de, b_8e207d4dfe54d0de.words, 262, d_8e207d4dfe54d0de, m_8e207d4dfe54d0de, - 5, 16, i_8e207d4dfe54d0de, nullptr, nullptr, { &s_8e207d4dfe54d0de, nullptr, nullptr, 0, 0, nullptr } + 5, 16, i_8e207d4dfe54d0de, nullptr, nullptr, { &s_8e207d4dfe54d0de, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_c90246b71adedbaa = { @@ -593,7 +593,7 @@ static const uint16_t m_c90246b71adedbaa[] = {1, 0, 2}; static const uint16_t i_c90246b71adedbaa[] = {0, 1, 2}; const ::capnp::_::RawSchema s_c90246b71adedbaa = { 0xc90246b71adedbaa, b_c90246b71adedbaa.words, 65, d_c90246b71adedbaa, m_c90246b71adedbaa, - 2, 3, i_c90246b71adedbaa, nullptr, nullptr, { &s_c90246b71adedbaa, nullptr, nullptr, 0, 0, nullptr } + 2, 3, i_c90246b71adedbaa, nullptr, nullptr, { &s_c90246b71adedbaa, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<55> b_aee8397040b0df7a = { @@ -663,7 +663,7 @@ static const uint16_t m_aee8397040b0df7a[] = {0, 1}; static const uint16_t i_aee8397040b0df7a[] = {0, 1}; const ::capnp::_::RawSchema s_aee8397040b0df7a = { 0xaee8397040b0df7a, b_aee8397040b0df7a.words, 55, d_aee8397040b0df7a, m_aee8397040b0df7a, - 2, 2, i_aee8397040b0df7a, nullptr, nullptr, { &s_aee8397040b0df7a, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_aee8397040b0df7a, nullptr, nullptr, { &s_aee8397040b0df7a, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_aa28e1400d793359 = { @@ -727,7 +727,7 @@ static const uint16_t m_aa28e1400d793359[] = {1, 0}; static const uint16_t i_aa28e1400d793359[] = {0, 1}; const ::capnp::_::RawSchema s_aa28e1400d793359 = { 0xaa28e1400d793359, b_aa28e1400d793359.words, 49, d_aa28e1400d793359, m_aa28e1400d793359, - 2, 2, i_aa28e1400d793359, nullptr, nullptr, { &s_aa28e1400d793359, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_aa28e1400d793359, nullptr, nullptr, { &s_aa28e1400d793359, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<677> b_96efe787c17e83bb = { @@ -1429,7 +1429,7 @@ static const uint16_t m_96efe787c17e83bb[] = {18, 3, 40, 37, 39, 22, 41, 34, 31, static const uint16_t i_96efe787c17e83bb[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 41, 0, 1, 2, 3, 4, 5, 6, 38}; const ::capnp::_::RawSchema s_96efe787c17e83bb = { 0x96efe787c17e83bb, b_96efe787c17e83bb.words, 677, d_96efe787c17e83bb, m_96efe787c17e83bb, - 12, 42, i_96efe787c17e83bb, nullptr, nullptr, { &s_96efe787c17e83bb, nullptr, nullptr, 0, 0, nullptr } + 12, 42, i_96efe787c17e83bb, nullptr, nullptr, { &s_96efe787c17e83bb, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<67> b_d5e71144af1ce175 = { @@ -1507,7 +1507,7 @@ static const uint16_t m_d5e71144af1ce175[] = {2, 0, 1}; static const uint16_t i_d5e71144af1ce175[] = {0, 1, 2}; const ::capnp::_::RawSchema s_d5e71144af1ce175 = { 0xd5e71144af1ce175, b_d5e71144af1ce175.words, 67, nullptr, m_d5e71144af1ce175, - 0, 3, i_d5e71144af1ce175, nullptr, nullptr, { &s_d5e71144af1ce175, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_d5e71144af1ce175, nullptr, nullptr, { &s_d5e71144af1ce175, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<45> b_d00489d473826290 = { @@ -1567,7 +1567,7 @@ static const uint16_t m_d00489d473826290[] = {0, 1}; static const uint16_t i_d00489d473826290[] = {0, 1}; const ::capnp::_::RawSchema s_d00489d473826290 = { 0xd00489d473826290, b_d00489d473826290.words, 45, d_d00489d473826290, m_d00489d473826290, - 2, 2, i_d00489d473826290, nullptr, nullptr, { &s_d00489d473826290, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_d00489d473826290, nullptr, nullptr, { &s_d00489d473826290, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<53> b_fb5aeed95cdf6af9 = { @@ -1635,7 +1635,7 @@ static const uint16_t m_fb5aeed95cdf6af9[] = {1, 0}; static const uint16_t i_fb5aeed95cdf6af9[] = {0, 1}; const ::capnp::_::RawSchema s_fb5aeed95cdf6af9 = { 0xfb5aeed95cdf6af9, b_fb5aeed95cdf6af9.words, 53, d_fb5aeed95cdf6af9, m_fb5aeed95cdf6af9, - 2, 2, i_fb5aeed95cdf6af9, nullptr, nullptr, { &s_fb5aeed95cdf6af9, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_fb5aeed95cdf6af9, nullptr, nullptr, { &s_fb5aeed95cdf6af9, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<28> b_94099c3f9eb32d6b = { @@ -1672,20 +1672,20 @@ static const ::capnp::_::AlignedData<28> b_94099c3f9eb32d6b = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_94099c3f9eb32d6b = { 0x94099c3f9eb32d6b, b_94099c3f9eb32d6b.words, 28, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_94099c3f9eb32d6b, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_94099c3f9eb32d6b, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<87> b_b3f66e7a79d81bcd = { +static const ::capnp::_::AlignedData<102> b_b3f66e7a79d81bcd = { { 0, 0, 0, 0, 5, 0, 6, 0, 205, 27, 216, 121, 122, 110, 246, 179, 41, 0, 0, 0, 1, 0, 2, 0, 187, 131, 126, 193, 135, 231, 239, 150, - 1, 0, 7, 0, 0, 0, 2, 0, + 1, 0, 7, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 154, 1, 0, 0, 45, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 41, 0, 0, 0, 231, 0, 0, 0, + 41, 0, 0, 0, 31, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 99, 111, @@ -1696,35 +1696,42 @@ static const ::capnp::_::AlignedData<87> b_b3f66e7a79d81bcd = { 46, 80, 97, 114, 97, 109, 76, 105, 115, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 16, 0, 0, 0, 3, 0, 4, 0, + 20, 0, 0, 0, 3, 0, 4, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 97, 0, 0, 0, 82, 0, 0, 0, + 125, 0, 0, 0, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 96, 0, 0, 0, 3, 0, 1, 0, - 124, 0, 0, 0, 2, 0, 1, 0, + 124, 0, 0, 0, 3, 0, 1, 0, + 152, 0, 0, 0, 2, 0, 1, 0, 1, 0, 254, 255, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 121, 0, 0, 0, 42, 0, 0, 0, + 149, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 116, 0, 0, 0, 3, 0, 1, 0, - 128, 0, 0, 0, 2, 0, 1, 0, - 2, 0, 0, 0, 1, 0, 0, 0, + 144, 0, 0, 0, 3, 0, 1, 0, + 156, 0, 0, 0, 2, 0, 1, 0, + 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 125, 0, 0, 0, 82, 0, 0, 0, + 153, 0, 0, 0, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 124, 0, 0, 0, 3, 0, 1, 0, - 136, 0, 0, 0, 2, 0, 1, 0, - 3, 0, 0, 0, 2, 0, 0, 0, + 152, 0, 0, 0, 3, 0, 1, 0, + 164, 0, 0, 0, 2, 0, 1, 0, + 4, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 133, 0, 0, 0, 66, 0, 0, 0, + 161, 0, 0, 0, 66, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 156, 0, 0, 0, 3, 0, 1, 0, + 168, 0, 0, 0, 2, 0, 1, 0, + 2, 0, 253, 255, 0, 0, 0, 0, + 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 128, 0, 0, 0, 3, 0, 1, 0, - 140, 0, 0, 0, 2, 0, 1, 0, + 165, 0, 0, 0, 58, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 160, 0, 0, 0, 3, 0, 1, 0, + 172, 0, 0, 0, 2, 0, 1, 0, 110, 97, 109, 101, 100, 76, 105, 115, 116, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, @@ -1762,6 +1769,14 @@ static const ::capnp::_::AlignedData<87> b_b3f66e7a79d81bcd = { 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 115, 116, 114, 101, 97, 109, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; ::capnp::word const* const bp_b3f66e7a79d81bcd = b_b3f66e7a79d81bcd.words; @@ -1770,11 +1785,11 @@ static const ::capnp::_::RawSchema* const d_b3f66e7a79d81bcd[] = { &s_8e207d4dfe54d0de, &s_fffe08a9a697d2a5, }; -static const uint16_t m_b3f66e7a79d81bcd[] = {3, 0, 2, 1}; -static const uint16_t i_b3f66e7a79d81bcd[] = {0, 1, 2, 3}; +static const uint16_t m_b3f66e7a79d81bcd[] = {3, 0, 2, 4, 1}; +static const uint16_t i_b3f66e7a79d81bcd[] = {0, 1, 4, 2, 3}; const ::capnp::_::RawSchema s_b3f66e7a79d81bcd = { - 0xb3f66e7a79d81bcd, b_b3f66e7a79d81bcd.words, 87, d_b3f66e7a79d81bcd, m_b3f66e7a79d81bcd, - 2, 4, i_b3f66e7a79d81bcd, nullptr, nullptr, { &s_b3f66e7a79d81bcd, nullptr, nullptr, 0, 0, nullptr } + 0xb3f66e7a79d81bcd, b_b3f66e7a79d81bcd.words, 102, d_b3f66e7a79d81bcd, m_b3f66e7a79d81bcd, + 2, 5, i_b3f66e7a79d81bcd, nullptr, nullptr, { &s_b3f66e7a79d81bcd, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<110> b_fffe08a9a697d2a5 = { @@ -1901,7 +1916,7 @@ static const uint16_t m_fffe08a9a697d2a5[] = {2, 3, 5, 0, 4, 1}; static const uint16_t i_fffe08a9a697d2a5[] = {0, 1, 2, 3, 4, 5}; const ::capnp::_::RawSchema s_fffe08a9a697d2a5 = { 0xfffe08a9a697d2a5, b_fffe08a9a697d2a5.words, 110, d_fffe08a9a697d2a5, m_fffe08a9a697d2a5, - 4, 6, i_fffe08a9a697d2a5, nullptr, nullptr, { &s_fffe08a9a697d2a5, nullptr, nullptr, 0, 0, nullptr } + 4, 6, i_fffe08a9a697d2a5, nullptr, nullptr, { &s_fffe08a9a697d2a5, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_e5104515fd88ea47 = { @@ -1967,7 +1982,7 @@ static const uint16_t m_e5104515fd88ea47[] = {0, 1}; static const uint16_t i_e5104515fd88ea47[] = {0, 1}; const ::capnp::_::RawSchema s_e5104515fd88ea47 = { 0xe5104515fd88ea47, b_e5104515fd88ea47.words, 51, d_e5104515fd88ea47, m_e5104515fd88ea47, - 2, 2, i_e5104515fd88ea47, nullptr, nullptr, { &s_e5104515fd88ea47, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_e5104515fd88ea47, nullptr, nullptr, { &s_e5104515fd88ea47, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_89f0c973c103ae96 = { @@ -2047,7 +2062,7 @@ static const uint16_t m_89f0c973c103ae96[] = {2, 1, 0}; static const uint16_t i_89f0c973c103ae96[] = {0, 1, 2}; const ::capnp::_::RawSchema s_89f0c973c103ae96 = { 0x89f0c973c103ae96, b_89f0c973c103ae96.words, 65, d_89f0c973c103ae96, m_89f0c973c103ae96, - 2, 3, i_89f0c973c103ae96, nullptr, nullptr, { &s_89f0c973c103ae96, nullptr, nullptr, 0, 0, nullptr } + 2, 3, i_89f0c973c103ae96, nullptr, nullptr, { &s_89f0c973c103ae96, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<34> b_e93164a80bfe2ccf = { @@ -2096,7 +2111,7 @@ static const uint16_t m_e93164a80bfe2ccf[] = {0}; static const uint16_t i_e93164a80bfe2ccf[] = {0}; const ::capnp::_::RawSchema s_e93164a80bfe2ccf = { 0xe93164a80bfe2ccf, b_e93164a80bfe2ccf.words, 34, d_e93164a80bfe2ccf, m_e93164a80bfe2ccf, - 2, 1, i_e93164a80bfe2ccf, nullptr, nullptr, { &s_e93164a80bfe2ccf, nullptr, nullptr, 0, 0, nullptr } + 2, 1, i_e93164a80bfe2ccf, nullptr, nullptr, { &s_e93164a80bfe2ccf, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_b348322a8dcf0d0c = { @@ -2160,7 +2175,7 @@ static const uint16_t m_b348322a8dcf0d0c[] = {0, 1}; static const uint16_t i_b348322a8dcf0d0c[] = {0, 1}; const ::capnp::_::RawSchema s_b348322a8dcf0d0c = { 0xb348322a8dcf0d0c, b_b348322a8dcf0d0c.words, 49, d_b348322a8dcf0d0c, m_b348322a8dcf0d0c, - 2, 2, i_b348322a8dcf0d0c, nullptr, nullptr, { &s_b348322a8dcf0d0c, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_b348322a8dcf0d0c, nullptr, nullptr, { &s_b348322a8dcf0d0c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<43> b_8f2622208fb358c8 = { @@ -2219,7 +2234,7 @@ static const uint16_t m_8f2622208fb358c8[] = {1, 0}; static const uint16_t i_8f2622208fb358c8[] = {0, 1}; const ::capnp::_::RawSchema s_8f2622208fb358c8 = { 0x8f2622208fb358c8, b_8f2622208fb358c8.words, 43, d_8f2622208fb358c8, m_8f2622208fb358c8, - 3, 2, i_8f2622208fb358c8, nullptr, nullptr, { &s_8f2622208fb358c8, nullptr, nullptr, 0, 0, nullptr } + 3, 2, i_8f2622208fb358c8, nullptr, nullptr, { &s_8f2622208fb358c8, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_d0d1a21de617951f = { @@ -2285,7 +2300,7 @@ static const uint16_t m_d0d1a21de617951f[] = {0, 1}; static const uint16_t i_d0d1a21de617951f[] = {0, 1}; const ::capnp::_::RawSchema s_d0d1a21de617951f = { 0xd0d1a21de617951f, b_d0d1a21de617951f.words, 51, d_d0d1a21de617951f, m_d0d1a21de617951f, - 2, 2, i_d0d1a21de617951f, nullptr, nullptr, { &s_d0d1a21de617951f, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_d0d1a21de617951f, nullptr, nullptr, { &s_d0d1a21de617951f, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<40> b_992a90eaf30235d3 = { @@ -2340,7 +2355,7 @@ static const uint16_t m_992a90eaf30235d3[] = {0}; static const uint16_t i_992a90eaf30235d3[] = {0}; const ::capnp::_::RawSchema s_992a90eaf30235d3 = { 0x992a90eaf30235d3, b_992a90eaf30235d3.words, 40, d_992a90eaf30235d3, m_992a90eaf30235d3, - 2, 1, i_992a90eaf30235d3, nullptr, nullptr, { &s_992a90eaf30235d3, nullptr, nullptr, 0, 0, nullptr } + 2, 1, i_992a90eaf30235d3, nullptr, nullptr, { &s_992a90eaf30235d3, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<42> b_eb971847d617c0b9 = { @@ -2398,7 +2413,7 @@ static const uint16_t m_eb971847d617c0b9[] = {0, 1}; static const uint16_t i_eb971847d617c0b9[] = {0, 1}; const ::capnp::_::RawSchema s_eb971847d617c0b9 = { 0xeb971847d617c0b9, b_eb971847d617c0b9.words, 42, d_eb971847d617c0b9, m_eb971847d617c0b9, - 3, 2, i_eb971847d617c0b9, nullptr, nullptr, { &s_eb971847d617c0b9, nullptr, nullptr, 0, 0, nullptr } + 3, 2, i_eb971847d617c0b9, nullptr, nullptr, { &s_eb971847d617c0b9, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_c6238c7d62d65173 = { @@ -2464,7 +2479,7 @@ static const uint16_t m_c6238c7d62d65173[] = {1, 0}; static const uint16_t i_c6238c7d62d65173[] = {0, 1}; const ::capnp::_::RawSchema s_c6238c7d62d65173 = { 0xc6238c7d62d65173, b_c6238c7d62d65173.words, 51, d_c6238c7d62d65173, m_c6238c7d62d65173, - 2, 2, i_c6238c7d62d65173, nullptr, nullptr, { &s_c6238c7d62d65173, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_c6238c7d62d65173, nullptr, nullptr, { &s_c6238c7d62d65173, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<230> b_9cb9e86e3198037f = { @@ -2709,7 +2724,7 @@ static const uint16_t m_9cb9e86e3198037f[] = {12, 2, 3, 4, 6, 1, 8, 9, 10, 11, 5 static const uint16_t i_9cb9e86e3198037f[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; const ::capnp::_::RawSchema s_9cb9e86e3198037f = { 0x9cb9e86e3198037f, b_9cb9e86e3198037f.words, 230, d_9cb9e86e3198037f, m_9cb9e86e3198037f, - 2, 13, i_9cb9e86e3198037f, nullptr, nullptr, { &s_9cb9e86e3198037f, nullptr, nullptr, 0, 0, nullptr } + 2, 13, i_9cb9e86e3198037f, nullptr, nullptr, { &s_9cb9e86e3198037f, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<34> b_84e4f3f5a807605c = { @@ -2757,7 +2772,7 @@ static const uint16_t m_84e4f3f5a807605c[] = {0}; static const uint16_t i_84e4f3f5a807605c[] = {0}; const ::capnp::_::RawSchema s_84e4f3f5a807605c = { 0x84e4f3f5a807605c, b_84e4f3f5a807605c.words, 34, d_84e4f3f5a807605c, m_84e4f3f5a807605c, - 1, 1, i_84e4f3f5a807605c, nullptr, nullptr, { &s_84e4f3f5a807605c, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_84e4f3f5a807605c, nullptr, nullptr, { &s_84e4f3f5a807605c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -2769,195 +2784,291 @@ namespace capnp { namespace compiler { // LocatedText +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LocatedText::_capnpPrivate::dataWordSize; constexpr uint16_t LocatedText::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LocatedText::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LocatedText::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // LocatedInteger +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LocatedInteger::_capnpPrivate::dataWordSize; constexpr uint16_t LocatedInteger::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LocatedInteger::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LocatedInteger::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // LocatedFloat +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LocatedFloat::_capnpPrivate::dataWordSize; constexpr uint16_t LocatedFloat::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LocatedFloat::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LocatedFloat::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Expression +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Expression::_capnpPrivate::dataWordSize; constexpr uint16_t Expression::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Expression::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Expression::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Expression::Param +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Expression::Param::_capnpPrivate::dataWordSize; constexpr uint16_t Expression::Param::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Expression::Param::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Expression::Param::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Expression::Application +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Expression::Application::_capnpPrivate::dataWordSize; constexpr uint16_t Expression::Application::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Expression::Application::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Expression::Application::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Expression::Member +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Expression::Member::_capnpPrivate::dataWordSize; constexpr uint16_t Expression::Member::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Expression::Member::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Expression::Member::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::BrandParameter +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::BrandParameter::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::BrandParameter::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::BrandParameter::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::BrandParameter::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::AnnotationApplication +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::AnnotationApplication::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::AnnotationApplication::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::AnnotationApplication::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::AnnotationApplication::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::AnnotationApplication::Value +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::AnnotationApplication::Value::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::AnnotationApplication::Value::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::AnnotationApplication::Value::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::AnnotationApplication::Value::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::ParamList +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::ParamList::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::ParamList::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::ParamList::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::ParamList::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Param +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Param::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Param::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Param::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Param::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Param::DefaultValue +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Param::DefaultValue::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Param::DefaultValue::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Param::DefaultValue::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Param::DefaultValue::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Id +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Id::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Id::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Id::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Id::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Using +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Using::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Using::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Using::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Using::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Const +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Const::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Const::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Const::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Const::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Field +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Field::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Field::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Field::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Field::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Field::DefaultValue +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Field::DefaultValue::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Field::DefaultValue::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Field::DefaultValue::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Field::DefaultValue::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Interface +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Interface::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Interface::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Interface::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Interface::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Method +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Method::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Method::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Method::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Method::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Method::Results +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Method::Results::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Method::Results::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Method::Results::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Method::Results::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Declaration::Annotation +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Declaration::Annotation::_capnpPrivate::dataWordSize; constexpr uint16_t Declaration::Annotation::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Declaration::Annotation::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Declaration::Annotation::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // ParsedFile +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t ParsedFile::_capnpPrivate::dataWordSize; constexpr uint16_t ParsedFile::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind ParsedFile::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* ParsedFile::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/c++/src/capnp/compiler/grammar.capnp.h b/c++/src/capnp/compiler/grammar.capnp.h index f529404b67..2183e98f4e 100644 --- a/c++/src/capnp/compiler/grammar.capnp.h +++ b/c++/src/capnp/compiler/grammar.capnp.h @@ -1,16 +1,20 @@ // Generated by Cap'n Proto compiler, DO NOT EDIT // source: grammar.capnp -#ifndef CAPNP_INCLUDED_c56be168dcbbc3c6_ -#define CAPNP_INCLUDED_c56be168dcbbc3c6_ +#pragma once #include +#include -#if CAPNP_VERSION != 6001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace schemas { @@ -295,6 +299,7 @@ struct Declaration::ParamList { enum Which: uint16_t { NAMED_LIST, TYPE, + STREAM, }; struct _capnpPrivate { @@ -810,11 +815,11 @@ class Expression::Reader { inline bool isList() const; inline bool hasList() const; - inline ::capnp::List< ::capnp::compiler::Expression>::Reader getList() const; + inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Reader getList() const; inline bool isTuple() const; inline bool hasTuple() const; - inline ::capnp::List< ::capnp::compiler::Expression::Param>::Reader getTuple() const; + inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Reader getTuple() const; inline ::uint32_t getStartByte() const; @@ -905,19 +910,19 @@ class Expression::Builder { inline bool isList(); inline bool hasList(); - inline ::capnp::List< ::capnp::compiler::Expression>::Builder getList(); - inline void setList( ::capnp::List< ::capnp::compiler::Expression>::Reader value); - inline ::capnp::List< ::capnp::compiler::Expression>::Builder initList(unsigned int size); - inline void adoptList(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression>> disownList(); + inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Builder getList(); + inline void setList( ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Builder initList(unsigned int size); + inline void adoptList(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>> disownList(); inline bool isTuple(); inline bool hasTuple(); - inline ::capnp::List< ::capnp::compiler::Expression::Param>::Builder getTuple(); - inline void setTuple( ::capnp::List< ::capnp::compiler::Expression::Param>::Reader value); - inline ::capnp::List< ::capnp::compiler::Expression::Param>::Builder initTuple(unsigned int size); - inline void adoptTuple(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param>> disownTuple(); + inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Builder getTuple(); + inline void setTuple( ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Builder initTuple(unsigned int size); + inline void adoptTuple(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>> disownTuple(); inline ::uint32_t getStartByte(); inline void setStartByte( ::uint32_t value); @@ -1115,7 +1120,7 @@ class Expression::Application::Reader { inline ::capnp::compiler::Expression::Reader getFunction() const; inline bool hasParams() const; - inline ::capnp::List< ::capnp::compiler::Expression::Param>::Reader getParams() const; + inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Reader getParams() const; private: ::capnp::_::StructReader _reader; @@ -1153,11 +1158,11 @@ class Expression::Application::Builder { inline ::capnp::Orphan< ::capnp::compiler::Expression> disownFunction(); inline bool hasParams(); - inline ::capnp::List< ::capnp::compiler::Expression::Param>::Builder getParams(); - inline void setParams( ::capnp::List< ::capnp::compiler::Expression::Param>::Reader value); - inline ::capnp::List< ::capnp::compiler::Expression::Param>::Builder initParams(unsigned int size); - inline void adoptParams(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param>> disownParams(); + inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Builder getParams(); + inline void setParams( ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Builder initParams(unsigned int size); + inline void adoptParams(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>> disownParams(); private: ::capnp::_::StructBuilder _builder; @@ -1303,10 +1308,10 @@ class Declaration::Reader { inline typename Id::Reader getId() const; inline bool hasNestedDecls() const; - inline ::capnp::List< ::capnp::compiler::Declaration>::Reader getNestedDecls() const; + inline ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>::Reader getNestedDecls() const; inline bool hasAnnotations() const; - inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Reader getAnnotations() const; + inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Reader getAnnotations() const; inline ::uint32_t getStartByte() const; @@ -1411,7 +1416,7 @@ class Declaration::Reader { inline ::capnp::Void getBuiltinAnyPointer() const; inline bool hasParameters() const; - inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>::Reader getParameters() const; + inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>::Reader getParameters() const; inline bool isBuiltinAnyStruct() const; inline ::capnp::Void getBuiltinAnyStruct() const; @@ -1462,18 +1467,18 @@ class Declaration::Builder { inline typename Id::Builder initId(); inline bool hasNestedDecls(); - inline ::capnp::List< ::capnp::compiler::Declaration>::Builder getNestedDecls(); - inline void setNestedDecls( ::capnp::List< ::capnp::compiler::Declaration>::Reader value); - inline ::capnp::List< ::capnp::compiler::Declaration>::Builder initNestedDecls(unsigned int size); - inline void adoptNestedDecls(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration>> disownNestedDecls(); + inline ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>::Builder getNestedDecls(); + inline void setNestedDecls( ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>::Builder initNestedDecls(unsigned int size); + inline void adoptNestedDecls(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>> disownNestedDecls(); inline bool hasAnnotations(); - inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Builder getAnnotations(); - inline void setAnnotations( ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Reader value); - inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Builder initAnnotations(unsigned int size); - inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>> disownAnnotations(); + inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Builder getAnnotations(); + inline void setAnnotations( ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Builder initAnnotations(unsigned int size); + inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>> disownAnnotations(); inline ::uint32_t getStartByte(); inline void setStartByte( ::uint32_t value); @@ -1621,11 +1626,11 @@ class Declaration::Builder { inline void setBuiltinAnyPointer( ::capnp::Void value = ::capnp::VOID); inline bool hasParameters(); - inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>::Builder getParameters(); - inline void setParameters( ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>::Reader value); - inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>::Builder initParameters(unsigned int size); - inline void adoptParameters(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>> disownParameters(); + inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>::Builder getParameters(); + inline void setParameters( ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>::Builder initParameters(unsigned int size); + inline void adoptParameters(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>> disownParameters(); inline bool isBuiltinAnyStruct(); inline ::capnp::Void getBuiltinAnyStruct(); @@ -1958,7 +1963,7 @@ class Declaration::ParamList::Reader { inline Which which() const; inline bool isNamedList() const; inline bool hasNamedList() const; - inline ::capnp::List< ::capnp::compiler::Declaration::Param>::Reader getNamedList() const; + inline ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>::Reader getNamedList() const; inline bool isType() const; inline bool hasType() const; @@ -1968,6 +1973,9 @@ class Declaration::ParamList::Reader { inline ::uint32_t getEndByte() const; + inline bool isStream() const; + inline ::capnp::Void getStream() const; + private: ::capnp::_::StructReader _reader; template @@ -1999,11 +2007,11 @@ class Declaration::ParamList::Builder { inline Which which(); inline bool isNamedList(); inline bool hasNamedList(); - inline ::capnp::List< ::capnp::compiler::Declaration::Param>::Builder getNamedList(); - inline void setNamedList( ::capnp::List< ::capnp::compiler::Declaration::Param>::Reader value); - inline ::capnp::List< ::capnp::compiler::Declaration::Param>::Builder initNamedList(unsigned int size); - inline void adoptNamedList(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::Param>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::Param>> disownNamedList(); + inline ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>::Builder getNamedList(); + inline void setNamedList( ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>::Builder initNamedList(unsigned int size); + inline void adoptNamedList(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>> disownNamedList(); inline bool isType(); inline bool hasType(); @@ -2019,6 +2027,10 @@ class Declaration::ParamList::Builder { inline ::uint32_t getEndByte(); inline void setEndByte( ::uint32_t value); + inline bool isStream(); + inline ::capnp::Void getStream(); + inline void setStream( ::capnp::Void value = ::capnp::VOID); + private: ::capnp::_::StructBuilder _builder; template @@ -2069,7 +2081,7 @@ class Declaration::Param::Reader { inline ::capnp::compiler::Expression::Reader getType() const; inline bool hasAnnotations() const; - inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Reader getAnnotations() const; + inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Reader getAnnotations() const; inline typename DefaultValue::Reader getDefaultValue() const; @@ -2120,11 +2132,11 @@ class Declaration::Param::Builder { inline ::capnp::Orphan< ::capnp::compiler::Expression> disownType(); inline bool hasAnnotations(); - inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Builder getAnnotations(); - inline void setAnnotations( ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Reader value); - inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Builder initAnnotations(unsigned int size); - inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>> disownAnnotations(); + inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Builder getAnnotations(); + inline void setAnnotations( ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Builder initAnnotations(unsigned int size); + inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>> disownAnnotations(); inline typename DefaultValue::Builder getDefaultValue(); inline typename DefaultValue::Builder initDefaultValue(); @@ -2733,7 +2745,7 @@ class Declaration::Interface::Reader { #endif // !CAPNP_LITE inline bool hasSuperclasses() const; - inline ::capnp::List< ::capnp::compiler::Expression>::Reader getSuperclasses() const; + inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Reader getSuperclasses() const; private: ::capnp::_::StructReader _reader; @@ -2764,11 +2776,11 @@ class Declaration::Interface::Builder { #endif // !CAPNP_LITE inline bool hasSuperclasses(); - inline ::capnp::List< ::capnp::compiler::Expression>::Builder getSuperclasses(); - inline void setSuperclasses( ::capnp::List< ::capnp::compiler::Expression>::Reader value); - inline ::capnp::List< ::capnp::compiler::Expression>::Builder initSuperclasses(unsigned int size); - inline void adoptSuperclasses(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression>> disownSuperclasses(); + inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Builder getSuperclasses(); + inline void setSuperclasses( ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Builder initSuperclasses(unsigned int size); + inline void adoptSuperclasses(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>> disownSuperclasses(); private: ::capnp::_::StructBuilder _builder; @@ -3585,41 +3597,41 @@ inline bool Expression::Builder::hasList() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Expression>::Reader Expression::Reader::getList() const { +inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Reader Expression::Reader::getList() const { KJ_IREQUIRE((which() == Expression::LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Expression>::Builder Expression::Builder::getList() { +inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Builder Expression::Builder::getList() { KJ_IREQUIRE((which() == Expression::LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void Expression::Builder::setList( ::capnp::List< ::capnp::compiler::Expression>::Reader value) { +inline void Expression::Builder::setList( ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Reader value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Expression::LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Expression>::Builder Expression::Builder::initList(unsigned int size) { +inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Builder Expression::Builder::initList(unsigned int size) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Expression::LIST); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::init(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void Expression::Builder::adoptList( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression>>&& value) { + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>&& value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Expression::LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::adopt(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression>> Expression::Builder::disownList() { +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>> Expression::Builder::disownList() { KJ_IREQUIRE((which() == Expression::LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -3639,41 +3651,41 @@ inline bool Expression::Builder::hasTuple() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Expression::Param>::Reader Expression::Reader::getTuple() const { +inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Reader Expression::Reader::getTuple() const { KJ_IREQUIRE((which() == Expression::TUPLE), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Expression::Param>::Builder Expression::Builder::getTuple() { +inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Builder Expression::Builder::getTuple() { KJ_IREQUIRE((which() == Expression::TUPLE), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void Expression::Builder::setTuple( ::capnp::List< ::capnp::compiler::Expression::Param>::Reader value) { +inline void Expression::Builder::setTuple( ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Reader value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Expression::TUPLE); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Expression::Param>::Builder Expression::Builder::initTuple(unsigned int size) { +inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Builder Expression::Builder::initTuple(unsigned int size) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Expression::TUPLE); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::init(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void Expression::Builder::adoptTuple( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param>>&& value) { + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>&& value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Expression::TUPLE); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::adopt(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param>> Expression::Builder::disownTuple() { +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>> Expression::Builder::disownTuple() { KJ_IREQUIRE((which() == Expression::TUPLE), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -4142,29 +4154,29 @@ inline bool Expression::Application::Builder::hasParams() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Expression::Param>::Reader Expression::Application::Reader::getParams() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Reader Expression::Application::Reader::getParams() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Expression::Param>::Builder Expression::Application::Builder::getParams() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Builder Expression::Application::Builder::getParams() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void Expression::Application::Builder::setParams( ::capnp::List< ::capnp::compiler::Expression::Param>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::set(_builder.getPointerField( +inline void Expression::Application::Builder::setParams( ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Expression::Param>::Builder Expression::Application::Builder::initParams(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>::Builder Expression::Application::Builder::initParams(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void Expression::Application::Builder::adoptParams( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param>> Expression::Application::Builder::disownParams() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>> Expression::Application::Builder::disownParams() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression::Param, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -4318,29 +4330,29 @@ inline bool Declaration::Builder::hasNestedDecls() { return !_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Declaration>::Reader Declaration::Reader::getNestedDecls() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>::Reader Declaration::Reader::getNestedDecls() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Declaration>::Builder Declaration::Builder::getNestedDecls() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>::Builder Declaration::Builder::getNestedDecls() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } -inline void Declaration::Builder::setNestedDecls( ::capnp::List< ::capnp::compiler::Declaration>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration>>::set(_builder.getPointerField( +inline void Declaration::Builder::setNestedDecls( ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Declaration>::Builder Declaration::Builder::initNestedDecls(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>::Builder Declaration::Builder::initNestedDecls(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), size); } inline void Declaration::Builder::adoptNestedDecls( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration>> Declaration::Builder::disownNestedDecls() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>> Declaration::Builder::disownNestedDecls() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } @@ -4352,29 +4364,29 @@ inline bool Declaration::Builder::hasAnnotations() { return !_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Reader Declaration::Reader::getAnnotations() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Reader Declaration::Reader::getAnnotations() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Builder Declaration::Builder::getAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Builder Declaration::Builder::getAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } -inline void Declaration::Builder::setAnnotations( ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::set(_builder.getPointerField( +inline void Declaration::Builder::setAnnotations( ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Builder Declaration::Builder::initAnnotations(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Builder Declaration::Builder::initAnnotations(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), size); } inline void Declaration::Builder::adoptAnnotations( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>> Declaration::Builder::disownAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>> Declaration::Builder::disownAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } @@ -5303,29 +5315,29 @@ inline bool Declaration::Builder::hasParameters() { return !_builder.getPointerField( ::capnp::bounded<7>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>::Reader Declaration::Reader::getParameters() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>::Reader Declaration::Reader::getParameters() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<7>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>::Builder Declaration::Builder::getParameters() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>::Builder Declaration::Builder::getParameters() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<7>() * ::capnp::POINTERS)); } -inline void Declaration::Builder::setParameters( ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>>::set(_builder.getPointerField( +inline void Declaration::Builder::setParameters( ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<7>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>::Builder Declaration::Builder::initParameters(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>::Builder Declaration::Builder::initParameters(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<7>() * ::capnp::POINTERS), size); } inline void Declaration::Builder::adoptParameters( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<7>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>> Declaration::Builder::disownParameters() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>> Declaration::Builder::disownParameters() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::BrandParameter, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<7>() * ::capnp::POINTERS)); } @@ -5638,41 +5650,41 @@ inline bool Declaration::ParamList::Builder::hasNamedList() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Declaration::Param>::Reader Declaration::ParamList::Reader::getNamedList() const { +inline ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>::Reader Declaration::ParamList::Reader::getNamedList() const { KJ_IREQUIRE((which() == Declaration::ParamList::NAMED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Declaration::Param>::Builder Declaration::ParamList::Builder::getNamedList() { +inline ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>::Builder Declaration::ParamList::Builder::getNamedList() { KJ_IREQUIRE((which() == Declaration::ParamList::NAMED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void Declaration::ParamList::Builder::setNamedList( ::capnp::List< ::capnp::compiler::Declaration::Param>::Reader value) { +inline void Declaration::ParamList::Builder::setNamedList( ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>::Reader value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Declaration::ParamList::NAMED_LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Declaration::Param>::Builder Declaration::ParamList::Builder::initNamedList(unsigned int size) { +inline ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>::Builder Declaration::ParamList::Builder::initNamedList(unsigned int size) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Declaration::ParamList::NAMED_LIST); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param>>::init(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void Declaration::ParamList::Builder::adoptNamedList( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::Param>>&& value) { + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>>&& value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Declaration::ParamList::NAMED_LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param>>::adopt(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::Param>> Declaration::ParamList::Builder::disownNamedList() { +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>> Declaration::ParamList::Builder::disownNamedList() { KJ_IREQUIRE((which() == Declaration::ParamList::NAMED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::Param, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -5758,6 +5770,32 @@ inline void Declaration::ParamList::Builder::setEndByte( ::uint32_t value) { ::capnp::bounded<2>() * ::capnp::ELEMENTS, value); } +inline bool Declaration::ParamList::Reader::isStream() const { + return which() == Declaration::ParamList::STREAM; +} +inline bool Declaration::ParamList::Builder::isStream() { + return which() == Declaration::ParamList::STREAM; +} +inline ::capnp::Void Declaration::ParamList::Reader::getStream() const { + KJ_IREQUIRE((which() == Declaration::ParamList::STREAM), + "Must check which() before get()ing a union member."); + return _reader.getDataField< ::capnp::Void>( + ::capnp::bounded<0>() * ::capnp::ELEMENTS); +} + +inline ::capnp::Void Declaration::ParamList::Builder::getStream() { + KJ_IREQUIRE((which() == Declaration::ParamList::STREAM), + "Must check which() before get()ing a union member."); + return _builder.getDataField< ::capnp::Void>( + ::capnp::bounded<0>() * ::capnp::ELEMENTS); +} +inline void Declaration::ParamList::Builder::setStream( ::capnp::Void value) { + _builder.setDataField( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, Declaration::ParamList::STREAM); + _builder.setDataField< ::capnp::Void>( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, value); +} + inline bool Declaration::Param::Reader::hasName() const { return !_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); @@ -5844,29 +5882,29 @@ inline bool Declaration::Param::Builder::hasAnnotations() { return !_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Reader Declaration::Param::Reader::getAnnotations() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Reader Declaration::Param::Reader::getAnnotations() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Builder Declaration::Param::Builder::getAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Builder Declaration::Param::Builder::getAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } -inline void Declaration::Param::Builder::setAnnotations( ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::set(_builder.getPointerField( +inline void Declaration::Param::Builder::setAnnotations( ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>::Builder Declaration::Param::Builder::initAnnotations(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>::Builder Declaration::Param::Builder::initAnnotations(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), size); } inline void Declaration::Param::Builder::adoptAnnotations( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>> Declaration::Param::Builder::disownAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>> Declaration::Param::Builder::disownAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Declaration::AnnotationApplication, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } @@ -6415,29 +6453,29 @@ inline bool Declaration::Interface::Builder::hasSuperclasses() { return !_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Expression>::Reader Declaration::Interface::Reader::getSuperclasses() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Reader Declaration::Interface::Reader::getSuperclasses() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Expression>::Builder Declaration::Interface::Builder::getSuperclasses() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Builder Declaration::Interface::Builder::getSuperclasses() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS)); } -inline void Declaration::Interface::Builder::setSuperclasses( ::capnp::List< ::capnp::compiler::Expression>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::set(_builder.getPointerField( +inline void Declaration::Interface::Builder::setSuperclasses( ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Expression>::Builder Declaration::Interface::Builder::initSuperclasses(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>::Builder Declaration::Interface::Builder::initSuperclasses(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS), size); } inline void Declaration::Interface::Builder::adoptSuperclasses( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression>> Declaration::Interface::Builder::disownSuperclasses() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>> Declaration::Interface::Builder::disownSuperclasses() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Expression, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS)); } @@ -6834,4 +6872,5 @@ inline ::capnp::Orphan< ::capnp::compiler::Declaration> ParsedFile::Builder::dis } // namespace } // namespace -#endif // CAPNP_INCLUDED_c56be168dcbbc3c6_ +CAPNP_END_HEADER + diff --git a/c++/src/capnp/compiler/lexer.c++ b/c++/src/capnp/compiler/lexer.c++ index 02a1dab572..22bd04d66d 100644 --- a/c++/src/capnp/compiler/lexer.c++ +++ b/c++/src/capnp/compiler/lexer.c++ @@ -186,6 +186,20 @@ Lexer::Lexer(Orphanage orphanageParam, ErrorReporter& errorReporter) initTok(t, loc).setStringLiteral(text); return t; }), + p::transformWithLocation( + sequence(p::exactChar<'`'>(), p::many(p::anyOfChars("\r\n").invert())), + [this](Location loc, kj::Array text) -> Orphan { + // Backtick-quoted line. Note that we assume either `\r` or `\n` is a valid line + // ending (to cover all known line ending formats) but we replace the line ending + // with `\n`. This way, changing the line endings of your source code doesn't affect + // the compiled code. + auto t = orphanage.newOrphan(); + // Append '\n' to the text. + auto out = initTok(t, loc).initStringLiteral(text.size() + 1); + memcpy(out.begin(), text.begin(), text.size()); + out[out.size() - 1] = '\n'; + return t; + }), p::transformWithLocation(p::doubleQuotedHexBinary, [this](Location loc, kj::Array data) -> Orphan { auto t = orphanage.newOrphan(); diff --git a/c++/src/capnp/compiler/lexer.capnp.c++ b/c++/src/capnp/compiler/lexer.capnp.c++ index 7cad320e27..316255fc91 100644 --- a/c++/src/capnp/compiler/lexer.capnp.c++ +++ b/c++/src/capnp/compiler/lexer.capnp.c++ @@ -211,7 +211,7 @@ static const uint16_t m_91cc55cd57de5419[] = {9, 6, 8, 3, 0, 2, 4, 5, 7, 1}; static const uint16_t i_91cc55cd57de5419[] = {0, 1, 2, 3, 4, 5, 6, 9, 7, 8}; const ::capnp::_::RawSchema s_91cc55cd57de5419 = { 0x91cc55cd57de5419, b_91cc55cd57de5419.words, 195, d_91cc55cd57de5419, m_91cc55cd57de5419, - 1, 10, i_91cc55cd57de5419, nullptr, nullptr, { &s_91cc55cd57de5419, nullptr, nullptr, 0, 0, nullptr } + 1, 10, i_91cc55cd57de5419, nullptr, nullptr, { &s_91cc55cd57de5419, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<119> b_c6725e678d60fa37 = { @@ -345,7 +345,7 @@ static const uint16_t m_c6725e678d60fa37[] = {2, 3, 5, 1, 4, 0}; static const uint16_t i_c6725e678d60fa37[] = {1, 2, 0, 3, 4, 5}; const ::capnp::_::RawSchema s_c6725e678d60fa37 = { 0xc6725e678d60fa37, b_c6725e678d60fa37.words, 119, d_c6725e678d60fa37, m_c6725e678d60fa37, - 2, 6, i_c6725e678d60fa37, nullptr, nullptr, { &s_c6725e678d60fa37, nullptr, nullptr, 0, 0, nullptr } + 2, 6, i_c6725e678d60fa37, nullptr, nullptr, { &s_c6725e678d60fa37, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<38> b_9e69a92512b19d18 = { @@ -397,7 +397,7 @@ static const uint16_t m_9e69a92512b19d18[] = {0}; static const uint16_t i_9e69a92512b19d18[] = {0}; const ::capnp::_::RawSchema s_9e69a92512b19d18 = { 0x9e69a92512b19d18, b_9e69a92512b19d18.words, 38, d_9e69a92512b19d18, m_9e69a92512b19d18, - 1, 1, i_9e69a92512b19d18, nullptr, nullptr, { &s_9e69a92512b19d18, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_9e69a92512b19d18, nullptr, nullptr, { &s_9e69a92512b19d18, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<40> b_a11f97b9d6c73dd4 = { @@ -451,7 +451,7 @@ static const uint16_t m_a11f97b9d6c73dd4[] = {0}; static const uint16_t i_a11f97b9d6c73dd4[] = {0}; const ::capnp::_::RawSchema s_a11f97b9d6c73dd4 = { 0xa11f97b9d6c73dd4, b_a11f97b9d6c73dd4.words, 40, d_a11f97b9d6c73dd4, m_a11f97b9d6c73dd4, - 1, 1, i_a11f97b9d6c73dd4, nullptr, nullptr, { &s_a11f97b9d6c73dd4, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_a11f97b9d6c73dd4, nullptr, nullptr, { &s_a11f97b9d6c73dd4, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -463,35 +463,51 @@ namespace capnp { namespace compiler { // Token +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Token::_capnpPrivate::dataWordSize; constexpr uint16_t Token::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Token::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Token::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Statement +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Statement::_capnpPrivate::dataWordSize; constexpr uint16_t Statement::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Statement::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Statement::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // LexedTokens +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LexedTokens::_capnpPrivate::dataWordSize; constexpr uint16_t LexedTokens::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LexedTokens::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LexedTokens::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // LexedStatements +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t LexedStatements::_capnpPrivate::dataWordSize; constexpr uint16_t LexedStatements::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind LexedStatements::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* LexedStatements::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/c++/src/capnp/compiler/lexer.capnp.h b/c++/src/capnp/compiler/lexer.capnp.h index c9bd07e613..6dbd7621b7 100644 --- a/c++/src/capnp/compiler/lexer.capnp.h +++ b/c++/src/capnp/compiler/lexer.capnp.h @@ -1,16 +1,20 @@ // Generated by Cap'n Proto compiler, DO NOT EDIT // source: lexer.capnp -#ifndef CAPNP_INCLUDED_a73956d2621fc3ee_ -#define CAPNP_INCLUDED_a73956d2621fc3ee_ +#pragma once #include +#include -#if CAPNP_VERSION != 6001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace schemas { @@ -139,11 +143,11 @@ class Token::Reader { inline bool isParenthesizedList() const; inline bool hasParenthesizedList() const; - inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Reader getParenthesizedList() const; + inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Reader getParenthesizedList() const; inline bool isBracketedList() const; inline bool hasBracketedList() const; - inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Reader getBracketedList() const; + inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Reader getBracketedList() const; inline ::uint32_t getStartByte() const; @@ -216,21 +220,21 @@ class Token::Builder { inline bool isParenthesizedList(); inline bool hasParenthesizedList(); - inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Builder getParenthesizedList(); - inline void setParenthesizedList( ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Reader value); - inline void setParenthesizedList(::kj::ArrayPtr::Reader> value); - inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Builder initParenthesizedList(unsigned int size); - inline void adoptParenthesizedList(::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>> disownParenthesizedList(); + inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Builder getParenthesizedList(); + inline void setParenthesizedList( ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Reader value); + inline void setParenthesizedList(::kj::ArrayPtr::Reader> value); + inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Builder initParenthesizedList(unsigned int size); + inline void adoptParenthesizedList(::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>> disownParenthesizedList(); inline bool isBracketedList(); inline bool hasBracketedList(); - inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Builder getBracketedList(); - inline void setBracketedList( ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Reader value); - inline void setBracketedList(::kj::ArrayPtr::Reader> value); - inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Builder initBracketedList(unsigned int size); - inline void adoptBracketedList(::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>> disownBracketedList(); + inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Builder getBracketedList(); + inline void setBracketedList( ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Reader value); + inline void setBracketedList(::kj::ArrayPtr::Reader> value); + inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Builder initBracketedList(unsigned int size); + inline void adoptBracketedList(::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>> disownBracketedList(); inline ::uint32_t getStartByte(); inline void setStartByte( ::uint32_t value); @@ -291,14 +295,14 @@ class Statement::Reader { inline Which which() const; inline bool hasTokens() const; - inline ::capnp::List< ::capnp::compiler::Token>::Reader getTokens() const; + inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Reader getTokens() const; inline bool isLine() const; inline ::capnp::Void getLine() const; inline bool isBlock() const; inline bool hasBlock() const; - inline ::capnp::List< ::capnp::compiler::Statement>::Reader getBlock() const; + inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Reader getBlock() const; inline bool hasDocComment() const; inline ::capnp::Text::Reader getDocComment() const; @@ -337,11 +341,11 @@ class Statement::Builder { inline Which which(); inline bool hasTokens(); - inline ::capnp::List< ::capnp::compiler::Token>::Builder getTokens(); - inline void setTokens( ::capnp::List< ::capnp::compiler::Token>::Reader value); - inline ::capnp::List< ::capnp::compiler::Token>::Builder initTokens(unsigned int size); - inline void adoptTokens(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token>> disownTokens(); + inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Builder getTokens(); + inline void setTokens( ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Builder initTokens(unsigned int size); + inline void adoptTokens(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>> disownTokens(); inline bool isLine(); inline ::capnp::Void getLine(); @@ -349,11 +353,11 @@ class Statement::Builder { inline bool isBlock(); inline bool hasBlock(); - inline ::capnp::List< ::capnp::compiler::Statement>::Builder getBlock(); - inline void setBlock( ::capnp::List< ::capnp::compiler::Statement>::Reader value); - inline ::capnp::List< ::capnp::compiler::Statement>::Builder initBlock(unsigned int size); - inline void adoptBlock(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement>> disownBlock(); + inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Builder getBlock(); + inline void setBlock( ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Builder initBlock(unsigned int size); + inline void adoptBlock(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>> disownBlock(); inline bool hasDocComment(); inline ::capnp::Text::Builder getDocComment(); @@ -412,7 +416,7 @@ class LexedTokens::Reader { #endif // !CAPNP_LITE inline bool hasTokens() const; - inline ::capnp::List< ::capnp::compiler::Token>::Reader getTokens() const; + inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Reader getTokens() const; private: ::capnp::_::StructReader _reader; @@ -443,11 +447,11 @@ class LexedTokens::Builder { #endif // !CAPNP_LITE inline bool hasTokens(); - inline ::capnp::List< ::capnp::compiler::Token>::Builder getTokens(); - inline void setTokens( ::capnp::List< ::capnp::compiler::Token>::Reader value); - inline ::capnp::List< ::capnp::compiler::Token>::Builder initTokens(unsigned int size); - inline void adoptTokens(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token>> disownTokens(); + inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Builder getTokens(); + inline void setTokens( ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Builder initTokens(unsigned int size); + inline void adoptTokens(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>> disownTokens(); private: ::capnp::_::StructBuilder _builder; @@ -493,7 +497,7 @@ class LexedStatements::Reader { #endif // !CAPNP_LITE inline bool hasStatements() const; - inline ::capnp::List< ::capnp::compiler::Statement>::Reader getStatements() const; + inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Reader getStatements() const; private: ::capnp::_::StructReader _reader; @@ -524,11 +528,11 @@ class LexedStatements::Builder { #endif // !CAPNP_LITE inline bool hasStatements(); - inline ::capnp::List< ::capnp::compiler::Statement>::Builder getStatements(); - inline void setStatements( ::capnp::List< ::capnp::compiler::Statement>::Reader value); - inline ::capnp::List< ::capnp::compiler::Statement>::Builder initStatements(unsigned int size); - inline void adoptStatements(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement>> disownStatements(); + inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Builder getStatements(); + inline void setStatements( ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Builder initStatements(unsigned int size); + inline void adoptStatements(::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>> disownStatements(); private: ::capnp::_::StructBuilder _builder; @@ -797,47 +801,47 @@ inline bool Token::Builder::hasParenthesizedList() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Reader Token::Reader::getParenthesizedList() const { +inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Reader Token::Reader::getParenthesizedList() const { KJ_IREQUIRE((which() == Token::PARENTHESIZED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Builder Token::Builder::getParenthesizedList() { +inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Builder Token::Builder::getParenthesizedList() { KJ_IREQUIRE((which() == Token::PARENTHESIZED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void Token::Builder::setParenthesizedList( ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Reader value) { +inline void Token::Builder::setParenthesizedList( ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Reader value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Token::PARENTHESIZED_LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline void Token::Builder::setParenthesizedList(::kj::ArrayPtr::Reader> value) { +inline void Token::Builder::setParenthesizedList(::kj::ArrayPtr::Reader> value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Token::PARENTHESIZED_LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Builder Token::Builder::initParenthesizedList(unsigned int size) { +inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Builder Token::Builder::initParenthesizedList(unsigned int size) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Token::PARENTHESIZED_LIST); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::init(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void Token::Builder::adoptParenthesizedList( - ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>&& value) { + ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>&& value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Token::PARENTHESIZED_LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::adopt(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>> Token::Builder::disownParenthesizedList() { +inline ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>> Token::Builder::disownParenthesizedList() { KJ_IREQUIRE((which() == Token::PARENTHESIZED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -857,47 +861,47 @@ inline bool Token::Builder::hasBracketedList() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Reader Token::Reader::getBracketedList() const { +inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Reader Token::Reader::getBracketedList() const { KJ_IREQUIRE((which() == Token::BRACKETED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Builder Token::Builder::getBracketedList() { +inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Builder Token::Builder::getBracketedList() { KJ_IREQUIRE((which() == Token::BRACKETED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void Token::Builder::setBracketedList( ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Reader value) { +inline void Token::Builder::setBracketedList( ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Reader value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Token::BRACKETED_LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline void Token::Builder::setBracketedList(::kj::ArrayPtr::Reader> value) { +inline void Token::Builder::setBracketedList(::kj::ArrayPtr::Reader> value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Token::BRACKETED_LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>::Builder Token::Builder::initBracketedList(unsigned int size) { +inline ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>::Builder Token::Builder::initBracketedList(unsigned int size) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Token::BRACKETED_LIST); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::init(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void Token::Builder::adoptBracketedList( - ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>&& value) { + ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>&& value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Token::BRACKETED_LIST); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::adopt(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>> Token::Builder::disownBracketedList() { +inline ::capnp::Orphan< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>> Token::Builder::disownBracketedList() { KJ_IREQUIRE((which() == Token::BRACKETED_LIST), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token>>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>, ::capnp::Kind::LIST>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -1000,29 +1004,29 @@ inline bool Statement::Builder::hasTokens() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Token>::Reader Statement::Reader::getTokens() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Reader Statement::Reader::getTokens() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Token>::Builder Statement::Builder::getTokens() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Builder Statement::Builder::getTokens() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void Statement::Builder::setTokens( ::capnp::List< ::capnp::compiler::Token>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::set(_builder.getPointerField( +inline void Statement::Builder::setTokens( ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Token>::Builder Statement::Builder::initTokens(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Builder Statement::Builder::initTokens(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void Statement::Builder::adoptTokens( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token>> Statement::Builder::disownTokens() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>> Statement::Builder::disownTokens() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -1068,41 +1072,41 @@ inline bool Statement::Builder::hasBlock() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Statement>::Reader Statement::Reader::getBlock() const { +inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Reader Statement::Reader::getBlock() const { KJ_IREQUIRE((which() == Statement::BLOCK), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Statement>::Builder Statement::Builder::getBlock() { +inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Builder Statement::Builder::getBlock() { KJ_IREQUIRE((which() == Statement::BLOCK), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void Statement::Builder::setBlock( ::capnp::List< ::capnp::compiler::Statement>::Reader value) { +inline void Statement::Builder::setBlock( ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Reader value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Statement::BLOCK); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Statement>::Builder Statement::Builder::initBlock(unsigned int size) { +inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Builder Statement::Builder::initBlock(unsigned int size) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Statement::BLOCK); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::init(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void Statement::Builder::adoptBlock( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement>>&& value) { + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>&& value) { _builder.setDataField( ::capnp::bounded<0>() * ::capnp::ELEMENTS, Statement::BLOCK); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::adopt(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement>> Statement::Builder::disownBlock() { +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>> Statement::Builder::disownBlock() { KJ_IREQUIRE((which() == Statement::BLOCK), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -1176,29 +1180,29 @@ inline bool LexedTokens::Builder::hasTokens() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Token>::Reader LexedTokens::Reader::getTokens() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Reader LexedTokens::Reader::getTokens() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Token>::Builder LexedTokens::Builder::getTokens() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Builder LexedTokens::Builder::getTokens() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void LexedTokens::Builder::setTokens( ::capnp::List< ::capnp::compiler::Token>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::set(_builder.getPointerField( +inline void LexedTokens::Builder::setTokens( ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Token>::Builder LexedTokens::Builder::initTokens(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>::Builder LexedTokens::Builder::initTokens(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void LexedTokens::Builder::adoptTokens( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token>> LexedTokens::Builder::disownTokens() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>> LexedTokens::Builder::disownTokens() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Token, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -1210,33 +1214,34 @@ inline bool LexedStatements::Builder::hasStatements() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::compiler::Statement>::Reader LexedStatements::Reader::getStatements() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Reader LexedStatements::Reader::getStatements() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::compiler::Statement>::Builder LexedStatements::Builder::getStatements() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Builder LexedStatements::Builder::getStatements() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void LexedStatements::Builder::setStatements( ::capnp::List< ::capnp::compiler::Statement>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::set(_builder.getPointerField( +inline void LexedStatements::Builder::setStatements( ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::compiler::Statement>::Builder LexedStatements::Builder::initStatements(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>::Builder LexedStatements::Builder::initStatements(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void LexedStatements::Builder::adoptStatements( - ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement>> LexedStatements::Builder::disownStatements() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>> LexedStatements::Builder::disownStatements() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::compiler::Statement, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } } // namespace } // namespace -#endif // CAPNP_INCLUDED_a73956d2621fc3ee_ +CAPNP_END_HEADER + diff --git a/c++/src/capnp/compiler/lexer.h b/c++/src/capnp/compiler/lexer.h index d18923d0e3..a1c05418b6 100644 --- a/c++/src/capnp/compiler/lexer.h +++ b/c++/src/capnp/compiler/lexer.h @@ -19,18 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_COMPILER_LEXER_H_ -#define CAPNP_COMPILER_LEXER_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include #include #include "error-reporter.h" +CAPNP_BEGIN_HEADER + namespace capnp { namespace compiler { @@ -99,4 +96,4 @@ class Lexer { } // namespace compiler } // namespace capnp -#endif // CAPNP_COMPILER_LEXER_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/md5.h b/c++/src/capnp/compiler/md5.h deleted file mode 100644 index 3c19d4bdbd..0000000000 --- a/c++/src/capnp/compiler/md5.h +++ /dev/null @@ -1,80 +0,0 @@ -// This file was modified by Kenton Varda from code placed in the public domain. -// The code, which was originally C, was modified to give it a C++ interface. -// The original code bore the following notice: - -/* - * This is an OpenSSL-compatible implementation of the RSA Data Security, Inc. - * MD5 Message-Digest Algorithm (RFC 1321). - * - * Homepage: - * http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5 - * - * Author: - * Alexander Peslyak, better known as Solar Designer - * - * This software was written by Alexander Peslyak in 2001. No copyright is - * claimed, and the software is hereby placed in the public domain. - * In case this attempt to disclaim copyright and place the software in the - * public domain is deemed null and void, then the software is - * Copyright (c) 2001 Alexander Peslyak and it is hereby released to the - * general public under the following terms: - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted. - * - * There's ABSOLUTELY NO WARRANTY, express or implied. - * - * See md5.c for more information. - */ - -// TODO(someday): Put in KJ? - -#ifndef CAPNP_COMPILER_MD5_H -#define CAPNP_COMPILER_MD5_H - -#include -#include - -namespace capnp { -namespace compiler { - -class Md5 { -public: - Md5(); - - void update(kj::ArrayPtr data); - inline void update(kj::ArrayPtr data) { - return update(data.asBytes()); - } - inline void update(kj::StringPtr data) { - return update(data.asArray()); - } - inline void update(const char* data) { - return update(kj::StringPtr(data)); - } - - kj::ArrayPtr finish(); - kj::StringPtr finishAsHex(); - -private: - /* Any 32-bit or wider unsigned integer data type will do */ - typedef unsigned int MD5_u32plus; - - bool finished = false; - - typedef struct { - MD5_u32plus lo, hi; - MD5_u32plus a, b, c, d; - kj::byte buffer[64]; - MD5_u32plus block[16]; - } MD5_CTX; - - MD5_CTX ctx; - - const kj::byte* body(const kj::byte* ptr, size_t size); -}; - -} // namespace compiler -} // namespace capnp - -#endif // CAPNP_COMPILER_MD5_H diff --git a/c++/src/capnp/compiler/module-loader.c++ b/c++/src/capnp/compiler/module-loader.c++ index f0d5b12b75..d72b0fce0c 100644 --- a/c++/src/capnp/compiler/module-loader.c++ +++ b/c++/src/capnp/compiler/module-loader.c++ @@ -27,227 +27,145 @@ #include #include #include -#include -#include -#include -#include -#include -#include - -#if _WIN32 -#include -#else -#include -#endif +#include namespace capnp { namespace compiler { namespace { -class MmapDisposer: public kj::ArrayDisposer { -protected: - void disposeImpl(void* firstElement, size_t elementSize, size_t elementCount, - size_t capacity, void (*destroyElement)(void*)) const { -#if _WIN32 - KJ_ASSERT(UnmapViewOfFile(firstElement)); -#else - munmap(firstElement, elementSize * elementCount); -#endif - } -}; - -KJ_CONSTEXPR(static const) MmapDisposer mmapDisposer = MmapDisposer(); - -kj::Array mmapForRead(kj::StringPtr filename) { - int fd; - // We already established that the file exists, so this should not fail. - KJ_SYSCALL(fd = open(filename.cStr(), O_RDONLY), filename); - kj::AutoCloseFd closer(fd); - - struct stat stats; - KJ_SYSCALL(fstat(fd, &stats)); - - if (S_ISREG(stats.st_mode)) { - if (stats.st_size == 0) { - // mmap()ing zero bytes will fail. - return nullptr; +struct FileKey { + // Key type for the modules map. We need to implement some complicated heuristics to detect when + // two files are actually the same underlying file on disk, in order to handle the case where + // people have mapped the same file into multiple locations in the import tree, whether by + // passing overlapping import paths, weird symlinks, or whatever. + // + // This is probably over-engineered. + + const kj::ReadableDirectory& baseDir; + kj::PathPtr path; + kj::Maybe file; + uint64_t hashCode; + uint64_t size; + kj::Date lastModified; + + FileKey(const kj::ReadableDirectory& baseDir, kj::PathPtr path) + : baseDir(baseDir), path(path), file(nullptr), + hashCode(0), size(0), lastModified(kj::UNIX_EPOCH) {} + FileKey(const kj::ReadableDirectory& baseDir, kj::PathPtr path, const kj::ReadableFile& file) + : FileKey(baseDir, path, file, file.stat()) {} + + FileKey(const kj::ReadableDirectory& baseDir, kj::PathPtr path, const kj::ReadableFile& file, + kj::FsNode::Metadata meta) + : baseDir(baseDir), path(path), file(&file), + hashCode(meta.hashCode), size(meta.size), lastModified(meta.lastModified) {} + + bool operator==(const FileKey& other) const { + // Allow matching on baseDir and path without a file. + if (&baseDir == &other.baseDir && path == other.path) return true; + if (file == nullptr || other.file == nullptr) return false; + + // Try comparing various file metadata to rule out obvious differences. + if (hashCode != other.hashCode) return false; + if (size != other.size || lastModified != other.lastModified) return false; + if (path.size() > 0 && other.path.size() > 0 && + path[path.size() - 1] != other.path[other.path.size() - 1]) { + // Names differ, so probably not the same file. + return false; } - // Regular file. Just mmap() it. -#if _WIN32 - HANDLE handle = reinterpret_cast(_get_osfhandle(fd)); - KJ_ASSERT(handle != INVALID_HANDLE_VALUE); - HANDLE mappingHandle = CreateFileMapping( - handle, NULL, PAGE_READONLY, 0, stats.st_size, NULL); - KJ_ASSERT(mappingHandle != INVALID_HANDLE_VALUE); - KJ_DEFER(KJ_ASSERT(CloseHandle(mappingHandle))); - const void* mapping = MapViewOfFile(mappingHandle, FILE_MAP_READ, 0, 0, stats.st_size); -#else // _WIN32 - const void* mapping = mmap(NULL, stats.st_size, PROT_READ, MAP_SHARED, fd, 0); - if (mapping == MAP_FAILED) { - KJ_FAIL_SYSCALL("mmap", errno, filename); + // Same file hash, but different paths, but same size and modification date. This could be a + // case of two different import paths overlapping and containing the same file. We'll need to + // check the content. + auto mapping1 = KJ_ASSERT_NONNULL(file).mmap(0, size); + auto mapping2 = KJ_ASSERT_NONNULL(other.file).mmap(0, size); + if (memcmp(mapping1.begin(), mapping2.begin(), size) != 0) return false; + + if (path == other.path) { + // Exactly the same content was mapped at exactly the same path relative to two different + // import directories. This can only really happen if this was one of the files passed on + // the command line, but its --src-prefix is not also an import path, but some other + // directory containing the same file was given as an import path. Whatever, we'll ignore + // this. + return true; } -#endif // _WIN32, else - return kj::Array( - reinterpret_cast(mapping), stats.st_size, mmapDisposer); - } else { - // This could be a stream of some sort, like a pipe. Fall back to read(). - // TODO(cleanup): This does a lot of copies. Not sure I care. - kj::Vector data(8192); - - byte buffer[4096]; - for (;;) { - kj::miniposix::ssize_t n; - KJ_SYSCALL(n = read(fd, buffer, sizeof(buffer))); - if (n == 0) break; - data.addAll(buffer, buffer + n); + // Exactly the same content! + static bool warned = false; + if (!warned) { + KJ_LOG(WARNING, + "Found exactly the same source file mapped at two different paths. This suggests " + "that your -I and --src-prefix flags are overlapping or inconsistent. Remember, these " + "flags should only specify directories that are logical 'roots' of the source tree. " + "It should never be the case that one of the import directories contains another one of " + "them.", + path, other.path); + warned = true; } - return data.releaseAsArray(); + return true; } -} - -static char* canonicalizePath(char* path) { - // Taken from some old C code of mine. - - // Preconditions: - // - path has already been determined to be relative, perhaps because the pointer actually points - // into the middle of some larger path string, in which case it must point to the character - // immediately after a '/'. - - // Invariants: - // - src points to the beginning of a path component. - // - dst points to the location where the path component should end up, if it is not special. - // - src == path or src[-1] == '/'. - // - dst == path or dst[-1] == '/'. - - char* src = path; - char* dst = path; - char* locked = dst; // dst cannot backtrack past this - char* partEnd; - bool hasMore; - - for (;;) { - while (*src == '/') { - // Skip duplicate slash. - ++src; - } - - partEnd = strchr(src, '/'); - hasMore = partEnd != NULL; - if (hasMore) { - *partEnd = '\0'; - } else { - partEnd = src + strlen(src); - } - - if (strcmp(src, ".") == 0) { - // Skip it. - } else if (strcmp(src, "..") == 0) { - if (dst > locked) { - // Backtrack over last path component. - --dst; - while (dst > locked && dst[-1] != '/') --dst; - } else { - locked += 3; - goto copy; - } - } else { - // Copy if needed. - copy: - if (dst < src) { - memmove(dst, src, partEnd - src); - dst += partEnd - src; - } else { - dst = partEnd; - } - *dst++ = '/'; - } +}; - if (hasMore) { - src = partEnd + 1; +struct FileKeyHash { + size_t operator()(const FileKey& key) const { + if (sizeof(size_t) < sizeof(key.hashCode)) { + // 32-bit system, do more mixing + return (key.hashCode >> 32) * 31 + static_cast(key.hashCode) + + key.size * 103 + (key.lastModified - kj::UNIX_EPOCH) / kj::MILLISECONDS * 73; } else { - // Oops, we have to remove the trailing '/'. - if (dst == path) { - // Oops, there is no trailing '/'. We have to return ".". - strcpy(path, "."); - return path + 1; - } else { - // Remove the trailing '/'. Note that this means that opening the file will work even - // if it is not a directory, where normally it should fail on non-directories when a - // trailing '/' is present. If this is a problem, we need to add some sort of special - // handling for this case where we stat() it separately to check if it is a directory, - // because Ekam findInput will not accept a trailing '/'. - --dst; - *dst = '\0'; - return dst; - } + return key.hashCode + key.size * 103 + + (key.lastModified - kj::UNIX_EPOCH) / kj::NANOSECONDS * 73ull; } } -} - -kj::String canonicalizePath(kj::StringPtr path) { - KJ_STACK_ARRAY(char, result, path.size() + 1, 128, 512); - strcpy(result.begin(), path.begin()); - - char* start = path.startsWith("/") ? result.begin() + 1 : result.begin(); - char* end = canonicalizePath(start); - return kj::heapString(result.slice(0, end - result.begin())); -} - -kj::String catPath(kj::StringPtr base, kj::StringPtr add) { - if (add.size() > 0 && add[0] == '/') { - return kj::heapString(add); - } - - const char* pos = base.end(); - while (pos > base.begin() && pos[-1] != '/') { - --pos; - } - - return kj::str(base.slice(0, pos - base.begin()), add); -} - -} // namespace +}; +}; class ModuleLoader::Impl { public: - Impl(GlobalErrorReporter& errorReporter): errorReporter(errorReporter) {} + Impl(GlobalErrorReporter& errorReporter) + : errorReporter(errorReporter) {} - void addImportPath(kj::String path) { - searchPath.add(kj::heapString(kj::mv(path))); + void addImportPath(const kj::ReadableDirectory& dir) { + searchPath.add(&dir); } - kj::Maybe loadModule(kj::StringPtr localName, kj::StringPtr sourceName); - kj::Maybe loadModuleFromSearchPath(kj::StringPtr sourceName); - kj::Maybe> readEmbed(kj::StringPtr localName, kj::StringPtr sourceName); - kj::Maybe> readEmbedFromSearchPath(kj::StringPtr sourceName); + kj::Maybe loadModule(const kj::ReadableDirectory& dir, kj::PathPtr path); + kj::Maybe loadModuleFromSearchPath(kj::PathPtr path); + kj::Maybe> readEmbed(const kj::ReadableDirectory& dir, kj::PathPtr path); + kj::Maybe> readEmbedFromSearchPath(kj::PathPtr path); GlobalErrorReporter& getErrorReporter() { return errorReporter; } + void setFileIdsRequired(bool value) { fileIdsRequired = value; } + bool areFileIdsRequired() { return fileIdsRequired; } + private: GlobalErrorReporter& errorReporter; - kj::Vector searchPath; - std::map> modules; + kj::Vector searchPath; + std::unordered_map, FileKeyHash> modules; + bool fileIdsRequired = true; }; class ModuleLoader::ModuleImpl final: public Module { public: - ModuleImpl(ModuleLoader::Impl& loader, kj::String localName, kj::String sourceName) - : loader(loader), localName(kj::mv(localName)), sourceName(kj::mv(sourceName)) {} + ModuleImpl(ModuleLoader::Impl& loader, kj::Own file, + const kj::ReadableDirectory& sourceDir, kj::Path pathParam) + : loader(loader), file(kj::mv(file)), sourceDir(sourceDir), path(kj::mv(pathParam)), + sourceNameStr(path.toString()) { + KJ_REQUIRE(path.size() > 0); + } - kj::StringPtr getLocalName() { - return localName; + kj::PathPtr getPath() { + return path; } kj::StringPtr getSourceName() override { - return sourceName; + return sourceNameStr; } Orphan loadContent(Orphanage orphanage) override { - kj::Array content = mmapForRead(localName).releaseAsChars(); + kj::Array content = file->mmap(0, file->stat().size).releaseAsChars(); lineBreaks = nullptr; // In case loadContent() is called multiple times. lineBreaks = lineBreaksSpace.construct(content); @@ -257,23 +175,23 @@ public: lex(content, statements, *this); auto parsed = orphanage.newOrphan(); - parseFile(statements.getStatements(), parsed.get(), *this); + parseFile(statements.getStatements(), parsed.get(), *this, loader.areFileIdsRequired()); return parsed; } kj::Maybe importRelative(kj::StringPtr importPath) override { if (importPath.size() > 0 && importPath[0] == '/') { - return loader.loadModuleFromSearchPath(importPath.slice(1)); + return loader.loadModuleFromSearchPath(kj::Path::parse(importPath.slice(1))); } else { - return loader.loadModule(catPath(localName, importPath), catPath(sourceName, importPath)); + return loader.loadModule(sourceDir, path.parent().eval(importPath)); } } kj::Maybe> embedRelative(kj::StringPtr embedPath) override { if (embedPath.size() > 0 && embedPath[0] == '/') { - return loader.readEmbedFromSearchPath(embedPath.slice(1)); + return loader.readEmbedFromSearchPath(kj::Path::parse(embedPath.slice(1))); } else { - return loader.readEmbed(catPath(localName, embedPath), catPath(sourceName, embedPath)); + return loader.readEmbed(sourceDir, path.parent().eval(embedPath)); } } @@ -281,8 +199,8 @@ public: auto& lines = *KJ_REQUIRE_NONNULL(lineBreaks, "Can't report errors until loadContent() is called."); - loader.getErrorReporter().addError( - localName, lines.toSourcePos(startByte), lines.toSourcePos(endByte), message); + loader.getErrorReporter().addError(sourceDir, path, + lines.toSourcePos(startByte), lines.toSourcePos(endByte), message); } bool hadErrors() override { @@ -291,8 +209,10 @@ public: private: ModuleLoader::Impl& loader; - kj::String localName; - kj::String sourceName; + kj::Own file; + const kj::ReadableDirectory& sourceDir; + kj::Path path; + kj::String sourceNameStr; kj::SpaceFor lineBreaksSpace; kj::Maybe> lineBreaks; @@ -301,35 +221,34 @@ private: // ======================================================================================= kj::Maybe ModuleLoader::Impl::loadModule( - kj::StringPtr localName, kj::StringPtr sourceName) { - kj::String canonicalLocalName = canonicalizePath(localName); - kj::String canonicalSourceName = canonicalizePath(sourceName); - - auto iter = modules.find(canonicalLocalName); + const kj::ReadableDirectory& dir, kj::PathPtr path) { + auto iter = modules.find(FileKey(dir, path)); if (iter != modules.end()) { // Return existing file. return *iter->second; } - if (access(canonicalLocalName.cStr(), F_OK) < 0) { + KJ_IF_MAYBE(file, dir.tryOpenFile(path)) { + auto pathCopy = path.clone(); + auto key = FileKey(dir, pathCopy, **file); + auto module = kj::heap(*this, kj::mv(*file), dir, kj::mv(pathCopy)); + auto& result = *module; + auto insertResult = modules.insert(std::make_pair(key, kj::mv(module))); + if (insertResult.second) { + return result; + } else { + // Now that we have the file open, we noticed a collision. Return the old file. + return *insertResult.first->second; + } + } else { // No such file. return nullptr; } - - auto module = kj::heap( - *this, kj::mv(canonicalLocalName), kj::mv(canonicalSourceName)); - auto& result = *module; - modules.insert(std::make_pair(result.getLocalName(), kj::mv(module))); - return result; } -kj::Maybe ModuleLoader::Impl::loadModuleFromSearchPath(kj::StringPtr sourceName) { - for (auto& search: searchPath) { - kj::String candidate = kj::str(search, "/", sourceName); - char* end = canonicalizePath(candidate.begin() + (candidate[0] == '/')); - - KJ_IF_MAYBE(module, loadModule( - kj::heapString(candidate.slice(0, end - candidate.begin())), sourceName)) { +kj::Maybe ModuleLoader::Impl::loadModuleFromSearchPath(kj::PathPtr path) { + for (auto candidate: searchPath) { + KJ_IF_MAYBE(module, loadModule(*candidate, path)) { return *module; } } @@ -337,26 +256,16 @@ kj::Maybe ModuleLoader::Impl::loadModuleFromSearchPath(kj::StringPtr so } kj::Maybe> ModuleLoader::Impl::readEmbed( - kj::StringPtr localName, kj::StringPtr sourceName) { - kj::String canonicalLocalName = canonicalizePath(localName); - kj::String canonicalSourceName = canonicalizePath(sourceName); - - if (access(canonicalLocalName.cStr(), F_OK) < 0) { - // No such file. - return nullptr; + const kj::ReadableDirectory& dir, kj::PathPtr path) { + KJ_IF_MAYBE(file, dir.tryOpenFile(path)) { + return file->get()->mmap(0, file->get()->stat().size); } - - return mmapForRead(localName); + return nullptr; } -kj::Maybe> ModuleLoader::Impl::readEmbedFromSearchPath( - kj::StringPtr sourceName) { - for (auto& search: searchPath) { - kj::String candidate = kj::str(search, "/", sourceName); - char* end = canonicalizePath(candidate.begin() + (candidate[0] == '/')); - - KJ_IF_MAYBE(module, readEmbed( - kj::heapString(candidate.slice(0, end - candidate.begin())), sourceName)) { +kj::Maybe> ModuleLoader::Impl::readEmbedFromSearchPath(kj::PathPtr path) { + for (auto candidate: searchPath) { + KJ_IF_MAYBE(module, readEmbed(*candidate, path)) { return kj::mv(*module); } } @@ -369,10 +278,16 @@ ModuleLoader::ModuleLoader(GlobalErrorReporter& errorReporter) : impl(kj::heap(errorReporter)) {} ModuleLoader::~ModuleLoader() noexcept(false) {} -void ModuleLoader::addImportPath(kj::String path) { impl->addImportPath(kj::mv(path)); } +void ModuleLoader::addImportPath(const kj::ReadableDirectory& dir) { + impl->addImportPath(dir); +} + +kj::Maybe ModuleLoader::loadModule(const kj::ReadableDirectory& dir, kj::PathPtr path) { + return impl->loadModule(dir, path); +} -kj::Maybe ModuleLoader::loadModule(kj::StringPtr localName, kj::StringPtr sourceName) { - return impl->loadModule(localName, sourceName); +void ModuleLoader::setFileIdsRequired(bool value) { + return impl->setFileIdsRequired(value); } } // namespace compiler diff --git a/c++/src/capnp/compiler/module-loader.h b/c++/src/capnp/compiler/module-loader.h index 86e6db247b..da0d6daf29 100644 --- a/c++/src/capnp/compiler/module-loader.h +++ b/c++/src/capnp/compiler/module-loader.h @@ -19,18 +19,16 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_COMPILER_MODULE_LOADER_H_ -#define CAPNP_COMPILER_MODULE_LOADER_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "compiler.h" #include "error-reporter.h" #include #include #include +#include + +CAPNP_BEGIN_HEADER namespace capnp { namespace compiler { @@ -40,17 +38,20 @@ class ModuleLoader { explicit ModuleLoader(GlobalErrorReporter& errorReporter); // Create a ModuleLoader that reports error messages to the given reporter. - KJ_DISALLOW_COPY(ModuleLoader); + KJ_DISALLOW_COPY_AND_MOVE(ModuleLoader); ~ModuleLoader() noexcept(false); - void addImportPath(kj::String path); + void addImportPath(const kj::ReadableDirectory& dir); // Add a directory to the list of paths that is searched for imports that start with a '/'. - kj::Maybe loadModule(kj::StringPtr localName, kj::StringPtr sourceName); - // Tries to load the module with the given filename. `localName` is the path to the file on - // disk (as you'd pass to open(2)), and `sourceName` is the canonical name it should be given - // in the schema (this is used e.g. to decide output file locations). Often, these are the same. + kj::Maybe loadModule(const kj::ReadableDirectory& dir, kj::PathPtr path); + // Tries to load a module with the given path inside the given directory. Returns nullptr if the + // file doesn't exist. + + void setFileIdsRequired(bool value); + // Same as SchemaParser::setFileIdsRequired(). If set false, files will not be required to have + // a top-level file ID; if missing a random one will be assigned. private: class Impl; @@ -62,4 +63,4 @@ class ModuleLoader { } // namespace compiler } // namespace capnp -#endif // CAPNP_COMPILER_MODULE_LOADER_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/node-translator.c++ b/c++/src/capnp/compiler/node-translator.c++ index 7b123d13c6..797cfd7dd0 100644 --- a/c++/src/capnp/compiler/node-translator.c++ +++ b/c++/src/capnp/compiler/node-translator.c++ @@ -20,13 +20,15 @@ // THE SOFTWARE. #include "node-translator.h" -#include "parser.h" // only for generateGroupId() +#include "parser.h" // only for generateGroupId() and expressionString() #include #include #include +#include #include #include #include +#include namespace capnp { namespace compiler { @@ -108,7 +110,7 @@ public: // from the given offset. The idea is that you just allocated an lgSize-sized field from // an limitLgSize-sized space, such as a newly-added word on the end of the data segment. - KJ_DREQUIRE(limitLgSize <= kj::size(holes)); + KJ_ASSUME(limitLgSize <= kj::size(holes)); while (lgSize < limitLgSize) { KJ_DREQUIRE(holes[lgSize] == 0); @@ -128,6 +130,11 @@ public: // No expansion requested. return true; } + if (oldLgSize == kj::size(holes)) { + // Old value is already a full word. Further expansion is impossible. + return false; + } + KJ_ASSERT(oldLgSize < kj::size(holes)); if (holes[oldLgSize] != oldOffset + 1) { // The space immediately after the location is not a hole. return false; @@ -210,7 +217,7 @@ public: } Top() = default; - KJ_DISALLOW_COPY(Top); + KJ_DISALLOW_COPY_AND_MOVE(Top); }; struct Union { @@ -238,7 +245,7 @@ public: kj::Vector pointerLocations; inline Union(StructOrGroup& parent): parent(parent) {} - KJ_DISALLOW_COPY(Union); + KJ_DISALLOW_COPY_AND_MOVE(Union); uint addNewDataLocation(uint lgSize) { // Add a whole new data location to the union with the given size. @@ -422,11 +429,11 @@ public: // in cases involving unions nested in other unions. The bug could lead to multiple // fields in a group incorrectly being assigned overlapping offsets. Although the bug // is now fixed by adding the `newHoles` parameter, this silently breaks - // backwards-compatibilty with affected schemas. Therefore, for now, we throw an + // backwards-compatibility with affected schemas. Therefore, for now, we throw an // exception to alert developers of the problem. // // TODO(cleanup): Once sufficient time has elapsed, remove this assert. - KJ_FAIL_ASSERT("Bad news: Cap'n Proto 0.5.x and previous contained a bug which would cause this schema to be compiled incorrectly. Please see: https://github.com/sandstorm-io/capnproto/issues/344"); + KJ_FAIL_ASSERT("Bad news: Cap'n Proto 0.5.x and previous contained a bug which would cause this schema to be compiled incorrectly. Please see: https://github.com/capnproto/capnproto/issues/344"); } lgSizeUsed = desiredUsage; return true; @@ -445,7 +452,7 @@ public: bool hasMembers = false; inline Group(Union& parent): parent(parent) {} - KJ_DISALLOW_COPY(Group); + KJ_DISALLOW_COPY_AND_MOVE(Group); void addMember() { if (!hasMembers) { @@ -552,7 +559,7 @@ public: bool result = usage.tryExpand( *this, location, oldLgSize, localOldOffset, expansionFactor); if (mustFail && result) { - KJ_FAIL_ASSERT("Bad news: Cap'n Proto 0.5.x and previous contained a bug which would cause this schema to be compiled incorrectly. Please see: https://github.com/sandstorm-io/capnproto/issues/344"); + KJ_FAIL_ASSERT("Bad news: Cap'n Proto 0.5.x and previous contained a bug which would cause this schema to be compiled incorrectly. Please see: https://github.com/capnproto/capnproto/issues/344"); } return result; } @@ -571,802 +578,6 @@ private: // ======================================================================================= -class NodeTranslator::BrandedDecl { - // Represents a declaration possibly with generic parameter bindings. - // - // TODO(cleaup): This is too complicated to live here. We should refactor this class and - // BrandScope out into their own file, independent of NodeTranslator. - -public: - inline BrandedDecl(Resolver::ResolvedDecl decl, - kj::Own&& brand, - Expression::Reader source) - : brand(kj::mv(brand)), source(source) { - body.init(kj::mv(decl)); - } - inline BrandedDecl(Resolver::ResolvedParameter variable, Expression::Reader source) - : source(source) { - body.init(kj::mv(variable)); - } - inline BrandedDecl(decltype(nullptr)) {} - - static BrandedDecl implicitMethodParam(uint index) { - // Get a BrandedDecl referring to an implicit method parameter. - // (As a hack, we internally represent this as a ResolvedParameter. Sorry.) - return BrandedDecl(Resolver::ResolvedParameter { 0, index }, Expression::Reader()); - } - - BrandedDecl(BrandedDecl& other); - BrandedDecl(BrandedDecl&& other) = default; - - BrandedDecl& operator=(BrandedDecl& other); - BrandedDecl& operator=(BrandedDecl&& other) = default; - - // TODO(cleanup): A lot of the methods below are actually only called within compileAsType(), - // which was originally a method on NodeTranslator, but now is a method here and thus doesn't - // need these to be public. We should privatize most of these. - - kj::Maybe applyParams(kj::Array params, Expression::Reader subSource); - // Treat the declaration as a generic and apply it to the given parameter list. - - kj::Maybe getMember(kj::StringPtr memberName, Expression::Reader subSource); - // Get a member of this declaration. - - kj::Maybe getKind(); - // Returns the kind of declaration, or null if this is an unbound generic variable. - - template - uint64_t getIdAndFillBrand(InitBrandFunc&& initBrand); - // Returns the type ID of this node. `initBrand` is a zero-arg functor which returns - // schema::Brand::Builder; this will be called if this decl has brand bindings, and - // the returned builder filled in to reflect those bindings. - // - // It is an error to call this when `getKind()` returns null. - - kj::Maybe getListParam(); - // Only if the kind is BUILTIN_LIST: Get the list's type parameter. - - Resolver::ResolvedParameter asVariable(); - // If this is an unbound generic variable (i.e. `getKind()` returns null), return information - // about the variable. - // - // It is an error to call this when `getKind()` does not return null. - - bool compileAsType(ErrorReporter& errorReporter, schema::Type::Builder target); - // Compile this decl to a schema::Type. - - inline void addError(ErrorReporter& errorReporter, kj::StringPtr message) { - errorReporter.addErrorOn(source, message); - } - - Resolver::ResolveResult asResolveResult(uint64_t scopeId, schema::Brand::Builder brandBuilder); - // Reverse this into a ResolveResult. If necessary, use `brandBuilder` to fill in - // ResolvedDecl.brand. - - kj::String toString(); - kj::String toDebugString(); - -private: - Resolver::ResolveResult body; - kj::Own brand; // null if parameter - Expression::Reader source; -}; - -class NodeTranslator::BrandScope: public kj::Refcounted { - // Tracks the brand parameter bindings affecting the current scope. For example, if we are - // interpreting the type expression "Foo(Text).Bar", we would start with the current scopes - // BrandScope, create a new child BrandScope representing "Foo", add the "(Text)" parameter - // bindings to it, then create a further child scope for "Bar". Thus the BrandScope for Bar - // knows that Foo's parameter list has been bound to "(Text)". - // - // TODO(cleanup): This is too complicated to live here. We should refactor this class and - // BrandedDecl out into their own file, independent of NodeTranslator. - -public: - BrandScope(ErrorReporter& errorReporter, uint64_t startingScopeId, - uint startingScopeParamCount, Resolver& startingScope) - : errorReporter(errorReporter), parent(nullptr), leafId(startingScopeId), - leafParamCount(startingScopeParamCount), inherited(true) { - // Create all lexical parent scopes, all with no brand bindings. - KJ_IF_MAYBE(p, startingScope.getParent()) { - parent = kj::refcounted( - errorReporter, p->id, p->genericParamCount, *p->resolver); - } - } - - bool isGeneric() { - if (leafParamCount > 0) return true; - - KJ_IF_MAYBE(p, parent) { - return p->get()->isGeneric(); - } else { - return false; - } - } - - kj::Own push(uint64_t typeId, uint paramCount) { - return kj::refcounted(kj::addRef(*this), typeId, paramCount); - } - - kj::Maybe> setParams( - kj::Array params, Declaration::Which genericType, Expression::Reader source) { - if (this->params.size() != 0) { - errorReporter.addErrorOn(source, "Double-application of generic parameters."); - return nullptr; - } else if (params.size() > leafParamCount) { - if (leafParamCount == 0) { - errorReporter.addErrorOn(source, "Declaration does not accept generic parameters."); - } else { - errorReporter.addErrorOn(source, "Too many generic parameters."); - } - return nullptr; - } else if (params.size() < leafParamCount) { - errorReporter.addErrorOn(source, "Not enough generic parameters."); - return nullptr; - } else { - if (genericType != Declaration::BUILTIN_LIST) { - for (auto& param: params) { - KJ_IF_MAYBE(kind, param.getKind()) { - switch (*kind) { - case Declaration::BUILTIN_LIST: - case Declaration::BUILTIN_TEXT: - case Declaration::BUILTIN_DATA: - case Declaration::BUILTIN_ANY_POINTER: - case Declaration::STRUCT: - case Declaration::INTERFACE: - break; - - default: - param.addError(errorReporter, - "Sorry, only pointer types can be used as generic parameters."); - break; - } - } - } - } - - return kj::refcounted(*this, kj::mv(params)); - } - } - - kj::Own pop(uint64_t newLeafId) { - if (leafId == newLeafId) { - return kj::addRef(*this); - } - KJ_IF_MAYBE(p, parent) { - return (*p)->pop(newLeafId); - } else { - // Looks like we're moving into a whole top-level scope. - return kj::refcounted(errorReporter, newLeafId); - } - } - - kj::Maybe lookupParameter(Resolver& resolver, uint64_t scopeId, uint index) { - // Returns null if the param should be inherited from the client scope. - - if (scopeId == leafId) { - if (index < params.size()) { - return params[index]; - } else if (inherited) { - return nullptr; - } else { - // Unbound and not inherited, so return AnyPointer. - auto decl = resolver.resolveBuiltin(Declaration::BUILTIN_ANY_POINTER); - return BrandedDecl(decl, - evaluateBrand(resolver, decl, List::Reader()), - Expression::Reader()); - } - } else KJ_IF_MAYBE(p, parent) { - return p->get()->lookupParameter(resolver, scopeId, index); - } else { - KJ_FAIL_REQUIRE("scope is not a parent"); - } - } - - kj::Maybe> getParams(uint64_t scopeId) { - // Returns null if params at the requested scope should be inherited from the client scope. - - if (scopeId == leafId) { - if (inherited) { - return nullptr; - } else { - return params.asPtr(); - } - } else KJ_IF_MAYBE(p, parent) { - return p->get()->getParams(scopeId); - } else { - KJ_FAIL_REQUIRE("scope is not a parent"); - } - } - - template - void compile(InitBrandFunc&& initBrand) { - kj::Vector levels; - BrandScope* ptr = this; - for (;;) { - if (ptr->params.size() > 0 || (ptr->inherited && ptr->leafParamCount > 0)) { - levels.add(ptr); - } - KJ_IF_MAYBE(p, ptr->parent) { - ptr = *p; - } else { - break; - } - } - - if (levels.size() > 0) { - auto scopes = initBrand().initScopes(levels.size()); - for (uint i: kj::indices(levels)) { - auto scope = scopes[i]; - scope.setScopeId(levels[i]->leafId); - - if (levels[i]->inherited) { - scope.setInherit(); - } else { - auto bindings = scope.initBind(levels[i]->params.size()); - for (uint j: kj::indices(bindings)) { - levels[i]->params[j].compileAsType(errorReporter, bindings[j].initType()); - } - } - } - } - } - - kj::Maybe compileDeclExpression( - Expression::Reader source, Resolver& resolver, - ImplicitParams implicitMethodParams); - - NodeTranslator::BrandedDecl interpretResolve( - Resolver& resolver, Resolver::ResolveResult& result, Expression::Reader source); - - kj::Own evaluateBrand( - Resolver& resolver, Resolver::ResolvedDecl decl, - List::Reader brand, uint index = 0); - - BrandedDecl decompileType(Resolver& resolver, schema::Type::Reader type); - - inline uint64_t getScopeId() { return leafId; } - -private: - ErrorReporter& errorReporter; - kj::Maybe> parent; - uint64_t leafId; // zero = this is the root - uint leafParamCount; // number of generic parameters on this leaf - bool inherited; - kj::Array params; - - BrandScope(kj::Own parent, uint64_t leafId, uint leafParamCount) - : errorReporter(parent->errorReporter), - parent(kj::mv(parent)), leafId(leafId), leafParamCount(leafParamCount), - inherited(false) {} - BrandScope(BrandScope& base, kj::Array params) - : errorReporter(base.errorReporter), - leafId(base.leafId), leafParamCount(base.leafParamCount), - inherited(false), params(kj::mv(params)) { - KJ_IF_MAYBE(p, base.parent) { - parent = kj::addRef(**p); - } - } - BrandScope(ErrorReporter& errorReporter, uint64_t scopeId) - : errorReporter(errorReporter), leafId(scopeId), leafParamCount(0), inherited(false) {} - - template - friend kj::Own kj::refcounted(Params&&... params); -}; - -NodeTranslator::BrandedDecl::BrandedDecl(BrandedDecl& other) - : body(other.body), - source(other.source) { - if (body.is()) { - brand = kj::addRef(*other.brand); - } -} - -NodeTranslator::BrandedDecl& NodeTranslator::BrandedDecl::operator=(BrandedDecl& other) { - body = other.body; - source = other.source; - if (body.is()) { - brand = kj::addRef(*other.brand); - } - return *this; -} - -kj::Maybe NodeTranslator::BrandedDecl::applyParams( - kj::Array params, Expression::Reader subSource) { - if (body.is()) { - return nullptr; - } else { - return brand->setParams(kj::mv(params), body.get().kind, subSource) - .map([&](kj::Own&& scope) { - BrandedDecl result = *this; - result.brand = kj::mv(scope); - result.source = subSource; - return result; - }); - } -} - -kj::Maybe NodeTranslator::BrandedDecl::getMember( - kj::StringPtr memberName, Expression::Reader subSource) { - if (body.is()) { - return nullptr; - } else KJ_IF_MAYBE(r, body.get().resolver->resolveMember(memberName)) { - return brand->interpretResolve(*body.get().resolver, *r, subSource); - } else { - return nullptr; - } -} - -kj::Maybe NodeTranslator::BrandedDecl::getKind() { - if (body.is()) { - return nullptr; - } else { - return body.get().kind; - } -} - -template -uint64_t NodeTranslator::BrandedDecl::getIdAndFillBrand(InitBrandFunc&& initBrand) { - KJ_REQUIRE(body.is()); - - brand->compile(kj::fwd(initBrand)); - return body.get().id; -} - -kj::Maybe NodeTranslator::BrandedDecl::getListParam() { - KJ_REQUIRE(body.is()); - - auto& decl = body.get(); - KJ_REQUIRE(decl.kind == Declaration::BUILTIN_LIST); - - auto params = KJ_ASSERT_NONNULL(brand->getParams(decl.id)); - if (params.size() != 1) { - return nullptr; - } else { - return params[0]; - } -} - -NodeTranslator::Resolver::ResolvedParameter NodeTranslator::BrandedDecl::asVariable() { - KJ_REQUIRE(body.is()); - - return body.get(); -} - -bool NodeTranslator::BrandedDecl::compileAsType( - ErrorReporter& errorReporter, schema::Type::Builder target) { - KJ_IF_MAYBE(kind, getKind()) { - switch (*kind) { - case Declaration::ENUM: { - auto enum_ = target.initEnum(); - enum_.setTypeId(getIdAndFillBrand([&]() { return enum_.initBrand(); })); - return true; - } - - case Declaration::STRUCT: { - auto struct_ = target.initStruct(); - struct_.setTypeId(getIdAndFillBrand([&]() { return struct_.initBrand(); })); - return true; - } - - case Declaration::INTERFACE: { - auto interface = target.initInterface(); - interface.setTypeId(getIdAndFillBrand([&]() { return interface.initBrand(); })); - return true; - } - - case Declaration::BUILTIN_LIST: { - auto elementType = target.initList().initElementType(); - - KJ_IF_MAYBE(param, getListParam()) { - if (!param->compileAsType(errorReporter, elementType)) { - return false; - } - } else { - addError(errorReporter, "'List' requires exactly one parameter."); - return false; - } - - if (elementType.isAnyPointer()) { - addError(errorReporter, "'List(AnyPointer)' is not supported."); - // Seeing List(AnyPointer) later can mess things up, so change the type to Void. - elementType.setVoid(); - return false; - } - - return true; - } - - case Declaration::BUILTIN_VOID: target.setVoid(); return true; - case Declaration::BUILTIN_BOOL: target.setBool(); return true; - case Declaration::BUILTIN_INT8: target.setInt8(); return true; - case Declaration::BUILTIN_INT16: target.setInt16(); return true; - case Declaration::BUILTIN_INT32: target.setInt32(); return true; - case Declaration::BUILTIN_INT64: target.setInt64(); return true; - case Declaration::BUILTIN_U_INT8: target.setUint8(); return true; - case Declaration::BUILTIN_U_INT16: target.setUint16(); return true; - case Declaration::BUILTIN_U_INT32: target.setUint32(); return true; - case Declaration::BUILTIN_U_INT64: target.setUint64(); return true; - case Declaration::BUILTIN_FLOAT32: target.setFloat32(); return true; - case Declaration::BUILTIN_FLOAT64: target.setFloat64(); return true; - case Declaration::BUILTIN_TEXT: target.setText(); return true; - case Declaration::BUILTIN_DATA: target.setData(); return true; - - case Declaration::BUILTIN_OBJECT: - addError(errorReporter, - "As of Cap'n Proto 0.4, 'Object' has been renamed to 'AnyPointer'. Sorry for the " - "inconvenience, and thanks for being an early adopter. :)"); - // no break - case Declaration::BUILTIN_ANY_POINTER: - target.initAnyPointer().initUnconstrained().setAnyKind(); - return true; - case Declaration::BUILTIN_ANY_STRUCT: - target.initAnyPointer().initUnconstrained().setStruct(); - return true; - case Declaration::BUILTIN_ANY_LIST: - target.initAnyPointer().initUnconstrained().setList(); - return true; - case Declaration::BUILTIN_CAPABILITY: - target.initAnyPointer().initUnconstrained().setCapability(); - return true; - - case Declaration::FILE: - case Declaration::USING: - case Declaration::CONST: - case Declaration::ENUMERANT: - case Declaration::FIELD: - case Declaration::UNION: - case Declaration::GROUP: - case Declaration::METHOD: - case Declaration::ANNOTATION: - case Declaration::NAKED_ID: - case Declaration::NAKED_ANNOTATION: - addError(errorReporter, kj::str("'", toString(), "' is not a type.")); - return false; - } - - KJ_UNREACHABLE; - } else { - // Oh, this is a type variable. - auto var = asVariable(); - if (var.id == 0) { - // This is actually a method implicit parameter. - auto builder = target.initAnyPointer().initImplicitMethodParameter(); - builder.setParameterIndex(var.index); - return true; - } else { - auto builder = target.initAnyPointer().initParameter(); - builder.setScopeId(var.id); - builder.setParameterIndex(var.index); - return true; - } - } -} - -NodeTranslator::Resolver::ResolveResult NodeTranslator::BrandedDecl::asResolveResult( - uint64_t scopeId, schema::Brand::Builder brandBuilder) { - auto result = body; - if (result.is()) { - // May need to compile our context as the "brand". - - result.get().scopeId = scopeId; - - getIdAndFillBrand([&]() { - result.get().brand = brandBuilder.asReader(); - return brandBuilder; - }); - } - return result; -} - -static kj::String expressionString(Expression::Reader name); // defined later - -kj::String NodeTranslator::BrandedDecl::toString() { - return expressionString(source); -} - -kj::String NodeTranslator::BrandedDecl::toDebugString() { - if (body.is()) { - auto variable = body.get(); - return kj::str("varibale(", variable.id, ", ", variable.index, ")"); - } else { - auto decl = body.get(); - return kj::str("decl(", decl.id, ", ", (uint)decl.kind, "')"); - } -} - -NodeTranslator::BrandedDecl NodeTranslator::BrandScope::interpretResolve( - Resolver& resolver, Resolver::ResolveResult& result, Expression::Reader source) { - if (result.is()) { - auto& decl = result.get(); - - auto scope = pop(decl.scopeId); - KJ_IF_MAYBE(brand, decl.brand) { - scope = scope->evaluateBrand(resolver, decl, brand->getScopes()); - } else { - scope = scope->push(decl.id, decl.genericParamCount); - } - - return BrandedDecl(decl, kj::mv(scope), source); - } else { - auto& param = result.get(); - KJ_IF_MAYBE(p, lookupParameter(resolver, param.id, param.index)) { - return *p; - } else { - return BrandedDecl(param, source); - } - } -} - -kj::Own NodeTranslator::BrandScope::evaluateBrand( - Resolver& resolver, Resolver::ResolvedDecl decl, - List::Reader brand, uint index) { - auto result = kj::refcounted(errorReporter, decl.id); - result->leafParamCount = decl.genericParamCount; - - // Fill in `params`. - if (index < brand.size()) { - auto nextScope = brand[index]; - if (decl.id == nextScope.getScopeId()) { - // Initialize our parameters. - - switch (nextScope.which()) { - case schema::Brand::Scope::BIND: { - auto bindings = nextScope.getBind(); - auto params = kj::heapArrayBuilder(bindings.size()); - for (auto binding: bindings) { - switch (binding.which()) { - case schema::Brand::Binding::UNBOUND: { - // Build an AnyPointer-equivalent. - auto anyPointerDecl = resolver.resolveBuiltin(Declaration::BUILTIN_ANY_POINTER); - params.add(BrandedDecl(anyPointerDecl, - kj::refcounted(errorReporter, anyPointerDecl.scopeId), - Expression::Reader())); - break; - } - - case schema::Brand::Binding::TYPE: - // Reverse this schema::Type back into a BrandedDecl. - params.add(decompileType(resolver, binding.getType())); - break; - } - } - result->params = params.finish(); - break; - } - - case schema::Brand::Scope::INHERIT: - KJ_IF_MAYBE(p, getParams(decl.id)) { - result->params = kj::heapArray(*p); - } else { - result->inherited = true; - } - break; - } - - // Parent should start one level deeper in the list. - ++index; - } - } - - // Fill in `parent`. - KJ_IF_MAYBE(parent, decl.resolver->getParent()) { - result->parent = evaluateBrand(resolver, *parent, brand, index); - } - - return result; -} - -NodeTranslator::BrandedDecl NodeTranslator::BrandScope::decompileType( - Resolver& resolver, schema::Type::Reader type) { - auto builtin = [&](Declaration::Which which) -> BrandedDecl { - auto decl = resolver.resolveBuiltin(which); - return BrandedDecl(decl, - evaluateBrand(resolver, decl, List::Reader()), - Expression::Reader()); - }; - - switch (type.which()) { - case schema::Type::VOID: return builtin(Declaration::BUILTIN_VOID); - case schema::Type::BOOL: return builtin(Declaration::BUILTIN_BOOL); - case schema::Type::INT8: return builtin(Declaration::BUILTIN_INT8); - case schema::Type::INT16: return builtin(Declaration::BUILTIN_INT16); - case schema::Type::INT32: return builtin(Declaration::BUILTIN_INT32); - case schema::Type::INT64: return builtin(Declaration::BUILTIN_INT64); - case schema::Type::UINT8: return builtin(Declaration::BUILTIN_U_INT8); - case schema::Type::UINT16: return builtin(Declaration::BUILTIN_U_INT16); - case schema::Type::UINT32: return builtin(Declaration::BUILTIN_U_INT32); - case schema::Type::UINT64: return builtin(Declaration::BUILTIN_U_INT64); - case schema::Type::FLOAT32: return builtin(Declaration::BUILTIN_FLOAT32); - case schema::Type::FLOAT64: return builtin(Declaration::BUILTIN_FLOAT64); - case schema::Type::TEXT: return builtin(Declaration::BUILTIN_TEXT); - case schema::Type::DATA: return builtin(Declaration::BUILTIN_DATA); - - case schema::Type::ENUM: { - auto enumType = type.getEnum(); - Resolver::ResolvedDecl decl = resolver.resolveId(enumType.getTypeId()); - return BrandedDecl(decl, - evaluateBrand(resolver, decl, enumType.getBrand().getScopes()), - Expression::Reader()); - } - - case schema::Type::INTERFACE: { - auto interfaceType = type.getInterface(); - Resolver::ResolvedDecl decl = resolver.resolveId(interfaceType.getTypeId()); - return BrandedDecl(decl, - evaluateBrand(resolver, decl, interfaceType.getBrand().getScopes()), - Expression::Reader()); - } - - case schema::Type::STRUCT: { - auto structType = type.getStruct(); - Resolver::ResolvedDecl decl = resolver.resolveId(structType.getTypeId()); - return BrandedDecl(decl, - evaluateBrand(resolver, decl, structType.getBrand().getScopes()), - Expression::Reader()); - } - - case schema::Type::LIST: { - auto elementType = decompileType(resolver, type.getList().getElementType()); - return KJ_ASSERT_NONNULL(builtin(Declaration::BUILTIN_LIST) - .applyParams(kj::heapArray(&elementType, 1), Expression::Reader())); - } - - case schema::Type::ANY_POINTER: { - auto anyPointer = type.getAnyPointer(); - switch (anyPointer.which()) { - case schema::Type::AnyPointer::UNCONSTRAINED: - return builtin(Declaration::BUILTIN_ANY_POINTER); - - case schema::Type::AnyPointer::PARAMETER: { - auto param = anyPointer.getParameter(); - auto id = param.getScopeId(); - uint index = param.getParameterIndex(); - KJ_IF_MAYBE(binding, lookupParameter(resolver, id, index)) { - return *binding; - } else { - return BrandedDecl(Resolver::ResolvedParameter {id, index}, Expression::Reader()); - } - } - - case schema::Type::AnyPointer::IMPLICIT_METHOD_PARAMETER: - KJ_FAIL_ASSERT("Alias pointed to implicit method type parameter?"); - } - - KJ_UNREACHABLE; - } - } - - KJ_UNREACHABLE; -} - -kj::Maybe NodeTranslator::BrandScope::compileDeclExpression( - Expression::Reader source, Resolver& resolver, - ImplicitParams implicitMethodParams) { - switch (source.which()) { - case Expression::UNKNOWN: - // Error reported earlier. - return nullptr; - - case Expression::POSITIVE_INT: - case Expression::NEGATIVE_INT: - case Expression::FLOAT: - case Expression::STRING: - case Expression::BINARY: - case Expression::LIST: - case Expression::TUPLE: - case Expression::EMBED: - errorReporter.addErrorOn(source, "Expected name."); - return nullptr; - - case Expression::RELATIVE_NAME: { - auto name = source.getRelativeName(); - auto nameValue = name.getValue(); - - // Check implicit method params first. - for (auto i: kj::indices(implicitMethodParams.params)) { - if (implicitMethodParams.params[i].getName() == nameValue) { - if (implicitMethodParams.scopeId == 0) { - return BrandedDecl::implicitMethodParam(i); - } else { - return BrandedDecl(Resolver::ResolvedParameter { - implicitMethodParams.scopeId, static_cast(i) }, - Expression::Reader()); - } - } - } - - KJ_IF_MAYBE(r, resolver.resolve(nameValue)) { - return interpretResolve(resolver, *r, source); - } else { - errorReporter.addErrorOn(name, kj::str("Not defined: ", nameValue)); - return nullptr; - } - } - - case Expression::ABSOLUTE_NAME: { - auto name = source.getAbsoluteName(); - KJ_IF_MAYBE(r, resolver.getTopScope().resolver->resolveMember(name.getValue())) { - return interpretResolve(resolver, *r, source); - } else { - errorReporter.addErrorOn(name, kj::str("Not defined: ", name.getValue())); - return nullptr; - } - } - - case Expression::IMPORT: { - auto filename = source.getImport(); - KJ_IF_MAYBE(decl, resolver.resolveImport(filename.getValue())) { - // Import is always a root scope, so create a fresh BrandScope. - return BrandedDecl(*decl, kj::refcounted( - errorReporter, decl->id, decl->genericParamCount, *decl->resolver), source); - } else { - errorReporter.addErrorOn(filename, kj::str("Import failed: ", filename.getValue())); - return nullptr; - } - } - - case Expression::APPLICATION: { - auto app = source.getApplication(); - KJ_IF_MAYBE(decl, compileDeclExpression(app.getFunction(), resolver, implicitMethodParams)) { - // Compile all params. - auto params = app.getParams(); - auto compiledParams = kj::heapArrayBuilder(params.size()); - bool paramFailed = false; - for (auto param: params) { - if (param.isNamed()) { - errorReporter.addErrorOn(param.getNamed(), "Named parameter not allowed here."); - } - - KJ_IF_MAYBE(d, compileDeclExpression(param.getValue(), resolver, implicitMethodParams)) { - compiledParams.add(kj::mv(*d)); - } else { - // Param failed to compile. Error was already reported. - paramFailed = true; - } - }; - - if (paramFailed) { - return kj::mv(*decl); - } - - // Add the parameters to the brand. - KJ_IF_MAYBE(applied, decl->applyParams(compiledParams.finish(), source)) { - return kj::mv(*applied); - } else { - // Error already reported. Ignore parameters. - return kj::mv(*decl); - } - } else { - // error already reported - return nullptr; - } - } - - case Expression::MEMBER: { - auto member = source.getMember(); - KJ_IF_MAYBE(decl, compileDeclExpression(member.getParent(), resolver, implicitMethodParams)) { - auto name = member.getName(); - KJ_IF_MAYBE(memberDecl, decl->getMember(name.getValue(), source)) { - return kj::mv(*memberDecl); - } else { - errorReporter.addErrorOn(name, kj::str( - "'", expressionString(member.getParent()), - "' has no member named '", name.getValue(), "'")); - return nullptr; - } - } else { - // error already reported - return nullptr; - } - } - } - - KJ_UNREACHABLE; -} - -// ======================================================================================= - NodeTranslator::NodeTranslator( Resolver& resolver, ErrorReporter& errorReporter, const Declaration::Reader& decl, Orphan wipNodeParam, @@ -1377,33 +588,47 @@ NodeTranslator::NodeTranslator( localBrand(kj::refcounted( errorReporter, wipNodeParam.getReader().getId(), decl.getParameters().size(), resolver)), - wipNode(kj::mv(wipNodeParam)) { + wipNode(kj::mv(wipNodeParam)), + sourceInfo(orphanage.newOrphan()) { compileNode(decl, wipNode.get()); } NodeTranslator::~NodeTranslator() noexcept(false) {} NodeTranslator::NodeSet NodeTranslator::getBootstrapNode() { + auto sourceInfos = kj::heapArrayBuilder( + 1 + groups.size() + paramStructs.size()); + sourceInfos.add(sourceInfo.getReader()); + for (auto& group: groups) { + sourceInfos.add(group.sourceInfo.getReader()); + } + for (auto& paramStruct: paramStructs) { + sourceInfos.add(paramStruct.sourceInfo.getReader()); + } + auto nodeReader = wipNode.getReader(); if (nodeReader.isInterface()) { return NodeSet { nodeReader, - KJ_MAP(g, paramStructs) { return g.getReader(); } + KJ_MAP(g, paramStructs) { return g.node.getReader(); }, + sourceInfos.finish() }; } else { return NodeSet { nodeReader, - KJ_MAP(g, groups) { return g.getReader(); } + KJ_MAP(g, groups) { return g.node.getReader(); }, + sourceInfos.finish() }; } } -NodeTranslator::NodeSet NodeTranslator::finish() { +NodeTranslator::NodeSet NodeTranslator::finish(Schema selfBootstrapSchema) { // Careful about iteration here: compileFinalValue() may actually add more elements to // `unfinishedValues`, invalidating iterators in the process. for (size_t i = 0; i < unfinishedValues.size(); i++) { auto& value = unfinishedValues[i]; - compileValue(value.source, value.type, value.typeScope, value.target, false); + compileValue(value.source, value.type, value.typeScope.orDefault(selfBootstrapSchema), + value.target, false); } return getBootstrapNode(); @@ -1467,10 +692,15 @@ void NodeTranslator::compileNode(Declaration::Reader decl, schema::Node::Builder } builder.adoptAnnotations(compileAnnotationApplications(decl.getAnnotations(), targetsFlagName)); + + auto di = sourceInfo.get(); + di.setId(wipNode.getReader().getId()); + if (decl.hasDocComment()) { + di.setDocComment(decl.getDocComment()); + } } static kj::StringPtr getExpressionTargetName(Expression::Reader exp) { - kj::StringPtr targetName; switch (exp.which()) { case Expression::ABSOLUTE_NAME: return exp.getAbsoluteName().getValue(); @@ -1628,14 +858,14 @@ void NodeTranslator::DuplicateNameDetector::check( void NodeTranslator::compileConst(Declaration::Const::Reader decl, schema::Node::Const::Builder builder) { auto typeBuilder = builder.initType(); - if (compileType(decl.getType(), typeBuilder, noImplicitParams())) { + if (compileType(decl.getType(), typeBuilder, ImplicitParams::none())) { compileBootstrapValue(decl.getValue(), typeBuilder.asReader(), builder.initValue()); } } void NodeTranslator::compileAnnotation(Declaration::Annotation::Reader decl, schema::Node::Annotation::Builder builder) { - compileType(decl.getType(), builder.initType(), noImplicitParams()); + compileType(decl.getType(), builder.initType(), ImplicitParams::none()); // Dynamically copy over the values of all of the "targets" members. DynamicStruct::Reader src = decl; @@ -1695,6 +925,7 @@ void NodeTranslator::compileEnum(Void decl, } auto list = builder.initEnum().initEnumerants(enumerants.size()); + auto sourceInfoList = sourceInfo.get().initMembers(enumerants.size()); uint i = 0; DuplicateOrdinalDetector dupDetector(errorReporter); @@ -1704,6 +935,10 @@ void NodeTranslator::compileEnum(Void decl, dupDetector.check(enumerantDecl.getId().getOrdinal()); + if (enumerantDecl.hasDocComment()) { + sourceInfoList[i].setDocComment(enumerantDecl.getDocComment()); + } + auto enumerantBuilder = list[i++]; enumerantBuilder.setName(enumerantDecl.getName().getValue()); enumerantBuilder.setCodeOrder(codeOrder); @@ -1719,18 +954,20 @@ public: explicit StructTranslator(NodeTranslator& translator, ImplicitParams implicitMethodParams) : translator(translator), errorReporter(translator.errorReporter), implicitMethodParams(implicitMethodParams) {} - KJ_DISALLOW_COPY(StructTranslator); + KJ_DISALLOW_COPY_AND_MOVE(StructTranslator); - void translate(Void decl, List::Reader members, schema::Node::Builder builder) { + void translate(Void decl, List::Reader members, schema::Node::Builder builder, + schema::Node::SourceInfo::Builder sourceInfo) { // Build the member-info-by-ordinal map. - MemberInfo root(builder); + MemberInfo root(builder, sourceInfo); traverseTopOrGroup(members, root, layout.getTop()); translateInternal(root, builder); } - void translate(List::Reader params, schema::Node::Builder builder) { + void translate(List::Reader params, schema::Node::Builder builder, + schema::Node::SourceInfo::Builder sourceInfo) { // Build a struct from a method param / result list. - MemberInfo root(builder); + MemberInfo root(builder, sourceInfo); traverseParams(params, root, layout.getTop()); translateInternal(root, builder); } @@ -1742,6 +979,16 @@ private: StructLayout layout; kj::Arena arena; + struct NodeSourceInfoBuilderPair { + schema::Node::Builder node; + schema::Node::SourceInfo::Builder sourceInfo; + }; + + struct FieldSourceInfoBuilderPair { + schema::Field::Builder field; + schema::Node::SourceInfo::Member::Builder sourceInfo; + }; + struct MemberInfo { MemberInfo* parent; // The MemberInfo for the parent scope. @@ -1779,10 +1026,13 @@ private: // Information about the field declaration. We don't use Declaration::Reader because it might // have come from a Declaration::Param instead. + kj::Maybe docComment = nullptr; + kj::Maybe schema; // Schema for the field. Initialized when getSchema() is first called. schema::Node::Builder node; + schema::Node::SourceInfo::Builder sourceInfo; // If it's a group, or the top-level struct. union { @@ -1797,8 +1047,10 @@ private: // copy over the discriminant offset to the schema. }; - inline explicit MemberInfo(schema::Node::Builder node) - : parent(nullptr), codeOrder(0), isInUnion(false), node(node), unionScope(nullptr) {} + inline explicit MemberInfo(schema::Node::Builder node, + schema::Node::SourceInfo::Builder sourceInfo) + : parent(nullptr), codeOrder(0), isInUnion(false), node(node), sourceInfo(sourceInfo), + unionScope(nullptr) {} inline MemberInfo(MemberInfo& parent, uint codeOrder, const Declaration::Reader& decl, StructLayout::StructOrGroup& fieldScope, @@ -1807,7 +1059,7 @@ private: name(decl.getName().getValue()), declId(decl.getId()), declKind(Declaration::FIELD), declAnnotations(decl.getAnnotations()), startByte(decl.getStartByte()), endByte(decl.getEndByte()), - node(nullptr), fieldScope(&fieldScope) { + node(nullptr), sourceInfo(nullptr), fieldScope(&fieldScope) { KJ_REQUIRE(decl.which() == Declaration::FIELD); auto fieldDecl = decl.getField(); fieldType = fieldDecl.getType(); @@ -1815,6 +1067,9 @@ private: hasDefaultValue = true; fieldDefaultValue = fieldDecl.getDefaultValue().getValue(); } + if (decl.hasDocComment()) { + docComment = decl.getDocComment(); + } } inline MemberInfo(MemberInfo& parent, uint codeOrder, const Declaration::Param::Reader& decl, @@ -1824,7 +1079,7 @@ private: name(decl.getName().getValue()), declKind(Declaration::FIELD), isParam(true), declAnnotations(decl.getAnnotations()), startByte(decl.getStartByte()), endByte(decl.getEndByte()), - node(nullptr), fieldScope(&fieldScope) { + node(nullptr), sourceInfo(nullptr), fieldScope(&fieldScope) { fieldType = decl.getType(); if (decl.getDefaultValue().isValue()) { hasDefaultValue = true; @@ -1833,14 +1088,17 @@ private: } inline MemberInfo(MemberInfo& parent, uint codeOrder, const Declaration::Reader& decl, - schema::Node::Builder node, + NodeSourceInfoBuilderPair builderPair, bool isInUnion) : parent(&parent), codeOrder(codeOrder), isInUnion(isInUnion), name(decl.getName().getValue()), declId(decl.getId()), declKind(decl.which()), declAnnotations(decl.getAnnotations()), startByte(decl.getStartByte()), endByte(decl.getEndByte()), - node(node), unionScope(nullptr) { + node(builderPair.node), sourceInfo(builderPair.sourceInfo), unionScope(nullptr) { KJ_REQUIRE(decl.which() != Declaration::FIELD); + if (decl.hasDocComment()) { + docComment = decl.getDocComment(); + } } schema::Field::Builder getSchema() { @@ -1848,18 +1106,24 @@ private: return *result; } else { index = parent->childInitializedCount; - auto builder = parent->addMemberSchema(); + auto builderPair = parent->addMemberSchema(); + auto builder = builderPair.field; if (isInUnion) { builder.setDiscriminantValue(parent->unionDiscriminantCount++); } builder.setName(name); builder.setCodeOrder(codeOrder); + + KJ_IF_MAYBE(dc, docComment) { + builderPair.sourceInfo.setDocComment(*dc); + } + schema = builder; return builder; } } - schema::Field::Builder addMemberSchema() { + FieldSourceInfoBuilderPair addMemberSchema() { // Get the schema builder for the child member at the given index. This lazily/dynamically // builds the builder tree. @@ -1870,9 +1134,19 @@ private: if (parent != nullptr) { getSchema(); // Make sure field exists in parent once the first child is added. } - return structNode.initFields(childCount)[childInitializedCount++]; + FieldSourceInfoBuilderPair result { + structNode.initFields(childCount)[childInitializedCount], + sourceInfo.initMembers(childCount)[childInitializedCount] + }; + ++childInitializedCount; + return result; } else { - return structNode.getFields()[childInitializedCount++]; + FieldSourceInfoBuilderPair result { + structNode.getFields()[childInitializedCount], + sourceInfo.getMembers()[childInitializedCount] + }; + ++childInitializedCount; + return result; } } @@ -1889,6 +1163,11 @@ private: node.setId(groupId); node.setScopeId(parent->node.getId()); getSchema().initGroup().setTypeId(groupId); + + sourceInfo.setId(groupId); + KJ_IF_MAYBE(dc, docComment) { + sourceInfo.setDocComment(*dc); + } } } }; @@ -2061,9 +1340,13 @@ private: } } - schema::Node::Builder newGroupNode(schema::Node::Reader parent, kj::StringPtr name) { - auto orphan = translator.orphanage.newOrphan(); - auto node = orphan.get(); + NodeSourceInfoBuilderPair newGroupNode(schema::Node::Reader parent, kj::StringPtr name) { + AuxNode aux { + translator.orphanage.newOrphan(), + translator.orphanage.newOrphan() + }; + auto node = aux.node.get(); + auto sourceInfo = aux.sourceInfo.get(); // We'll set the ID and scope ID later. node.setDisplayName(kj::str(parent.getDisplayName(), '.', name)); @@ -2073,8 +1356,8 @@ private: // The remaining contents of node.struct will be filled in later. - translator.groups.add(kj::mv(orphan)); - return node; + translator.groups.add(kj::mv(aux)); + return { node, sourceInfo }; } void translateInternal(MemberInfo& root, schema::Node::Builder builder) { @@ -2086,7 +1369,7 @@ private: MemberInfo& member = *entry.second; // Make sure the exceptions added relating to - // https://github.com/sandstorm-io/capnproto/issues/344 identify the affected field. + // https://github.com/capnproto/capnproto/issues/344 identify the affected field. KJ_CONTEXT(member.name); if (member.declId.isOrdinal()) { @@ -2226,7 +1509,7 @@ private: structBuilder.setPreferredListEncoding(schema::ElementSize::INLINE_COMPOSITE); for (auto& group: translator.groups) { - auto groupBuilder = group.get().getStruct(); + auto groupBuilder = group.node.get().getStruct(); groupBuilder.setDataWordCount(structBuilder.getDataWordCount()); groupBuilder.setPointerCount(structBuilder.getPointerCount()); groupBuilder.setPreferredListEncoding(structBuilder.getPreferredListEncoding()); @@ -2236,13 +1519,12 @@ private: void NodeTranslator::compileStruct(Void decl, List::Reader members, schema::Node::Builder builder) { - StructTranslator(*this, noImplicitParams()).translate(decl, members, builder); + StructTranslator(*this, ImplicitParams::none()) + .translate(decl, members, builder, sourceInfo.get()); } // ------------------------------------------------------------------- -static kj::String expressionString(Expression::Reader name); - void NodeTranslator::compileInterface(Declaration::Interface::Reader decl, List::Reader members, schema::Node::Builder builder) { @@ -2253,7 +1535,7 @@ void NodeTranslator::compileInterface(Declaration::Interface::Reader decl, for (uint i: kj::indices(superclassesDecl)) { auto superclass = superclassesDecl[i]; - KJ_IF_MAYBE(decl, compileDeclExpression(superclass, noImplicitParams())) { + KJ_IF_MAYBE(decl, compileDeclExpression(superclass, ImplicitParams::none())) { KJ_IF_MAYBE(kind, decl->getKind()) { if (*kind == Declaration::INTERFACE) { auto s = superclassesBuilder[i]; @@ -2284,6 +1566,7 @@ void NodeTranslator::compileInterface(Declaration::Interface::Reader decl, } auto list = interfaceBuilder.initMethods(methods.size()); + auto sourceInfoList = sourceInfo.get().initMembers(methods.size()); uint i = 0; DuplicateOrdinalDetector dupDetector(errorReporter); @@ -2296,6 +1579,10 @@ void NodeTranslator::compileInterface(Declaration::Interface::Reader decl, dupDetector.check(ordinalDecl); uint16_t ordinal = ordinalDecl.getValue(); + if (methodDecl.hasDocComment()) { + sourceInfoList[i].setDocComment(methodDecl.getDocComment()); + } + auto methodBuilder = list[i++]; methodBuilder.setName(methodDecl.getName().getValue()); methodBuilder.setCodeOrder(codeOrder); @@ -2306,9 +1593,13 @@ void NodeTranslator::compileInterface(Declaration::Interface::Reader decl, implicitsBuilder[i].setName(implicits[i].getName()); } + auto params = methodReader.getParams(); + if (params.isStream()) { + errorReporter.addErrorOn(params, "'stream' can only appear after '->', not before."); + } methodBuilder.setParamStructType(compileParamList( methodDecl.getName().getValue(), ordinal, false, - methodReader.getParams(), implicits, + params, implicits, [&]() { return methodBuilder.initParamBrand(); })); auto results = methodReader.getResults(); @@ -2339,6 +1630,7 @@ uint64_t NodeTranslator::compileParamList( switch (paramList.which()) { case Declaration::ParamList::NAMED_LIST: { auto newStruct = orphanage.newOrphan(); + auto newSourceInfo = orphanage.newOrphan(); auto builder = newStruct.get(); auto parent = wipNode.getReader(); @@ -2357,9 +1649,9 @@ uint64_t NodeTranslator::compileParamList( // params as types actually need to refer to them as regular params, so we create an // ImplicitParams with a scopeId here. StructTranslator(*this, ImplicitParams { builder.getId(), implicitParams }) - .translate(paramList.getNamedList(), builder); + .translate(paramList.getNamedList(), builder, newSourceInfo.get()); uint64_t id = builder.getId(); - paramStructs.add(kj::mv(newStruct)); + paramStructs.add(AuxNode { kj::mv(newStruct), kj::mv(newSourceInfo) }); auto brand = localBrand->push(builder.getId(), implicitParams.size()); @@ -2399,142 +1691,37 @@ uint64_t NodeTranslator::compileParamList( } } return 0; - } - KJ_UNREACHABLE; -} - -// ------------------------------------------------------------------- - -static const char HEXDIGITS[] = "0123456789abcdef"; - -static kj::StringTree stringLiteral(kj::StringPtr chars) { - // TODO(cleanup): This code keeps coming up. Put somewhere common? - - kj::Vector escaped(chars.size()); - - for (char c: chars) { - switch (c) { - case '\a': escaped.addAll(kj::StringPtr("\\a")); break; - case '\b': escaped.addAll(kj::StringPtr("\\b")); break; - case '\f': escaped.addAll(kj::StringPtr("\\f")); break; - case '\n': escaped.addAll(kj::StringPtr("\\n")); break; - case '\r': escaped.addAll(kj::StringPtr("\\r")); break; - case '\t': escaped.addAll(kj::StringPtr("\\t")); break; - case '\v': escaped.addAll(kj::StringPtr("\\v")); break; - case '\'': escaped.addAll(kj::StringPtr("\\\'")); break; - case '\"': escaped.addAll(kj::StringPtr("\\\"")); break; - case '\\': escaped.addAll(kj::StringPtr("\\\\")); break; - default: - if (c < 0x20) { - escaped.add('\\'); - escaped.add('x'); - uint8_t c2 = c; - escaped.add(HEXDIGITS[c2 / 16]); - escaped.add(HEXDIGITS[c2 % 16]); - } else { - escaped.add(c); + case Declaration::ParamList::STREAM: + KJ_IF_MAYBE(streamCapnp, resolver.resolveImport("/capnp/stream.capnp")) { + if (streamCapnp->resolver->resolveMember("StreamResult") == nullptr) { + errorReporter.addErrorOn(paramList, + "The version of '/capnp/stream.capnp' found in your import path does not appear " + "to be the official one; it is missing the declaration of StreamResult."); } - break; - } - } - return kj::strTree('"', escaped, '"'); -} - -static kj::StringTree binaryLiteral(Data::Reader data) { - kj::Vector escaped(data.size() * 3); - - for (byte b: data) { - escaped.add(HEXDIGITS[b % 16]); - escaped.add(HEXDIGITS[b / 16]); - escaped.add(' '); - } - - escaped.removeLast(); - return kj::strTree("0x\"", escaped, '"'); -} - -static kj::StringTree expressionStringTree(Expression::Reader exp); - -static kj::StringTree tupleLiteral(List::Reader params) { - auto parts = kj::heapArrayBuilder(params.size()); - for (auto param: params) { - auto part = expressionStringTree(param.getValue()); - if (param.isNamed()) { - part = kj::strTree(param.getNamed().getValue(), " = ", kj::mv(part)); - } - parts.add(kj::mv(part)); - } - return kj::strTree("( ", kj::StringTree(parts.finish(), ", "), " )"); -} - -static kj::StringTree expressionStringTree(Expression::Reader exp) { - switch (exp.which()) { - case Expression::UNKNOWN: - return kj::strTree(""); - case Expression::POSITIVE_INT: - return kj::strTree(exp.getPositiveInt()); - case Expression::NEGATIVE_INT: - return kj::strTree('-', exp.getNegativeInt()); - case Expression::FLOAT: - return kj::strTree(exp.getFloat()); - case Expression::STRING: - return stringLiteral(exp.getString()); - case Expression::BINARY: - return binaryLiteral(exp.getBinary()); - case Expression::RELATIVE_NAME: - return kj::strTree(exp.getRelativeName().getValue()); - case Expression::ABSOLUTE_NAME: - return kj::strTree('.', exp.getAbsoluteName().getValue()); - case Expression::IMPORT: - return kj::strTree("import ", stringLiteral(exp.getImport().getValue())); - case Expression::EMBED: - return kj::strTree("embed ", stringLiteral(exp.getEmbed().getValue())); - - case Expression::LIST: { - auto list = exp.getList(); - auto parts = kj::heapArrayBuilder(list.size()); - for (auto element: list) { - parts.add(expressionStringTree(element)); + } else { + errorReporter.addErrorOn(paramList, + "A method declaration uses streaming, but '/capnp/stream.capnp' is not found " + "in the import path. This is a standard file that should always be installed " + "with the Cap'n Proto compiler."); } - return kj::strTree("[ ", kj::StringTree(parts.finish(), ", "), " ]"); - } - - case Expression::TUPLE: - return tupleLiteral(exp.getTuple()); - - case Expression::APPLICATION: { - auto app = exp.getApplication(); - return kj::strTree(expressionStringTree(app.getFunction()), - '(', tupleLiteral(app.getParams()), ')'); - } - - case Expression::MEMBER: { - auto member = exp.getMember(); - return kj::strTree(expressionStringTree(member.getParent()), '.', - member.getName().getValue()); - } + return typeId(); } - KJ_UNREACHABLE; } -static kj::String expressionString(Expression::Reader name) { - return expressionStringTree(name).flatten(); -} - // ------------------------------------------------------------------- -kj::Maybe +kj::Maybe NodeTranslator::compileDeclExpression( Expression::Reader source, ImplicitParams implicitMethodParams) { return localBrand->compileDeclExpression(source, resolver, implicitMethodParams); } -/* static */ kj::Maybe NodeTranslator::compileDecl( +/* static */ kj::Maybe NodeTranslator::compileDecl( uint64_t scopeId, uint scopeParameterCount, Resolver& resolver, ErrorReporter& errorReporter, Expression::Reader expression, schema::Brand::Builder brandBuilder) { auto scope = kj::refcounted(errorReporter, scopeId, scopeParameterCount, resolver); - KJ_IF_MAYBE(decl, scope->compileDeclExpression(expression, resolver, noImplicitParams())) { + KJ_IF_MAYBE(decl, scope->compileDeclExpression(expression, resolver, ImplicitParams::none())) { return decl->asResolveResult(scope->getScopeId(), brandBuilder); } else { return nullptr; @@ -2582,7 +1769,7 @@ void NodeTranslator::compileDefaultDefaultValue( void NodeTranslator::compileBootstrapValue( Expression::Reader source, schema::Type::Reader type, schema::Value::Builder target, - Schema typeScope) { + kj::Maybe typeScope) { // Start by filling in a default default value so that if for whatever reason we don't end up // initializing the value, this won't cause schema validation to fail. compileDefaultDefaultValue(type, target); @@ -2596,8 +1783,9 @@ void NodeTranslator::compileBootstrapValue( break; default: - // Primitive value. - compileValue(source, type, typeScope, target, true); + // Primitive value. (Note that the scope can't possibly matter since primitives are not + // generic.) + compileValue(source, type, typeScope.orDefault(Schema()), target, true); break; } } @@ -2641,29 +1829,63 @@ void NodeTranslator::compileValue(Expression::Reader source, schema::Type::Reade } kj::Maybe> ValueTranslator::compileValue(Expression::Reader src, Type type) { + if (type.isAnyPointer()) { + if (type.getBrandParameter() != nullptr || type.getImplicitParameter() != nullptr) { + errorReporter.addErrorOn(src, + "Cannot interpret value because the type is a generic type parameter which is not " + "yet bound. We don't know what type to expect here."); + return nullptr; + } + } + Orphan result = compileValueInner(src, type); + if (result.getType() == DynamicValue::UNKNOWN) { + // Error already reported. + return nullptr; + } else if (matchesType(src, type, result)) { + return kj::mv(result); + } else { + // If the expected type is a struct, we try matching its first field. + if (type.isStruct()) { + auto structType = type.asStruct(); + auto fields = structType.getFields(); + if (fields.size() > 0) { + auto field = fields[0]; + if (matchesType(src, field.getType(), result)) { + // Success. Wrap in a struct. + auto outer = orphanage.newOrphan(type.asStruct()); + outer.get().adopt(field, kj::mv(result)); + return Orphan(kj::mv(outer)); + } + } + } + + // That didn't work, so this is just a type mismatch. + errorReporter.addErrorOn(src, kj::str("Type mismatch; expected ", makeTypeName(type), ".")); + return nullptr; + } +} + +bool ValueTranslator::matchesType(Expression::Reader src, Type type, Orphan& result) { + // compileValueInner() evaluated `src` and only used `type` as a hint in interpreting `src` if + // `src`'s type wasn't already obvious. So, now we need to check that the resulting value + // actually matches `type`. + switch (result.getType()) { case DynamicValue::UNKNOWN: - // Error already reported. - return nullptr; + KJ_UNREACHABLE; case DynamicValue::VOID: - if (type.isVoid()) { - return kj::mv(result); - } - break; + return type.isVoid(); case DynamicValue::BOOL: - if (type.isBool()) { - return kj::mv(result); - } - break; + return type.isBool(); case DynamicValue::INT: { int64_t value = result.getReader().as(); if (value < 0) { - int64_t minValue = 1; + int64_t minValue; switch (type.which()) { case schema::Type::INT8: minValue = (int8_t)kj::minValue; break; case schema::Type::INT16: minValue = (int16_t)kj::minValue; break; @@ -2680,22 +1902,20 @@ kj::Maybe> ValueTranslator::compileValue(Expression::Reader minValue = (int64_t)kj::minValue; break; - default: break; + default: return false; } - if (minValue == 1) break; if (value < minValue) { errorReporter.addErrorOn(src, "Integer value out of range."); result = minValue; } - return kj::mv(result); + return true; } - // No break -- value is positive, so we can just go on to the uint case below. - } + } KJ_FALLTHROUGH; // value is positive, so we can just go on to the uint case below. case DynamicValue::UINT: { - uint64_t maxValue = 0; + uint64_t maxValue; switch (type.which()) { case schema::Type::INT8: maxValue = (int8_t)kj::maxValue; break; case schema::Type::INT16: maxValue = (int16_t)kj::maxValue; break; @@ -2712,76 +1932,62 @@ kj::Maybe> ValueTranslator::compileValue(Expression::Reader maxValue = (uint64_t)kj::maxValue; break; - default: break; + default: return false; } - if (maxValue == 0) break; if (result.getReader().as() > maxValue) { errorReporter.addErrorOn(src, "Integer value out of range."); result = maxValue; } - return kj::mv(result); + return true; } case DynamicValue::FLOAT: - if (type.isFloat32() || type.isFloat64()) { - return kj::mv(result); - } - break; + return type.isFloat32() || type.isFloat64(); case DynamicValue::TEXT: - if (type.isText()) { - return kj::mv(result); - } - break; + return type.isText(); case DynamicValue::DATA: - if (type.isData()) { - return kj::mv(result); - } - break; + return type.isData(); case DynamicValue::LIST: if (type.isList()) { - if (result.getReader().as().getSchema() == type.asList()) { - return kj::mv(result); - } + return result.getReader().as().getSchema() == type.asList(); } else if (type.isAnyPointer()) { switch (type.whichAnyPointerKind()) { case schema::Type::AnyPointer::Unconstrained::ANY_KIND: case schema::Type::AnyPointer::Unconstrained::LIST: - return kj::mv(result); + return true; case schema::Type::AnyPointer::Unconstrained::STRUCT: case schema::Type::AnyPointer::Unconstrained::CAPABILITY: - break; + return false; } + KJ_UNREACHABLE; + } else { + return false; } - break; case DynamicValue::ENUM: - if (type.isEnum()) { - if (result.getReader().as().getSchema() == type.asEnum()) { - return kj::mv(result); - } - } - break; + return type.isEnum() && + result.getReader().as().getSchema() == type.asEnum(); case DynamicValue::STRUCT: if (type.isStruct()) { - if (result.getReader().as().getSchema() == type.asStruct()) { - return kj::mv(result); - } + return result.getReader().as().getSchema() == type.asStruct(); } else if (type.isAnyPointer()) { switch (type.whichAnyPointerKind()) { case schema::Type::AnyPointer::Unconstrained::ANY_KIND: case schema::Type::AnyPointer::Unconstrained::STRUCT: - return kj::mv(result); + return true; case schema::Type::AnyPointer::Unconstrained::LIST: case schema::Type::AnyPointer::Unconstrained::CAPABILITY: - break; + return false; } + KJ_UNREACHABLE; + } else { + return false; } - break; case DynamicValue::CAPABILITY: KJ_FAIL_ASSERT("Interfaces can't have literal values."); @@ -2790,8 +1996,7 @@ kj::Maybe> ValueTranslator::compileValue(Expression::Reader KJ_FAIL_ASSERT("AnyPointers can't have literal values."); } - errorReporter.addErrorOn(src, kj::str("Type mismatch; expected ", makeTypeName(type), ".")); - return nullptr; + KJ_UNREACHABLE; } Orphan ValueTranslator::compileValueInner(Expression::Reader src, Type type) { @@ -2982,9 +2187,26 @@ void ValueTranslator::fillStructValue(DynamicStruct::Builder builder, break; case schema::Field::GROUP: + auto groupBuilder = builder.init(*field).as(); if (value.isTuple()) { - fillStructValue(builder.init(*field).as(), value.getTuple()); + fillStructValue(groupBuilder, value.getTuple()); } else { + auto groupFields = groupBuilder.getSchema().getFields(); + if (groupFields.size() > 0) { + auto groupField = groupFields[0]; + + // Call compileValueInner() using the group's type as `type`. Since we already + // established `value` is not a tuple, this will only return a valid result if + // the value has unambiguous type. + auto result = compileValueInner(value, field->getType()); + + // Does it match the first field? + if (matchesType(value, groupField.getType(), result)) { + groupBuilder.adopt(groupField, kj::mv(result)); + break; + } + } + errorReporter.addErrorOn(value, "Type mismatch; expected group."); } break; @@ -3033,8 +2255,8 @@ kj::String ValueTranslator::makeTypeName(Type type) { kj::Maybe NodeTranslator::readConstant( Expression::Reader source, bool isBootstrap) { // Look up the constant decl. - NodeTranslator::BrandedDecl constDecl = nullptr; - KJ_IF_MAYBE(decl, compileDeclExpression(source, noImplicitParams())) { + BrandedDecl constDecl = nullptr; + KJ_IF_MAYBE(decl, compileDeclExpression(source, ImplicitParams::none())) { constDecl = *decl; } else { // Lookup will have reported an error. @@ -3155,7 +2377,7 @@ Orphan> NodeTranslator::compileAnnotationApplications( annotationBuilder.initValue().setVoid(); auto name = annotation.getName(); - KJ_IF_MAYBE(decl, compileDeclExpression(name, noImplicitParams())) { + KJ_IF_MAYBE(decl, compileDeclExpression(name, ImplicitParams::none())) { KJ_IF_MAYBE(kind, decl->getKind()) { if (*kind != Declaration::ANNOTATION) { errorReporter.addErrorOn(name, kj::str( @@ -3194,7 +2416,7 @@ Orphan> NodeTranslator::compileAnnotationApplications( } } } - } else if (*kind != Declaration::ANNOTATION) { + } else { errorReporter.addErrorOn(name, kj::str( "'", expressionString(name), "' is not an annotation.")); } diff --git a/c++/src/capnp/compiler/node-translator.h b/c++/src/capnp/compiler/node-translator.h index 2041a13730..7562b08637 100644 --- a/c++/src/capnp/compiler/node-translator.h +++ b/c++/src/capnp/compiler/node-translator.h @@ -19,12 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_COMPILER_NODE_TRANSLATOR_H_ -#define CAPNP_COMPILER_NODE_TRANSLATOR_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include @@ -33,6 +28,11 @@ #include #include #include "error-reporter.h" +#include "resolver.h" +#include "generics.h" +#include + +CAPNP_BEGIN_HEADER namespace capnp { namespace compiler { @@ -41,84 +41,7 @@ class NodeTranslator { // Translates one node in the schema from AST form to final schema form. A "node" is anything // that has a unique ID, such as structs, enums, constants, and annotations, but not fields, // unions, enumerants, or methods (the latter set have 16-bit ordinals but not 64-bit global IDs). - public: - class Resolver { - // Callback class used to find other nodes relative to this one. - // - // TODO(cleanup): This has evolved into being a full interface for traversing the node tree. - // Maybe we should rename it as such, and move it out of NodeTranslator. See also - // TODO(cleanup) on NodeTranslator::BrandedDecl. - - public: - struct ResolvedDecl { - uint64_t id; - uint genericParamCount; - uint64_t scopeId; - Declaration::Which kind; - Resolver* resolver; - - kj::Maybe brand; - // If present, then it is necessary to replace the brand scope with the given brand before - // using the target type. This happens when the decl resolved to an alias; all other fields - // of `ResolvedDecl` refer to the target of the alias, except for `scopeId` which is the - // scope that contained the alias. - }; - - struct ResolvedParameter { - uint64_t id; // ID of the node declaring the parameter. - uint index; // Index of the parameter. - }; - - typedef kj::OneOf ResolveResult; - - virtual kj::Maybe resolve(kj::StringPtr name) = 0; - // Look up the given name, relative to this node, and return basic information about the - // target. - - virtual kj::Maybe resolveMember(kj::StringPtr name) = 0; - // Look up a member of this node. - - virtual ResolvedDecl resolveBuiltin(Declaration::Which which) = 0; - virtual ResolvedDecl resolveId(uint64_t id) = 0; - - virtual kj::Maybe getParent() = 0; - // Returns the parent of this scope, or null if this is the top scope. - - virtual ResolvedDecl getTopScope() = 0; - // Get the top-level scope containing this node. - - virtual kj::Maybe resolveBootstrapSchema(uint64_t id, schema::Brand::Reader brand) = 0; - // Get the schema for the given ID. If a schema is returned, it must be safe to traverse its - // dependencies via the Schema API. A schema that is only at the bootstrap stage is - // acceptable. - // - // Throws an exception if the id is not one that was found by calling resolve() or by - // traversing other schemas. Returns null if the ID is recognized, but the corresponding - // schema node failed to be built for reasons that were already reported. - - virtual kj::Maybe resolveFinalSchema(uint64_t id) = 0; - // Get the final schema for the given ID. A bootstrap schema is not acceptable. A raw - // node reader is returned rather than a Schema object because using a Schema object built - // by the final schema loader could trigger lazy initialization of dependencies which could - // lead to a cycle and deadlock. - // - // Throws an exception if the id is not one that was found by calling resolve() or by - // traversing other schemas. Returns null if the ID is recognized, but the corresponding - // schema node failed to be built for reasons that were already reported. - - virtual kj::Maybe resolveImport(kj::StringPtr name) = 0; - // Get the ID of an imported file given the import path. - - virtual kj::Maybe> readEmbed(kj::StringPtr name) = 0; - // Read and return the contents of a file for an `embed` expression. - - virtual kj::Maybe resolveBootstrapType(schema::Type::Reader type, Schema scope) = 0; - // Compile a schema::Type into a Type whose dependencies may safely be traversed via the schema - // API. These dependencies may have only bootstrap schemas. Returns null if the type could not - // be constructed due to already-reported errors. - }; - NodeTranslator(Resolver& resolver, ErrorReporter& errorReporter, const Declaration::Reader& decl, Orphan wipNode, bool compileAnnotations); @@ -136,6 +59,9 @@ class NodeTranslator { // Auxiliary nodes that were produced when translating this node and should be loaded along // with it. In particular, structs that contain groups (or named unions) spawn extra nodes // representing those, and interfaces spawn struct nodes representing method params/results. + + kj::Array sourceInfo; + // The SourceInfo for the node and all aux nodes. }; NodeSet getBootstrapNode(); @@ -147,9 +73,12 @@ class NodeTranslator { // If the final node has already been built, this will actually return the final node (in fact, // it's the same node object). - NodeSet finish(); + NodeSet finish(Schema selfUnboundBootstrap); // Finish translating the node (including filling in all the pieces that are missing from the // bootstrap node) and return it. + // + // `selfUnboundBootstrap` is a Schema build using the Node returned by getBootstrapNode(), and + // with generic parameters "unbound", i.e. it was returned by SchemaLoader::getUnbound(). static kj::Maybe compileDecl( uint64_t scopeId, uint scopeParameterCount, Resolver& resolver, ErrorReporter& errorReporter, @@ -165,8 +94,6 @@ class NodeTranslator { class DuplicateOrdinalDetector; class StructLayout; class StructTranslator; - class BrandedDecl; - class BrandScope; Resolver& resolver; ErrorReporter& errorReporter; @@ -177,17 +104,25 @@ class NodeTranslator { Orphan wipNode; // The work-in-progress schema node. - kj::Vector> groups; + Orphan sourceInfo; + // Doc comments and other source info for this node. + + struct AuxNode { + Orphan node; + Orphan sourceInfo; + }; + + kj::Vector groups; // If this is a struct node and it contains groups, these are the nodes for those groups, which // must be loaded together with the top-level node. - kj::Vector> paramStructs; + kj::Vector paramStructs; // If this is an interface, these are the auto-generated structs representing params and results. struct UnfinishedValue { Expression::Reader source; schema::Type::Reader type; - Schema typeScope; + kj::Maybe typeScope; schema::Value::Builder target; }; kj::Vector unfinishedValues; @@ -212,21 +147,6 @@ class NodeTranslator { // The `members` arrays contain only members with ordinal numbers, in code order. Other members // are handled elsewhere. - struct ImplicitParams { - // Represents a set of implicit parameters visible in the current context. - - uint64_t scopeId; - // If zero, then any reference to an implciit param in this context should be compiled to a - // `implicitMethodParam` AnyPointer. If non-zero, it should be complied to a `parameter` - // AnyPointer. - - List::Reader params; - }; - - static inline ImplicitParams noImplicitParams() { - return { 0, List::Reader() }; - } - template uint64_t compileParamList(kj::StringPtr methodName, uint16_t ordinal, bool isResults, Declaration::ParamList::Reader paramList, @@ -248,15 +168,13 @@ class NodeTranslator { void compileBootstrapValue( Expression::Reader source, schema::Type::Reader type, schema::Value::Builder target, - Schema typeScope = Schema()); - // Calls compileValue() if this value should be interpreted at bootstrap time. Otheriwse, + kj::Maybe typeScope = nullptr); + // Calls compileValue() if this value should be interpreted at bootstrap time. Otherwise, // adds the value to `unfinishedValues` for later evaluation. // - // If `type` comes from some other node, `typeScope` is the schema for that node. This is only - // really needed for looking up generic parameter bindings, therefore if the type comes from - // the node being built, an empty "Schema" (the default) works here because the node being built - // is of course being built for all possible bindings and thus none of its generic parameters are - // bound. + // If `type` comes from some other node, `typeScope` is the schema for that node. Otherwise the + // scope of the type expression is assumed to be this node (meaning, in particular, that no + // generic type parameters are bound). void compileValue(Expression::Reader source, schema::Type::Reader type, Schema typeScope, schema::Value::Builder target, bool isBootstrap); @@ -297,7 +215,8 @@ class ValueTranslator { Orphanage orphanage; Orphan compileValueInner(Expression::Reader src, Type type); - // Helper for compileValue(). + bool matchesType(Expression::Reader src, Type type, Orphan& result); + // Helpers for compileValue(). kj::String makeNodeName(Schema node); kj::String makeTypeName(Type type); @@ -308,4 +227,4 @@ class ValueTranslator { } // namespace compiler } // namespace capnp -#endif // CAPNP_COMPILER_NODE_TRANSLATOR_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/parser.c++ b/c++/src/capnp/compiler/parser.c++ index 43f0f7f575..91a8c5cf4c 100644 --- a/c++/src/capnp/compiler/parser.c++ +++ b/c++/src/capnp/compiler/parser.c++ @@ -19,10 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if _WIN32 +#include +#endif + #include "parser.h" -#include "md5.h" +#include "type-id.h" #include #include +#include #if !_MSC_VER #include #endif @@ -32,7 +37,9 @@ #if _WIN32 #include -#undef VOID +#include +#undef CONST +#include #endif namespace capnp { @@ -52,6 +59,7 @@ uint64_t generateRandomId() { #else int fd; KJ_SYSCALL(fd = open("/dev/urandom", O_RDONLY)); + KJ_DEFER(close(fd)); ssize_t n; KJ_SYSCALL(n = read(fd, &result, sizeof(result)), "/dev/urandom"); @@ -61,83 +69,8 @@ uint64_t generateRandomId() { return result | (1ull << 63); } -uint64_t generateChildId(uint64_t parentId, kj::StringPtr childName) { - // Compute ID by MD5 hashing the concatenation of the parent ID and the declaration name, and - // then taking the first 8 bytes. - - kj::byte parentIdBytes[sizeof(uint64_t)]; - for (uint i = 0; i < sizeof(uint64_t); i++) { - parentIdBytes[i] = (parentId >> (i * 8)) & 0xff; - } - - Md5 md5; - md5.update(kj::arrayPtr(parentIdBytes, kj::size(parentIdBytes))); - md5.update(childName); - - kj::ArrayPtr resultBytes = md5.finish(); - - uint64_t result = 0; - for (uint i = 0; i < sizeof(uint64_t); i++) { - result = (result << 8) | resultBytes[i]; - } - - return result | (1ull << 63); -} - -uint64_t generateGroupId(uint64_t parentId, uint16_t groupIndex) { - // Compute ID by MD5 hashing the concatenation of the parent ID and the group index, and - // then taking the first 8 bytes. - - kj::byte bytes[sizeof(uint64_t) + sizeof(uint16_t)]; - for (uint i = 0; i < sizeof(uint64_t); i++) { - bytes[i] = (parentId >> (i * 8)) & 0xff; - } - for (uint i = 0; i < sizeof(uint16_t); i++) { - bytes[sizeof(uint64_t) + i] = (groupIndex >> (i * 8)) & 0xff; - } - - Md5 md5; - md5.update(bytes); - - kj::ArrayPtr resultBytes = md5.finish(); - - uint64_t result = 0; - for (uint i = 0; i < sizeof(uint64_t); i++) { - result = (result << 8) | resultBytes[i]; - } - - return result | (1ull << 63); -} - -uint64_t generateMethodParamsId(uint64_t parentId, uint16_t methodOrdinal, bool isResults) { - // Compute ID by MD5 hashing the concatenation of the parent ID, the method ordinal, and a - // boolean indicating whether this is the params or the results, and then taking the first 8 - // bytes. - - kj::byte bytes[sizeof(uint64_t) + sizeof(uint16_t) + 1]; - for (uint i = 0; i < sizeof(uint64_t); i++) { - bytes[i] = (parentId >> (i * 8)) & 0xff; - } - for (uint i = 0; i < sizeof(uint16_t); i++) { - bytes[sizeof(uint64_t) + i] = (methodOrdinal >> (i * 8)) & 0xff; - } - bytes[sizeof(bytes) - 1] = isResults; - - Md5 md5; - md5.update(bytes); - - kj::ArrayPtr resultBytes = md5.finish(); - - uint64_t result = 0; - for (uint i = 0; i < sizeof(uint64_t); i++) { - result = (result << 8) | resultBytes[i]; - } - - return result | (1ull << 63); -} - void parseFile(List::Reader statements, ParsedFile::Builder result, - ErrorReporter& errorReporter) { + ErrorReporter& errorReporter, bool requiresId) { CapnpParser parser(Orphanage::getForMessageContaining(result), errorReporter); kj::Vector> decls(statements.size()); @@ -178,7 +111,7 @@ void parseFile(List::Reader statements, ParsedFile::Builder result, // Don't report missing ID if there was a parse error, because quite often the parse error // prevents us from parsing the ID even though it is actually there. - if (!errorReporter.hadErrors()) { + if (requiresId && !errorReporter.hadErrors()) { errorReporter.addError(0, 0, kj::str("File does not declare an ID. I've generated one for you. Add this line to " "your file: @0x", kj::hex(id), ";")); @@ -291,6 +224,27 @@ constexpr auto op(const char* expected) return p::transformOrReject(operatorToken, ExactString(expected)); } +class LocatedExactString { +public: + constexpr LocatedExactString(const char* expected): expected(expected) {} + + kj::Maybe> operator()(Located&& text) const { + if (text.value == expected) { + return kj::mv(text); + } else { + return nullptr; + } + } + +private: + const char* expected; +}; + +constexpr auto locatedKeyword(const char* expected) + -> decltype(p::transformOrReject(identifier, LocatedExactString(expected))) { + return p::transformOrReject(identifier, LocatedExactString(expected)); +} + // ======================================================================================= template @@ -505,12 +459,14 @@ CapnpParser::CapnpParser(Orphanage orphanageParam, ErrorReporter& errorReporterP initLocation(location, builder); return result; }), - p::transform(stringLiteral, - [this](Located&& value) -> Orphan { + p::transform(p::oneOrMore(stringLiteral), + [this](kj::Array>&& value) -> Orphan { auto result = orphanage.newOrphan(); auto builder = result.get(); - builder.setString(value.value); - value.copyLocationTo(builder); + builder.setString(kj::strArray( + KJ_MAP(part, value) { return part.value; }, "")); + builder.setStartByte(value.front().startByte); + builder.setEndByte(value.back().endByte); return result; }), p::transform(binaryLiteral, @@ -927,6 +883,14 @@ CapnpParser::CapnpParser(Orphanage orphanageParam, ErrorReporter& errorReporterP } return decl; }), + p::transform(locatedKeyword("stream"), + [this](Located&& kw) -> Orphan { + auto decl = orphanage.newOrphan(); + auto builder = decl.get(); + kw.copyLocationTo(builder); + builder.setStream(); + return decl; + }), p::transform(parsers.expression, [this](Orphan&& name) -> Orphan { auto decl = orphanage.newOrphan(); @@ -1130,5 +1094,95 @@ kj::Maybe> CapnpParser::parseStatement( } } +// ======================================================================================= + +static const char HEXDIGITS[] = "0123456789abcdef"; + +static kj::StringTree stringLiteralStringTree(kj::StringPtr chars) { + return kj::strTree('"', kj::encodeCEscape(chars), '"'); +} + +static kj::StringTree binaryLiteralStringTree(Data::Reader data) { + kj::Vector escaped(data.size() * 3); + + for (byte b: data) { + escaped.add(HEXDIGITS[b % 16]); + escaped.add(HEXDIGITS[b / 16]); + escaped.add(' '); + } + + escaped.removeLast(); + return kj::strTree("0x\"", escaped, '"'); +} + +static kj::StringTree expressionStringTree(Expression::Reader exp); + +static kj::StringTree tupleLiteral(List::Reader params) { + auto parts = kj::heapArrayBuilder(params.size()); + for (auto param: params) { + auto part = expressionStringTree(param.getValue()); + if (param.isNamed()) { + part = kj::strTree(param.getNamed().getValue(), " = ", kj::mv(part)); + } + parts.add(kj::mv(part)); + } + return kj::strTree("( ", kj::StringTree(parts.finish(), ", "), " )"); +} + +static kj::StringTree expressionStringTree(Expression::Reader exp) { + switch (exp.which()) { + case Expression::UNKNOWN: + return kj::strTree(""); + case Expression::POSITIVE_INT: + return kj::strTree(exp.getPositiveInt()); + case Expression::NEGATIVE_INT: + return kj::strTree('-', exp.getNegativeInt()); + case Expression::FLOAT: + return kj::strTree(exp.getFloat()); + case Expression::STRING: + return stringLiteralStringTree(exp.getString()); + case Expression::BINARY: + return binaryLiteralStringTree(exp.getBinary()); + case Expression::RELATIVE_NAME: + return kj::strTree(exp.getRelativeName().getValue()); + case Expression::ABSOLUTE_NAME: + return kj::strTree('.', exp.getAbsoluteName().getValue()); + case Expression::IMPORT: + return kj::strTree("import ", stringLiteralStringTree(exp.getImport().getValue())); + case Expression::EMBED: + return kj::strTree("embed ", stringLiteralStringTree(exp.getEmbed().getValue())); + + case Expression::LIST: { + auto list = exp.getList(); + auto parts = kj::heapArrayBuilder(list.size()); + for (auto element: list) { + parts.add(expressionStringTree(element)); + } + return kj::strTree("[ ", kj::StringTree(parts.finish(), ", "), " ]"); + } + + case Expression::TUPLE: + return tupleLiteral(exp.getTuple()); + + case Expression::APPLICATION: { + auto app = exp.getApplication(); + return kj::strTree(expressionStringTree(app.getFunction()), + '(', tupleLiteral(app.getParams()), ')'); + } + + case Expression::MEMBER: { + auto member = exp.getMember(); + return kj::strTree(expressionStringTree(member.getParent()), '.', + member.getName().getValue()); + } + } + + KJ_UNREACHABLE; +} + +kj::String expressionString(Expression::Reader name) { + return expressionStringTree(name).flatten(); +} + } // namespace compiler } // namespace capnp diff --git a/c++/src/capnp/compiler/parser.h b/c++/src/capnp/compiler/parser.h index 50370affd1..7c798b2742 100644 --- a/c++/src/capnp/compiler/parser.h +++ b/c++/src/capnp/compiler/parser.h @@ -19,12 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_COMPILER_PARSER_H_ -#define CAPNP_COMPILER_PARSER_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include @@ -32,11 +27,13 @@ #include #include "error-reporter.h" +CAPNP_BEGIN_HEADER + namespace capnp { namespace compiler { void parseFile(List::Reader statements, ParsedFile::Builder result, - ErrorReporter& errorReporter); + ErrorReporter& errorReporter, bool requiresId); // Parse a list of statements to build a ParsedFile. // // If any errors are reported, then the output is not usable. However, it may be passed on through @@ -67,7 +64,7 @@ class CapnpParser { ~CapnpParser() noexcept(false); - KJ_DISALLOW_COPY(CapnpParser); + KJ_DISALLOW_COPY_AND_MOVE(CapnpParser); using ParserInput = kj::parse::IteratorInput::Reader::Iterator>; struct DeclParserResult; @@ -145,7 +142,10 @@ class CapnpParser { Parsers parsers; }; +kj::String expressionString(Expression::Reader name); +// Stringify the expression as code. + } // namespace compiler } // namespace capnp -#endif // CAPNP_COMPILER_PARSER_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/resolver.h b/c++/src/capnp/compiler/resolver.h new file mode 100644 index 0000000000..b4a39423ba --- /dev/null +++ b/c++/src/capnp/compiler/resolver.h @@ -0,0 +1,132 @@ +// Copyright (c) 2013-2020 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include +#include +#include + +CAPNP_BEGIN_HEADER + +namespace capnp { +namespace compiler { + +class Resolver { + // Callback class used to find other nodes relative to some existing node. + // + // `Resolver` is used when compiling one declaration requires inspecting the compiled versions + // of other declarations it depends on. For example, if struct type Foo contains a field of type + // Bar, and specifies a default value for that field, then to parse that default value we need + // the compiled version of `Bar`. Or, more commonly, if a struct type Foo refers to some other + // type `Bar.Baz`, this requires doing a lookup that depends on at least partial compilation of + // `Bar`, in order to discover its nested type `Baz`. + // + // Note that declarations are often compiled just-in-time the first time they are resolved. So, + // the methods of Resolver may recurse back into other parts of the compiler. It must detect when + // a dependency cycle occurs and report an error in order to prevent an infinite loop. + +public: + struct ResolvedDecl { + // Information about a resolved declaration. + + uint64_t id; + // Type ID / node ID of the resolved declaration. + + uint genericParamCount; + // If non-zero, the declaration is a generic with the given number of parameters. + + uint64_t scopeId; + // The ID of the parent scope of this declaration. + + Declaration::Which kind; + // What basic kind of declaration is this? E.g. struct, interface, const, etc. + + Resolver* resolver; + // `Resolver` instance that can be used to further resolve other declarations relative to this + // one. + + kj::Maybe brand; + // If present, then it is necessary to replace the brand scope with the given brand before + // using the target type. This happens when the decl resolved to an alias; all other fields + // of `ResolvedDecl` refer to the target of the alias, except for `scopeId` which is the + // scope that contained the alias. + }; + + struct ResolvedParameter { + uint64_t id; // ID of the node declaring the parameter. + uint index; // Index of the parameter. + }; + + typedef kj::OneOf ResolveResult; + + virtual kj::Maybe resolve(kj::StringPtr name) = 0; + // Look up the given name, relative to this node, and return basic information about the + // target. + + virtual kj::Maybe resolveMember(kj::StringPtr name) = 0; + // Look up a member of this node. + + virtual ResolvedDecl resolveBuiltin(Declaration::Which which) = 0; + virtual ResolvedDecl resolveId(uint64_t id) = 0; + + virtual kj::Maybe getParent() = 0; + // Returns the parent of this scope, or null if this is the top scope. + + virtual ResolvedDecl getTopScope() = 0; + // Get the top-level scope containing this node. + + virtual kj::Maybe resolveBootstrapSchema(uint64_t id, schema::Brand::Reader brand) = 0; + // Get the schema for the given ID. If a schema is returned, it must be safe to traverse its + // dependencies via the Schema API. A schema that is only at the bootstrap stage is + // acceptable. + // + // Throws an exception if the id is not one that was found by calling resolve() or by + // traversing other schemas. Returns null if the ID is recognized, but the corresponding + // schema node failed to be built for reasons that were already reported. + + virtual kj::Maybe resolveFinalSchema(uint64_t id) = 0; + // Get the final schema for the given ID. A bootstrap schema is not acceptable. A raw + // node reader is returned rather than a Schema object because using a Schema object built + // by the final schema loader could trigger lazy initialization of dependencies which could + // lead to a cycle and deadlock. + // + // Throws an exception if the id is not one that was found by calling resolve() or by + // traversing other schemas. Returns null if the ID is recognized, but the corresponding + // schema node failed to be built for reasons that were already reported. + + virtual kj::Maybe resolveImport(kj::StringPtr name) = 0; + // Get the ID of an imported file given the import path. + + virtual kj::Maybe> readEmbed(kj::StringPtr name) = 0; + // Read and return the contents of a file for an `embed` expression. + + virtual kj::Maybe resolveBootstrapType(schema::Type::Reader type, Schema scope) = 0; + // Compile a schema::Type into a Type whose dependencies may safely be traversed via the schema + // API. These dependencies may have only bootstrap schemas. Returns null if the type could not + // be constructed due to already-reported errors. +}; + +} // namespace compiler +} // namespace capnp + +CAPNP_END_HEADER diff --git a/c++/src/capnp/compiler/md5-test.c++ b/c++/src/capnp/compiler/type-id-test.c++ similarity index 52% rename from c++/src/capnp/compiler/md5-test.c++ rename to c++/src/capnp/compiler/type-id-test.c++ index 76b2da0b19..04e6fde43b 100644 --- a/c++/src/capnp/compiler/md5-test.c++ +++ b/c++/src/capnp/compiler/type-id-test.c++ @@ -1,4 +1,4 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Copyright (c) 2017 Sandstorm Development Group, Inc. and contributors // Licensed under the MIT License: // // Permission is hereby granted, free of charge, to any person obtaining a copy @@ -19,43 +19,28 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#include "md5.h" -#include +#include "type-id.h" +#include +#include namespace capnp { namespace compiler { namespace { -static kj::String doMd5(kj::StringPtr text) { - Md5 md5; - md5.update(text); - return kj::str(md5.finishAsHex().cStr()); -} - -TEST(Md5, Sum) { - EXPECT_STREQ("acbd18db4cc2f85cedef654fccc4a4d8", doMd5("foo").cStr()); - EXPECT_STREQ("37b51d194a7513e45b56f6524f2d51f2", doMd5("bar").cStr()); - EXPECT_STREQ("3858f62230ac3c915f300c664312c63f", doMd5("foobar").cStr()); - - { - Md5 md5; - md5.update("foo"); - md5.update("bar"); - EXPECT_STREQ("3858f62230ac3c915f300c664312c63f", md5.finishAsHex().cStr()); - } +KJ_TEST("type ID generation hasn't changed") { + KJ_EXPECT(generateChildId(0xa93fc509624c72d9ull, "Node") == 0xe682ab4cf923a417ull); + KJ_EXPECT(generateChildId(0xe682ab4cf923a417ull, "NestedNode") == 0xdebf55bbfa0fc242ull); + KJ_EXPECT(generateGroupId(0xe682ab4cf923a417ull, 7) == 0x9ea0b19b37fb4435ull); - EXPECT_STREQ("ebf2442d167a30ca4453f99abd8cddf4", doMd5( - "Hello, this is a long string that is more than 64 bytes because the md5 code uses a " - "buffer of 64 bytes.").cStr()); + KJ_EXPECT(typeId() == 0xe682ab4cf923a417ull); + KJ_EXPECT(typeId() == 0xdebf55bbfa0fc242ull); + KJ_EXPECT(typeId() == 0x9ea0b19b37fb4435ull); - { - Md5 md5; - md5.update("Hello, this is a long string "); - md5.update("that is more than 64 bytes "); - md5.update("because the md5 code uses a "); - md5.update("buffer of 64 bytes."); - EXPECT_STREQ("ebf2442d167a30ca4453f99abd8cddf4", md5.finishAsHex().cStr()); - } + // Methods of TestInterface. + KJ_EXPECT(generateMethodParamsId(0x88eb12a0e0af92b2ull, 0, false) == 0xb874edc0d559b391ull); + KJ_EXPECT(generateMethodParamsId(0x88eb12a0e0af92b2ull, 0, true) == 0xb04fcaddab714ba4ull); + KJ_EXPECT(generateMethodParamsId(0x88eb12a0e0af92b2ull, 1, false) == 0xd044893357b42568ull); + KJ_EXPECT(generateMethodParamsId(0x88eb12a0e0af92b2ull, 1, true) == 0x9bf141df4247d52full); } } // namespace diff --git a/c++/src/capnp/compiler/md5.c++ b/c++/src/capnp/compiler/type-id.c++ similarity index 62% rename from c++/src/capnp/compiler/md5.c++ rename to c++/src/capnp/compiler/type-id.c++ index dcf5832382..107b1cd6e1 100644 --- a/c++/src/capnp/compiler/md5.c++ +++ b/c++/src/capnp/compiler/type-id.c++ @@ -1,5 +1,142 @@ -// This file was modified by Kenton Varda from code placed in the public domain. -// The code, which was originally C, was modified to give it a C++ interface. +// Copyright (c) 2013-2017 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "type-id.h" +#include +#include + +namespace capnp { +namespace compiler { + +class TypeIdGenerator { + // A non-cryptographic deterministic random number generator used to generate type IDs when the + // developer did not specify one themselves. + // + // The underlying algorithm is MD5. MD5 is safe to use here because this is not intended to be a + // cryptographic random number generator. In retrospect it would have been nice to use something + // else just to avoid people freaking out about it, but changing the algorithm now would break + // backwards-compatibility. + +public: + TypeIdGenerator(); + + void update(kj::ArrayPtr data); + inline void update(kj::ArrayPtr data) { + return update(data.asBytes()); + } + inline void update(kj::StringPtr data) { + return update(data.asArray()); + } + + kj::ArrayPtr finish(); + +private: + bool finished = false; + + struct { + uint lo, hi; + uint a, b, c, d; + kj::byte buffer[64]; + uint block[16]; + } ctx; + + const kj::byte* body(const kj::byte* ptr, size_t size); +}; + +uint64_t generateChildId(uint64_t parentId, kj::StringPtr childName) { + // Compute ID by hashing the concatenation of the parent ID and the declaration name, and + // then taking the first 8 bytes. + + kj::byte parentIdBytes[sizeof(uint64_t)]; + for (uint i = 0; i < sizeof(uint64_t); i++) { + parentIdBytes[i] = (parentId >> (i * 8)) & 0xff; + } + + TypeIdGenerator generator; + generator.update(kj::arrayPtr(parentIdBytes, kj::size(parentIdBytes))); + generator.update(childName); + + kj::ArrayPtr resultBytes = generator.finish(); + + uint64_t result = 0; + for (uint i = 0; i < sizeof(uint64_t); i++) { + result = (result << 8) | resultBytes[i]; + } + + return result | (1ull << 63); +} + +uint64_t generateGroupId(uint64_t parentId, uint16_t groupIndex) { + // Compute ID by hashing the concatenation of the parent ID and the group index, and + // then taking the first 8 bytes. + + kj::byte bytes[sizeof(uint64_t) + sizeof(uint16_t)]; + for (uint i = 0; i < sizeof(uint64_t); i++) { + bytes[i] = (parentId >> (i * 8)) & 0xff; + } + for (uint i = 0; i < sizeof(uint16_t); i++) { + bytes[sizeof(uint64_t) + i] = (groupIndex >> (i * 8)) & 0xff; + } + + TypeIdGenerator generator; + generator.update(bytes); + + kj::ArrayPtr resultBytes = generator.finish(); + + uint64_t result = 0; + for (uint i = 0; i < sizeof(uint64_t); i++) { + result = (result << 8) | resultBytes[i]; + } + + return result | (1ull << 63); +} + +uint64_t generateMethodParamsId(uint64_t parentId, uint16_t methodOrdinal, bool isResults) { + // Compute ID by hashing the concatenation of the parent ID, the method ordinal, and a + // boolean indicating whether this is the params or the results, and then taking the first 8 + // bytes. + + kj::byte bytes[sizeof(uint64_t) + sizeof(uint16_t) + 1]; + for (uint i = 0; i < sizeof(uint64_t); i++) { + bytes[i] = (parentId >> (i * 8)) & 0xff; + } + for (uint i = 0; i < sizeof(uint16_t); i++) { + bytes[sizeof(uint64_t) + i] = (methodOrdinal >> (i * 8)) & 0xff; + } + bytes[sizeof(bytes) - 1] = isResults; + + TypeIdGenerator generator; + generator.update(bytes); + + kj::ArrayPtr resultBytes = generator.finish(); + + uint64_t result = 0; + for (uint i = 0; i < sizeof(uint64_t); i++) { + result = (result << 8) | resultBytes[i]; + } + + return result | (1ull << 63); +} + +// The remainder of this file was derived from code placed in the public domain. // The original code bore the following notice: /* @@ -39,13 +176,6 @@ * compile-time configuration. */ -#include "md5.h" -#include -#include - -namespace capnp { -namespace compiler { - /* * The basic MD5 functions. * @@ -76,16 +206,16 @@ namespace compiler { */ #if defined(__i386__) || defined(__x86_64__) || defined(__vax__) #define SET(n) \ - (*(MD5_u32plus *)&ptr[(n) * 4]) + (*(uint *)&ptr[(n) * 4]) #define GET(n) \ SET(n) #else #define SET(n) \ (ctx.block[(n)] = \ - (MD5_u32plus)ptr[(n) * 4] | \ - ((MD5_u32plus)ptr[(n) * 4 + 1] << 8) | \ - ((MD5_u32plus)ptr[(n) * 4 + 2] << 16) | \ - ((MD5_u32plus)ptr[(n) * 4 + 3] << 24)) + (uint)ptr[(n) * 4] | \ + ((uint)ptr[(n) * 4 + 1] << 8) | \ + ((uint)ptr[(n) * 4 + 2] << 16) | \ + ((uint)ptr[(n) * 4 + 3] << 24)) #define GET(n) \ (ctx.block[(n)]) #endif @@ -94,10 +224,10 @@ namespace compiler { * This processes one or more 64-byte data blocks, but does NOT update * the bit counters. There are no alignment requirements. */ -const kj::byte* Md5::body(const kj::byte* ptr, size_t size) +const kj::byte* TypeIdGenerator::body(const kj::byte* ptr, size_t size) { - MD5_u32plus a, b, c, d; - MD5_u32plus saved_a, saved_b, saved_c, saved_d; + uint a, b, c, d; + uint saved_a, saved_b, saved_c, saved_d; a = ctx.a; b = ctx.b; @@ -198,7 +328,7 @@ const kj::byte* Md5::body(const kj::byte* ptr, size_t size) return ptr; } -Md5::Md5() +TypeIdGenerator::TypeIdGenerator() { ctx.a = 0x67452301; ctx.b = 0xefcdab89; @@ -209,14 +339,14 @@ Md5::Md5() ctx.hi = 0; } -void Md5::update(kj::ArrayPtr dataArray) +void TypeIdGenerator::update(kj::ArrayPtr dataArray) { - KJ_REQUIRE(!finished, "already called Md5::finish()"); + KJ_REQUIRE(!finished, "already called TypeIdGenerator::finish()"); const kj::byte* data = dataArray.begin(); unsigned long size = dataArray.size(); - MD5_u32plus saved_lo; + uint saved_lo; unsigned long used, free; saved_lo = ctx.lo; @@ -248,7 +378,7 @@ void Md5::update(kj::ArrayPtr dataArray) memcpy(ctx.buffer, data, size); } -kj::ArrayPtr Md5::finish() +kj::ArrayPtr TypeIdGenerator::finish() { if (!finished) { unsigned long used, free; @@ -304,21 +434,6 @@ kj::ArrayPtr Md5::finish() return kj::arrayPtr(ctx.buffer, 16); } -kj::StringPtr Md5::finishAsHex() { - static const char hexDigits[] = "0123456789abcdef"; - - kj::ArrayPtr bytes = finish(); - - char* chars = reinterpret_cast(ctx.buffer + 16); - char* pos = chars; - for (auto byte: bytes) { - *pos++ = hexDigits[byte / 16]; - *pos++ = hexDigits[byte % 16]; - } - *pos++ = '\0'; - - return kj::StringPtr(chars, 32); -} } // namespace compiler } // namespace capnp diff --git a/c++/src/capnp/compiler/type-id.h b/c++/src/capnp/compiler/type-id.h new file mode 100644 index 0000000000..5968a1762d --- /dev/null +++ b/c++/src/capnp/compiler/type-id.h @@ -0,0 +1,46 @@ +// Copyright (c) 2017 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include +#include + +CAPNP_BEGIN_HEADER + +namespace capnp { +namespace compiler { + +uint64_t generateChildId(uint64_t parentId, kj::StringPtr childName); +uint64_t generateGroupId(uint64_t parentId, uint16_t groupIndex); +uint64_t generateMethodParamsId(uint64_t parentId, uint16_t methodOrdinal, bool isResults); +// Generate a default type ID for various symbols. These are used only if the developer did not +// specify an ID explicitly. +// +// The returned ID always has the most-significant bit set. The remaining bits are generated +// pseudo-randomly from the input using an algorithm that should produce a uniform distribution of +// IDs. + +} // namespace compiler +} // namespace capnp + +CAPNP_END_HEADER diff --git a/c++/src/capnp/dynamic-capability.c++ b/c++/src/capnp/dynamic-capability.c++ index e56d8fadc6..81a4ed3540 100644 --- a/c++/src/capnp/dynamic-capability.c++ +++ b/c++/src/capnp/dynamic-capability.c++ @@ -40,8 +40,11 @@ Request DynamicCapability::Client::newRequest( auto paramType = method.getParamType(); auto resultType = method.getResultType(); + CallHints hints; + hints.noPromisePipelining = !resultType.mayContainCapabilities(); + auto typeless = hook->newCall( - methodInterface.getProto().getId(), method.getIndex(), sizeHint); + methodInterface.getProto().getId(), method.getIndex(), sizeHint, hints); return Request( typeless.getAs(paramType), kj::mv(typeless.hook), resultType); @@ -52,15 +55,20 @@ Request DynamicCapability::Client::newRequest( return newRequest(schema.getMethodByName(methodName), sizeHint); } -kj::Promise DynamicCapability::Server::dispatchCall( +Capability::Server::DispatchCallResult DynamicCapability::Server::dispatchCall( uint64_t interfaceId, uint16_t methodId, CallContext context) { KJ_IF_MAYBE(interface, schema.findSuperclass(interfaceId)) { auto methods = interface->getMethods(); if (methodId < methods.size()) { auto method = methods[methodId]; - return call(method, CallContext(*context.hook, - method.getParamType(), method.getResultType())); + auto resultType = method.getResultType(); + return { + call(method, CallContext(*context.hook, + method.getParamType(), resultType)), + resultType.isStreamResult(), + options.allowCancellation + }; } else { return internalUnimplemented( interface->getProto().getDisplayName().cStr(), interfaceId, methodId); @@ -72,6 +80,7 @@ kj::Promise DynamicCapability::Server::dispatchCall( RemotePromise Request::send() { auto typelessPromise = hook->send(); + hook = nullptr; // prevent reuse auto resultSchemaCopy = resultSchema; // Convert the Promise to return the correct response type. @@ -90,4 +99,12 @@ RemotePromise Request::send() { return RemotePromise(kj::mv(typedPromise), kj::mv(typedPipeline)); } +kj::Promise Request::sendStreaming() { + KJ_REQUIRE(resultSchema.isStreamResult()); + + auto promise = hook->sendStreaming(); + hook = nullptr; // prevent reuse + return promise; +} + } // namespace capnp diff --git a/c++/src/capnp/dynamic-test.c++ b/c++/src/capnp/dynamic-test.c++ index e337cbdab5..1c08c206d4 100644 --- a/c++/src/capnp/dynamic-test.c++ +++ b/c++/src/capnp/dynamic-test.c++ @@ -422,15 +422,20 @@ TEST(DynamicApi, Has) { // Primitive fields are always present even if set to default. EXPECT_TRUE(root.has("int32Field")); + EXPECT_FALSE(root.has("int32Field", HasMode::NON_DEFAULT)); root.set("int32Field", 123); EXPECT_TRUE(root.has("int32Field")); + EXPECT_TRUE(root.has("int32Field", HasMode::NON_DEFAULT)); root.set("int32Field", -12345678); EXPECT_TRUE(root.has("int32Field")); + EXPECT_FALSE(root.has("int32Field", HasMode::NON_DEFAULT)); // Pointers are absent until initialized. EXPECT_FALSE(root.has("structField")); + EXPECT_FALSE(root.has("structField", HasMode::NON_DEFAULT)); root.init("structField"); EXPECT_TRUE(root.has("structField")); + EXPECT_TRUE(root.has("structField", HasMode::NON_DEFAULT)); } TEST(DynamicApi, HasWhenEmpty) { @@ -443,6 +448,11 @@ TEST(DynamicApi, HasWhenEmpty) { EXPECT_TRUE(root.has("int32Field")); EXPECT_FALSE(root.has("structField")); EXPECT_FALSE(root.has("int32List")); + + EXPECT_FALSE(root.has("voidField", HasMode::NON_DEFAULT)); + EXPECT_FALSE(root.has("int32Field", HasMode::NON_DEFAULT)); + EXPECT_FALSE(root.has("structField", HasMode::NON_DEFAULT)); + EXPECT_FALSE(root.has("int32List", HasMode::NON_DEFAULT)); } TEST(DynamicApi, SetEnumFromNative) { diff --git a/c++/src/capnp/dynamic.c++ b/c++/src/capnp/dynamic.c++ index a06429a22d..59e5d1268f 100644 --- a/c++/src/capnp/dynamic.c++ +++ b/c++/src/capnp/dynamic.c++ @@ -180,7 +180,7 @@ DynamicValue::Reader DynamicStruct::Reader::get(StructSchema::Field field) const case schema::Field::SLOT: { auto slot = proto.getSlot(); - // Note that the default value might be "anyPointer" even if the type is some poniter type + // Note that the default value might be "anyPointer" even if the type is some pointer type // *other than* anyPointer. This happens with generics -- the field is actually a generic // parameter that has been bound, but the default value was of course compiled without any // binding available. @@ -272,7 +272,7 @@ DynamicValue::Builder DynamicStruct::Builder::get(StructSchema::Field field) { case schema::Field::SLOT: { auto slot = proto.getSlot(); - // Note that the default value might be "anyPointer" even if the type is some poniter type + // Note that the default value might be "anyPointer" even if the type is some pointer type // *other than* anyPointer. This happens with generics -- the field is actually a generic // parameter that has been bound, but the default value was of course compiled without any // binding available. @@ -414,7 +414,7 @@ DynamicValue::Pipeline DynamicStruct::Pipeline::get(StructSchema::Field field) { KJ_UNREACHABLE; } -bool DynamicStruct::Reader::has(StructSchema::Field field) const { +bool DynamicStruct::Reader::has(StructSchema::Field field, HasMode mode) const { KJ_REQUIRE(field.getContainingStruct() == schema, "`field` is not a field of this struct."); auto proto = field.getProto(); @@ -441,20 +441,35 @@ bool DynamicStruct::Reader::has(StructSchema::Field field) const { switch (type.which()) { case schema::Type::VOID: + // Void is always equal to the default. + return mode == HasMode::NON_NULL; + case schema::Type::BOOL: + return mode == HasMode::NON_NULL || + reader.getDataField(assumeDataOffset(slot.getOffset()), 0) != 0; + case schema::Type::INT8: - case schema::Type::INT16: - case schema::Type::INT32: - case schema::Type::INT64: case schema::Type::UINT8: + return mode == HasMode::NON_NULL || + reader.getDataField(assumeDataOffset(slot.getOffset()), 0) != 0; + + case schema::Type::INT16: case schema::Type::UINT16: + case schema::Type::ENUM: + return mode == HasMode::NON_NULL || + reader.getDataField(assumeDataOffset(slot.getOffset()), 0) != 0; + + case schema::Type::INT32: case schema::Type::UINT32: - case schema::Type::UINT64: case schema::Type::FLOAT32: + return mode == HasMode::NON_NULL || + reader.getDataField(assumeDataOffset(slot.getOffset()), 0) != 0; + + case schema::Type::INT64: + case schema::Type::UINT64: case schema::Type::FLOAT64: - case schema::Type::ENUM: - // Primitive types are always present. - return true; + return mode == HasMode::NON_NULL || + reader.getDataField(assumeDataOffset(slot.getOffset()), 0) != 0; case schema::Type::TEXT: case schema::Type::DATA: @@ -725,6 +740,7 @@ DynamicValue::Builder DynamicStruct::Builder::init(StructSchema::Field field, ui (uint)type.which()); break; } + KJ_UNREACHABLE; } case schema::Field::GROUP: @@ -985,11 +1001,11 @@ DynamicValue::Builder DynamicStruct::Builder::get(kj::StringPtr name) { DynamicValue::Pipeline DynamicStruct::Pipeline::get(kj::StringPtr name) { return get(schema.getFieldByName(name)); } -bool DynamicStruct::Reader::has(kj::StringPtr name) const { - return has(schema.getFieldByName(name)); +bool DynamicStruct::Reader::has(kj::StringPtr name, HasMode mode) const { + return has(schema.getFieldByName(name), mode); } -bool DynamicStruct::Builder::has(kj::StringPtr name) { - return has(schema.getFieldByName(name)); +bool DynamicStruct::Builder::has(kj::StringPtr name, HasMode mode) { + return has(schema.getFieldByName(name), mode); } void DynamicStruct::Builder::set(kj::StringPtr name, const DynamicValue::Reader& value) { set(schema.getFieldByName(name), value); @@ -1451,6 +1467,14 @@ DynamicValue::Reader::Reader(ConstSchema constant): type(VOID) { } } +#if __GNUC__ && !__clang__ && __GNUC__ >= 9 +// In the copy constructors below, we use memcpy() to copy only after verifying that it is safe. +// But GCC 9 doesn't know we've checked, and whines. I suppose GCC is probably right: our checks +// probably don't technically make memcpy safe according to the standard. But it works in practice, +// and if it ever stops working, the tests will catch it. +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif + DynamicValue::Reader::Reader(const Reader& other) { switch (other.type) { case UNKNOWN: @@ -1549,12 +1573,12 @@ DynamicValue::Builder::Builder(Builder& other) { // Unfortunately canMemcpy() doesn't work on these types due to the use of // DisallowConstCopy, but __has_trivial_destructor should detect if any of these types // become non-trivial. - static_assert(__has_trivial_destructor(Text::Builder) && - __has_trivial_destructor(Data::Builder) && - __has_trivial_destructor(DynamicList::Builder) && - __has_trivial_destructor(DynamicEnum) && - __has_trivial_destructor(DynamicStruct::Builder) && - __has_trivial_destructor(AnyPointer::Builder), + static_assert(KJ_HAS_TRIVIAL_DESTRUCTOR(Text::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(Data::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicList::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicEnum) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicStruct::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(AnyPointer::Builder), "Assumptions here don't hold."); break; @@ -1583,12 +1607,12 @@ DynamicValue::Builder::Builder(Builder&& other) noexcept { // Unfortunately __has_trivial_copy doesn't work on these types due to the use of // DisallowConstCopy, but __has_trivial_destructor should detect if any of these types // become non-trivial. - static_assert(__has_trivial_destructor(Text::Builder) && - __has_trivial_destructor(Data::Builder) && - __has_trivial_destructor(DynamicList::Builder) && - __has_trivial_destructor(DynamicEnum) && - __has_trivial_destructor(DynamicStruct::Builder) && - __has_trivial_destructor(AnyPointer::Builder), + static_assert(KJ_HAS_TRIVIAL_DESTRUCTOR(Text::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(Data::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicList::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicEnum) && + KJ_HAS_TRIVIAL_DESTRUCTOR(DynamicStruct::Builder) && + KJ_HAS_TRIVIAL_DESTRUCTOR(AnyPointer::Builder), "Assumptions here don't hold."); break; @@ -1709,11 +1733,32 @@ int64_t unsignedToSigned(unsigned long long value) { template T checkRoundTrip(U value) { - KJ_REQUIRE(T(value) == value, "Value out-of-range for requested type.", value) { + T result = value; + KJ_REQUIRE(U(result) == value, "Value out-of-range for requested type.", value) { // Use it anyway. break; } - return value; + return result; +} + +template +T checkRoundTripFromFloat(U value) { + // When `U` is `float` or `double`, we have to use a different approach, because casting an + // out-of-range float to an integer is, surprisingly, UB. + constexpr T MIN = kj::minValue; + constexpr T MAX = kj::maxValue; + KJ_REQUIRE(value >= U(MIN), "Value out-of-range for requested type.", value) { + return MIN; + } + KJ_REQUIRE(value <= U(MAX), "Value out-of-range for requested type.", value) { + return MAX; + } + T result = value; + KJ_REQUIRE(U(result) == value, "Value out-of-range for requested type.", value) { + // Use it anyway. + break; + } + return result; } } // namespace @@ -1748,14 +1793,14 @@ typeName DynamicValue::Builder::AsImpl::apply(Builder& builder) { \ } \ } -HANDLE_NUMERIC_TYPE(int8_t, checkRoundTrip, unsignedToSigned, checkRoundTrip) -HANDLE_NUMERIC_TYPE(int16_t, checkRoundTrip, unsignedToSigned, checkRoundTrip) -HANDLE_NUMERIC_TYPE(int32_t, checkRoundTrip, unsignedToSigned, checkRoundTrip) -HANDLE_NUMERIC_TYPE(int64_t, kj::implicitCast, unsignedToSigned, checkRoundTrip) -HANDLE_NUMERIC_TYPE(uint8_t, signedToUnsigned, checkRoundTrip, checkRoundTrip) -HANDLE_NUMERIC_TYPE(uint16_t, signedToUnsigned, checkRoundTrip, checkRoundTrip) -HANDLE_NUMERIC_TYPE(uint32_t, signedToUnsigned, checkRoundTrip, checkRoundTrip) -HANDLE_NUMERIC_TYPE(uint64_t, signedToUnsigned, kj::implicitCast, checkRoundTrip) +HANDLE_NUMERIC_TYPE(int8_t, checkRoundTrip, unsignedToSigned, checkRoundTripFromFloat) +HANDLE_NUMERIC_TYPE(int16_t, checkRoundTrip, unsignedToSigned, checkRoundTripFromFloat) +HANDLE_NUMERIC_TYPE(int32_t, checkRoundTrip, unsignedToSigned, checkRoundTripFromFloat) +HANDLE_NUMERIC_TYPE(int64_t, kj::implicitCast, unsignedToSigned, checkRoundTripFromFloat) +HANDLE_NUMERIC_TYPE(uint8_t, signedToUnsigned, checkRoundTrip, checkRoundTripFromFloat) +HANDLE_NUMERIC_TYPE(uint16_t, signedToUnsigned, checkRoundTrip, checkRoundTripFromFloat) +HANDLE_NUMERIC_TYPE(uint32_t, signedToUnsigned, checkRoundTrip, checkRoundTripFromFloat) +HANDLE_NUMERIC_TYPE(uint64_t, signedToUnsigned, kj::implicitCast, checkRoundTripFromFloat) HANDLE_NUMERIC_TYPE(float, kj::implicitCast, kj::implicitCast, kj::implicitCast) HANDLE_NUMERIC_TYPE(double, kj::implicitCast, kj::implicitCast, kj::implicitCast) diff --git a/c++/src/capnp/dynamic.h b/c++/src/capnp/dynamic.h index fcefcc3bf2..8aab1f7ad9 100644 --- a/c++/src/capnp/dynamic.h +++ b/c++/src/capnp/dynamic.h @@ -30,18 +30,16 @@ // As always, underlying data is validated lazily, so you have to actually traverse the whole // message if you want to validate all content. -#ifndef CAPNP_DYNAMIC_H_ -#define CAPNP_DYNAMIC_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "schema.h" #include "layout.h" #include "message.h" #include "any.h" #include "capability.h" +#include // work-around macro conflict with `VOID` + +CAPNP_BEGIN_HEADER namespace capnp { @@ -170,6 +168,21 @@ class DynamicEnum { // ------------------------------------------------------------------- +enum class HasMode: uint8_t { + // Specifies the meaning of "has(field)". + + NON_NULL, + // "has(field)" only returns false if the field is a pointer and the pointer is null. This is the + // default behavior. + + NON_DEFAULT + // "has(field)" returns false if the field is set to its default value. This differs from + // NON_NULL only in the handling of primitive values. + // + // "Equal to default value" is technically defined as the field value being encoded as all-zero + // on the wire (since primitive values are XORed by their defined default value when encoded). +}; + class DynamicStruct::Reader { public: typedef DynamicStruct Reads; @@ -179,6 +192,8 @@ class DynamicStruct::Reader { template >() == Kind::STRUCT>> inline Reader(T&& value): Reader(toDynamic(value)) {} + inline operator AnyStruct::Reader() const { return AnyStruct::Reader(reader); } + inline MessageSize totalSize() const { return reader.totalSize().asPublic(); } template @@ -190,12 +205,10 @@ class DynamicStruct::Reader { DynamicValue::Reader get(StructSchema::Field field) const; // Read the given field value. - bool has(StructSchema::Field field) const; - // Tests whether the given field is set to its default value. For pointer values, this does - // not actually traverse the value comparing it with the default, but simply returns true if the - // pointer is non-null. For members of unions, has() returns false if the union member is not - // active, but does not necessarily return true if the member is active (depends on the field's - // value). + bool has(StructSchema::Field field, HasMode mode = HasMode::NON_NULL) const; + // Tests whether the given field is "present". If the field is a union member and is not the + // active member, this always returns false. Otherwise, the field's value is interpreted + // according to `mode`. kj::Maybe which() const; // If the struct contains an (unnamed) union, and the currently-active field within that union @@ -205,7 +218,7 @@ class DynamicStruct::Reader { // newer version of the protocol and is using a field of the union that you don't know about yet. DynamicValue::Reader get(kj::StringPtr name) const; - bool has(kj::StringPtr name) const; + bool has(kj::StringPtr name, HasMode mode = HasMode::NON_NULL) const; // Shortcuts to access fields by name. These throw exceptions if no such field exists. private: @@ -234,6 +247,7 @@ class DynamicStruct::Reader { friend class Orphan; friend class Orphan; friend class Orphan; + friend class AnyStruct::Reader; }; class DynamicStruct::Builder { @@ -246,6 +260,8 @@ class DynamicStruct::Builder { template >() == Kind::STRUCT>> inline Builder(T&& value): Builder(toDynamic(value)) {} + inline operator AnyStruct::Builder() { return AnyStruct::Builder(builder); } + inline MessageSize totalSize() const { return asReader().totalSize(); } template @@ -257,12 +273,11 @@ class DynamicStruct::Builder { DynamicValue::Builder get(StructSchema::Field field); // Read the given field value. - inline bool has(StructSchema::Field field) { return asReader().has(field); } - // Tests whether the given field is set to its default value. For pointer values, this does - // not actually traverse the value comparing it with the default, but simply returns true if the - // pointer is non-null. For members of unions, has() returns whether the field is currently - // active and the union as a whole is non-default -- so, the only time has() will return false - // for an active union field is if it is the default active field and it has its default value. + inline bool has(StructSchema::Field field, HasMode mode = HasMode::NON_NULL) + { return asReader().has(field, mode); } + // Tests whether the given field is "present". If the field is a union member and is not the + // active member, this always returns false. Otherwise, the field's value is interpreted + // according to `mode`. kj::Maybe which(); // If the struct contains an (unnamed) union, and the currently-active field within that union @@ -288,7 +303,7 @@ class DynamicStruct::Builder { // field null. DynamicValue::Builder get(kj::StringPtr name); - bool has(kj::StringPtr name); + bool has(kj::StringPtr name, HasMode mode = HasMode::NON_NULL); void set(kj::StringPtr name, const DynamicValue::Reader& value); void set(kj::StringPtr name, std::initializer_list value); DynamicValue::Builder init(kj::StringPtr name); @@ -323,6 +338,7 @@ class DynamicStruct::Builder { friend class Orphan; friend class Orphan; friend class Orphan; + friend class AnyStruct::Builder; }; class DynamicStruct::Pipeline { @@ -364,6 +380,8 @@ class DynamicList::Reader { template >() == Kind::LIST>> inline Reader(T&& value): Reader(toDynamic(value)) {} + inline operator AnyList::Reader() const { return AnyList::Reader(reader); } + template typename T::Reader as() const; // Try to convert to any List, Data, or Text. Throws an exception if the underlying data @@ -407,6 +425,8 @@ class DynamicList::Builder { template >() == Kind::LIST>> inline Builder(T&& value): Builder(toDynamic(value)) {} + inline operator AnyList::Builder() { return AnyList::Builder(builder); } + template typename T::Builder as(); // Try to convert to any List, Data, or Text. Throws an exception if the underlying data @@ -504,18 +524,28 @@ class DynamicCapability::Server: public Capability::Server { public: typedef DynamicCapability Serves; + struct Options { + bool allowCancellation = false; + // See the `allowCancellation` annotation defined in `c++.capnp`. + // + // This option applies to all calls made to this server object. The annotation in the schema + // is NOT used for dynamic servers. + }; + Server(InterfaceSchema schema): schema(schema) {} + Server(InterfaceSchema schema, Options options): schema(schema), options(options) {} virtual kj::Promise call(InterfaceSchema::Method method, CallContext context) = 0; - kj::Promise dispatchCall(uint64_t interfaceId, uint16_t methodId, - CallContext context) override final; + DispatchCallResult dispatchCall(uint64_t interfaceId, uint16_t methodId, + CallContext context) override final; inline InterfaceSchema getSchema() const { return schema; } private: InterfaceSchema schema; + Options options; }; template <> @@ -530,6 +560,10 @@ class Request: public DynamicStruct::Builder { RemotePromise send(); // Send the call and return a promise for the results. + kj::Promise sendStreaming(); + // Use when the caller is aware that the response type is StreamResult and wants to invoke + // streaming behavior. It is an error to call this if the response type is not StreamResult. + private: kj::Own hook; StructSchema resultSchema; @@ -560,7 +594,9 @@ class CallContext: public kj::DisallowConstCopy { Orphanage getResultsOrphanage(kj::Maybe sizeHint = nullptr); template kj::Promise tailCall(Request&& tailRequest); - void allowCancellation(); + + StructSchema getParamsType() const { return paramType; } + StructSchema getResultsType() const { return resultType; } private: CallContextHook* hook; @@ -1516,6 +1552,16 @@ inline AnyStruct::Builder DynamicStruct::Builder::as() { return AnyStruct::Builder(builder); } +template <> +inline DynamicStruct::Reader AnyStruct::Reader::as(StructSchema schema) const { + return DynamicStruct::Reader(schema, _reader); +} + +template <> +inline DynamicStruct::Builder AnyStruct::Builder::as(StructSchema schema) { + return DynamicStruct::Builder(schema, _builder); +} + template typename T::Pipeline DynamicStruct::Pipeline::releaseAs() { static_assert(kind() == Kind::STRUCT, @@ -1621,9 +1667,6 @@ inline kj::Promise CallContext::tailCall( Request&& tailRequest) { return hook->tailCall(kj::mv(tailRequest.hook)); } -inline void CallContext::allowCancellation() { - hook->allowCancellation(); -} template <> inline DynamicCapability::Client Capability::Client::castAs( @@ -1631,6 +1674,14 @@ inline DynamicCapability::Client Capability::Client::castAs( return DynamicCapability::Client(schema, hook->addRef()); } +template <> +inline DynamicCapability::Client CapabilityServerSet::add( + kj::Own&& server) { + void* ptr = reinterpret_cast(server.get()); + auto schema = server->getSchema(); + return addInternal(kj::mv(server), ptr).castAs(schema); +} + // ------------------------------------------------------------------- template @@ -1640,4 +1691,4 @@ ReaderFor ConstSchema::as() const { } // namespace capnp -#endif // CAPNP_DYNAMIC_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/encoding-test.c++ b/c++/src/capnp/encoding-test.c++ index ce7fabc266..8487501505 100644 --- a/c++/src/capnp/encoding-test.c++ +++ b/c++/src/capnp/encoding-test.c++ @@ -1471,6 +1471,57 @@ TEST(Encoding, ListSetters) { dst.set(0, src[0]); dst.set(1, src[1]); } + + checkTestMessage(root2); + + // Now let's do some adopting and disowning. + auto adopter = builder2.getOrphanage().newOrphan(); + auto disowner = root2.disownLists(); + + adopter.get().adoptList0(disowner.get().disownList0()); + adopter.get().adoptList1(disowner.get().disownList1()); + adopter.get().adoptList8(disowner.get().disownList8()); + adopter.get().adoptList16(disowner.get().disownList16()); + adopter.get().adoptList32(disowner.get().disownList32()); + adopter.get().adoptList64(disowner.get().disownList64()); + adopter.get().adoptListP(disowner.get().disownListP()); + + { + auto dst = adopter.get().initInt32ListList(3); + auto src = disowner.get().getInt32ListList(); + + auto orphan = src.disown(0); + checkList(orphan.getReader(), {1, 2, 3}); + dst.adopt(0, kj::mv(orphan)); + dst.adopt(1, src.disown(1)); + dst.adopt(2, src.disown(2)); + } + + { + auto dst = adopter.get().initTextListList(3); + auto src = disowner.get().getTextListList(); + + auto orphan = src.disown(0); + checkList(orphan.getReader(), {"foo", "bar"}); + dst.adopt(0, kj::mv(orphan)); + dst.adopt(1, src.disown(1)); + dst.adopt(2, src.disown(2)); + } + + { + auto dst = adopter.get().initStructListList(2); + auto src = disowner.get().getStructListList(); + + auto orphan = src.disown(0); + KJ_EXPECT(orphan.getReader()[0].getInt32Field() == 123); + KJ_EXPECT(orphan.getReader()[1].getInt32Field() == 456); + dst.adopt(0, kj::mv(orphan)); + dst.adopt(1, src.disown(1)); + } + + root2.adoptLists(kj::mv(adopter)); + + checkTestMessage(root2); } } @@ -1687,6 +1738,14 @@ TEST(Encoding, GlobalConstants) { EXPECT_EQ("structlist 2", listReader[1].getTextField()); EXPECT_EQ("structlist 3", listReader[2].getTextField()); } + + kj::StringPtr expected = + "foo bar baz\n" + "\"qux\" `corge` \'grault\'\n" + "regular\"quoted\"line" + "garply\\nwaldo\\tfred\\\"plugh\\\"xyzzy\\\'thud\n"; + + EXPECT_EQ(expected, test::BLOCK_TEXT); } TEST(Encoding, Embeds) { @@ -1929,6 +1988,120 @@ TEST(Encoding, DefaultListBuilder) { List::Builder(nullptr); } +TEST(Encoding, ListSize) { + MallocMessageBuilder builder; + auto root = builder.initRoot(); + initTestMessage(root); + + auto lists = root.asReader().getLists(); + + auto listSizes = + lists.getList0().totalSize() + + lists.getList1().totalSize() + + lists.getList8().totalSize() + + lists.getList16().totalSize() + + lists.getList32().totalSize() + + lists.getList64().totalSize() + + lists.getListP().totalSize() + + lists.getInt32ListList().totalSize() + + lists.getTextListList().totalSize() + + lists.getStructListList().totalSize(); + + auto structSize = lists.totalSize(); + + auto shallowSize = unbound(capnp::_::structSize().total() / WORDS); + + EXPECT_EQ(structSize.wordCount - shallowSize, listSizes.wordCount); +} + +KJ_TEST("list.setWithCaveats(i, list[i]) doesn't corrupt contents") { + MallocMessageBuilder builder; + auto root = builder.initRoot(); + auto list = root.initStructList(2); + initTestMessage(list[0]); + list.setWithCaveats(0, list[0]); + checkTestMessage(list[0]); + checkTestMessageAllZero(list[1]); + list.setWithCaveats(1, list[0]); + checkTestMessage(list[0]); + checkTestMessage(list[1]); +} + +KJ_TEST("Downgrade pointer-list from struct-list") { + // Test that downgrading a list-of-structs to a list-of-pointers (where the relevant pointer is + // the struct's first pointer) works as advertised. + + MallocMessageBuilder builder; + auto root = builder.initRoot(); + + { + auto list = root.getAnyPointerField().initAs>(2); + initTestMessage(list[0]); + list[1].setTextField("hello"); + } + + { + auto list = root.asReader().getAnyPointerField().getAs>(); + KJ_ASSERT(list.size() == 2); + KJ_EXPECT(list[0] == "foo"); + KJ_EXPECT(list[1] == "hello"); + } +} + +KJ_TEST("Copying ListList downgraded from ListStruct does not get corrupted") { + // Test written by David Renshaw to demonstrate CVE-??? + + AlignedData<10> data = {{ + // struct, 1 pointer + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + // list, inline composite. 4 words. + 0x01, 0x00, 0x00, 0x00, 0x27, 0x00, 0x00, 0x00, + + // one element, 3 data words, 1 pointer. + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x00, + + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, // data section + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, // data section + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, // data section + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // null struct pointer + + // bad bytes that shouldn't be visible from the root of the message + 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, + 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, + + // bug can cause this word to be read as the list element struct pointer + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + }}; + + kj::ArrayPtr segments[1] = { + // Only take the first 7 words. The last three words above should not be accessible + // from these segments. + kj::arrayPtr(data.words, 7) + }; + + SegmentArrayMessageReader reader(kj::arrayPtr(segments, 1)); + auto readerRoot = reader.getRoot(); + auto listList = readerRoot.getAnyPointerField().getAs>>(); + EXPECT_EQ(listList.size(), 1); + + MallocMessageBuilder builder; + auto root = builder.initRoot(); + + root.getAnyPointerField().setAs>>(listList); + + auto outputSegments = builder.getSegmentsForOutput(); + ASSERT_EQ(outputSegments.size(), 1); + + auto inputBytes = segments[0].asBytes(); + auto outputBytes = outputSegments[0].asBytes(); + + ASSERT_EQ(outputBytes, inputBytes); + // Should be equal. Instead, we see that outputBytes includes the (copied) + // out-of-bounds 0xbb bytes from `data` above, which should be impossible. +} + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/c++/src/capnp/endian.h b/c++/src/capnp/endian.h index c5a6e63c5a..c0e3d75840 100644 --- a/c++/src/capnp/endian.h +++ b/c++/src/capnp/endian.h @@ -19,17 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_ENDIAN_H_ -#define CAPNP_ENDIAN_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "common.h" #include #include // memcpy +CAPNP_BEGIN_HEADER + namespace capnp { namespace _ { // private @@ -306,4 +303,4 @@ using WireValue = ShiftingWireValue; } // namespace _ (private) } // namespace capnp -#endif // CAPNP_ENDIAN_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/ez-rpc-test.c++ b/c++/src/capnp/ez-rpc-test.c++ index 35d40e5f02..0cd2fdbb5e 100644 --- a/c++/src/capnp/ez-rpc-test.c++ +++ b/c++/src/capnp/ez-rpc-test.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#define CAPNP_TESTING_CAPNP 1 + #include "ez-rpc.h" #include "test-util.h" #include diff --git a/c++/src/capnp/ez-rpc.c++ b/c++/src/capnp/ez-rpc.c++ index bfd6880c7e..5871c77bf4 100644 --- a/c++/src/capnp/ez-rpc.c++ +++ b/c++/src/capnp/ez-rpc.c++ @@ -177,10 +177,10 @@ Capability::Client EzRpcClient::importCap(kj::StringPtr name) { KJ_IF_MAYBE(client, impl->clientContext) { return client->get()->restore(name); } else { - return impl->setupPromise.addBranch().then(kj::mvCapture(kj::heapString(name), - [this](kj::String&& name) { + return impl->setupPromise.addBranch().then( + [this,name=kj::heapString(name)]() { return KJ_ASSERT_NONNULL(impl->clientContext)->restore(name); - })); + }); } } @@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() { // ======================================================================================= +namespace { + +class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter { +public: + bool shouldAllow(const struct sockaddr* addr, uint addrlen) override { + return true; + } +}; + +static DummyFilter DUMMY_FILTER; + +} // namespace + struct EzRpcServer::Impl final: public SturdyRefRestorer, public kj::TaskSet::ErrorHandler { Capability::Client mainInterface; @@ -247,13 +260,11 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer, portPromise = paf.promise.fork(); tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort) - .then(kj::mvCapture(paf.fulfiller, - [this, readerOpts](kj::Own>&& portFulfiller, - kj::Own&& addr) { + .then([this, portFulfiller=kj::mv(paf.fulfiller), readerOpts](kj::Own&& addr) mutable { auto listener = addr->listen(); portFulfiller->fulfill(listener->getPort()); acceptLoop(kj::mv(listener), readerOpts); - }))); + })); } Impl(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize, @@ -271,14 +282,13 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer, context(EzRpcContext::getThreadLocal()), portPromise(kj::Promise(port).fork()), tasks(*this) { - acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd), readerOpts); + acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER), + readerOpts); } void acceptLoop(kj::Own&& listener, ReaderOptions readerOpts) { auto ptr = listener.get(); - tasks.add(ptr->accept().then(kj::mvCapture(kj::mv(listener), - [this, readerOpts](kj::Own&& listener, - kj::Own&& connection) { + tasks.add(ptr->accept().then([this, listener=kj::mv(listener), readerOpts](kj::Own&& connection) mutable { acceptLoop(kj::mv(listener), readerOpts); auto server = kj::heap(kj::mv(connection), *this, readerOpts); @@ -286,7 +296,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer, // Arrange to destroy the server context when all references are gone, or when the // EzRpcServer is destroyed (which will destroy the TaskSet). tasks.add(server->network.onDisconnect().attach(kj::mv(server))); - }))); + })); } Capability::Client restore(AnyPointer::Reader objectId) override { diff --git a/c++/src/capnp/ez-rpc.h b/c++/src/capnp/ez-rpc.h index fba5ace582..ef2649239a 100644 --- a/c++/src/capnp/ez-rpc.h +++ b/c++/src/capnp/ez-rpc.h @@ -19,16 +19,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_EZ_RPC_H_ -#define CAPNP_EZ_RPC_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "rpc.h" #include "message.h" +CAPNP_BEGIN_HEADER + struct sockaddr; namespace kj { class AsyncIoProvider; class LowLevelAsyncIoProvider; } @@ -130,10 +127,10 @@ class EzRpcClient { // Get the server's main (aka "bootstrap") interface. template - typename Type::Client importCap(kj::StringPtr name) - KJ_DEPRECATED("Change your server to export a main interface, then use getMain() instead."); - Capability::Client importCap(kj::StringPtr name) - KJ_DEPRECATED("Change your server to export a main interface, then use getMain() instead."); + typename Type::Client importCap(kj::StringPtr name) CAPNP_DEPRECATED( + "Change your server to export a main interface, then use getMain() instead."); + Capability::Client importCap(kj::StringPtr name) CAPNP_DEPRECATED( + "Change your server to export a main interface, then use getMain() instead."); // ** DEPRECATED ** // // Ask the sever for the capability with the given name. You may specify a type to automatically @@ -198,12 +195,12 @@ class EzRpcServer { explicit EzRpcServer(kj::StringPtr bindAddress, uint defaultPort = 0, ReaderOptions readerOpts = ReaderOptions()) - KJ_DEPRECATED("Please specify a main interface for your server."); + CAPNP_DEPRECATED("Please specify a main interface for your server."); EzRpcServer(struct sockaddr* bindAddress, uint addrSize, ReaderOptions readerOpts = ReaderOptions()) - KJ_DEPRECATED("Please specify a main interface for your server."); + CAPNP_DEPRECATED("Please specify a main interface for your server."); EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts = ReaderOptions()) - KJ_DEPRECATED("Please specify a main interface for your server."); + CAPNP_DEPRECATED("Please specify a main interface for your server."); ~EzRpcServer() noexcept(false); @@ -251,4 +248,4 @@ inline typename Type::Client EzRpcClient::importCap(kj::StringPtr name) { } // namespace capnp -#endif // CAPNP_EZ_RPC_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/fuzz-test.c++ b/c++/src/capnp/fuzz-test.c++ index 6c1e2af796..b114ceb425 100644 --- a/c++/src/capnp/fuzz-test.c++ +++ b/c++/src/capnp/fuzz-test.c++ @@ -19,6 +19,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include #include #include "message.h" diff --git a/c++/src/capnp/generated-header-support.h b/c++/src/capnp/generated-header-support.h index 51b6dd7c11..b103d4bc8f 100644 --- a/c++/src/capnp/generated-header-support.h +++ b/c++/src/capnp/generated-header-support.h @@ -21,12 +21,7 @@ // This file is included from all generated headers. -#ifndef CAPNP_GENERATED_HEADER_SUPPORT_H_ -#define CAPNP_GENERATED_HEADER_SUPPORT_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "raw-schema.h" #include "layout.h" @@ -36,6 +31,9 @@ #include "any.h" #include #include +#include + +CAPNP_BEGIN_HEADER namespace capnp { @@ -139,7 +137,7 @@ struct BrandBindingFor_, Kind::LIST> { template struct BrandBindingFor_ { static constexpr RawBrandedSchema::Binding get(uint16_t listDepth) { - return { 15, listDepth, nullptr }; + return { 15, listDepth, &rawSchema().defaultBrand }; } }; @@ -211,7 +209,7 @@ template class ConstStruct { public: ConstStruct() = delete; - KJ_DISALLOW_COPY(ConstStruct); + KJ_DISALLOW_COPY_AND_MOVE(ConstStruct); inline explicit constexpr ConstStruct(const word* ptr): ptr(ptr) {} inline typename T::Reader get() const { @@ -230,7 +228,7 @@ template class ConstList { public: ConstList() = delete; - KJ_DISALLOW_COPY(ConstList); + KJ_DISALLOW_COPY_AND_MOVE(ConstList); inline explicit constexpr ConstList(const word* ptr): ptr(ptr) {} inline typename List::Reader get() const { @@ -249,7 +247,7 @@ template class ConstText { public: ConstText() = delete; - KJ_DISALLOW_COPY(ConstText); + KJ_DISALLOW_COPY_AND_MOVE(ConstText); inline explicit constexpr ConstText(const word* ptr): ptr(ptr) {} inline Text::Reader get() const { @@ -277,7 +275,7 @@ template class ConstData { public: ConstData() = delete; - KJ_DISALLOW_COPY(ConstData); + KJ_DISALLOW_COPY_AND_MOVE(ConstData); inline explicit constexpr ConstData(const word* ptr): ptr(ptr) {} inline Data::Reader get() const { @@ -317,7 +315,7 @@ inline constexpr uint sizeInWords() { } // namespace capnp -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) // MSVC doesn't understand floating-point constexpr yet. // // TODO(msvc): Remove this hack when MSVC is fixed. @@ -328,7 +326,7 @@ inline constexpr uint sizeInWords() { #define CAPNP_NON_INT_CONSTEXPR_DEF_INIT(value) #endif -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) // TODO(msvc): A little hack to allow MSVC to use C++14 return type deduction in cases where the // explicit type exposes bugs in the compiler. #define CAPNP_AUTO_IF_MSVC(...) auto @@ -336,6 +334,13 @@ inline constexpr uint sizeInWords() { #define CAPNP_AUTO_IF_MSVC(...) __VA_ARGS__ #endif +// TODO(msvc): MSVC does not even expect constexprs to have definitions below C++17. +#if (__cplusplus < 201703L) && !(defined(_MSC_VER) && !defined(__clang__)) +#define CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL 1 +#else +#define CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL 0 +#endif + #if CAPNP_LITE #define CAPNP_DECLARE_SCHEMA(id) \ @@ -351,12 +356,11 @@ inline constexpr uint sizeInWords() { static inline ::capnp::word const* encodedSchema() { return bp_##id; } \ } -#if _MSC_VER -// TODO(msvc): MSVC dosen't expect constexprs to have definitions. -#define CAPNP_DEFINE_ENUM(type, id) -#else +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #define CAPNP_DEFINE_ENUM(type, id) \ constexpr uint64_t EnumInfo::typeId +#else +#define CAPNP_DEFINE_ENUM(type, id) #endif #define CAPNP_DECLARE_STRUCT_HEADER(id, dataWordSize_, pointerCount_) \ @@ -382,9 +386,14 @@ inline constexpr uint sizeInWords() { static inline ::capnp::word const* encodedSchema() { return bp_##id; } \ static constexpr ::capnp::_::RawSchema const* schema = &s_##id; \ } + +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #define CAPNP_DEFINE_ENUM(type, id) \ constexpr uint64_t EnumInfo::typeId; \ constexpr ::capnp::_::RawSchema const* EnumInfo::schema +#else +#define CAPNP_DEFINE_ENUM(type, id) +#endif #define CAPNP_DECLARE_STRUCT_HEADER(id, dataWordSize_, pointerCount_) \ struct IsStruct; \ @@ -404,4 +413,12 @@ inline constexpr uint sizeInWords() { #endif // CAPNP_LITE, else -#endif // CAPNP_GENERATED_HEADER_SUPPORT_H_ +namespace capnp { +namespace schemas { +CAPNP_DECLARE_SCHEMA(995f9a3377c0b16e); +// HACK: Forward-declare the RawSchema for StreamResult, from stream.capnp. This allows capnp +// files which declare streaming methods to avoid including stream.capnp.h. +} +} + +CAPNP_END_HEADER diff --git a/c++/src/capnp/layout.c++ b/c++/src/capnp/layout.c++ index 8ad4bb7952..e2fa84df64 100644 --- a/c++/src/capnp/layout.c++ +++ b/c++/src/capnp/layout.c++ @@ -34,25 +34,36 @@ namespace capnp { namespace _ { // private #if !CAPNP_LITE -static BrokenCapFactory* brokenCapFactory = nullptr; +static BrokenCapFactory* globalBrokenCapFactory = nullptr; // Horrible hack: We need to be able to construct broken caps without any capability context, // but we can't have a link-time dependency on libcapnp-rpc. void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory) { // Called from capability.c++ when the capability API is used, to make sure that layout.c++ // is ready for it. May be called multiple times but always with the same value. -#if __GNUC__ - __atomic_store_n(&brokenCapFactory, &factory, __ATOMIC_RELAXED); +#if __GNUC__ || defined(__clang__) + __atomic_store_n(&globalBrokenCapFactory, &factory, __ATOMIC_RELAXED); #elif _MSC_VER - *static_cast(&brokenCapFactory) = &factory; + *static_cast(&globalBrokenCapFactory) = &factory; #else #error "Platform not supported" #endif } +static BrokenCapFactory* readGlobalBrokenCapFactoryForLayoutCpp() { +#if __GNUC__ || defined(__clang__) + // Thread-sanitizer doesn't have the right information to know this is safe without doing an + // atomic read. https://groups.google.com/g/capnproto/c/634juhn5ap0/m/pyRiwWl1AAAJ + return __atomic_load_n(&globalBrokenCapFactory, __ATOMIC_RELAXED); +#else + return globalBrokenCapFactory; +#endif +} + } // namespace _ (private) const uint ClientHook::NULL_CAPABILITY_BRAND = 0; +const uint ClientHook::BROKEN_CAPABILITY_BRAND = 0; // Defined here rather than capability.c++ so that we can safely call isNull() in this file. namespace _ { // private @@ -67,6 +78,16 @@ namespace _ { // private // ======================================================================================= +#if __GNUC__ >= 8 && !__clang__ +// GCC 8 introduced a warning which complains whenever we try to memset() or memcpy() a +// WirePointer, because we deleted the regular copy constructor / assignment operator. Weirdly, if +// I remove those deletions, GCC *still* complains that WirePointer is non-trivial. I don't +// understand why -- maybe because WireValue has private members? We don't want to make WireValue's +// member public, but memset() and memcpy() on it are certainly valid and desirable, so we'll just +// have to disable the warning I guess. +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif + struct WirePointer { // A pointer, in exactly the format in which it appears on the wire. @@ -153,7 +174,7 @@ struct WirePointer { reinterpret_cast(segment->getStartPtr())); KJ_DREQUIRE(reinterpret_cast(target) <= reinterpret_cast(segment->getStartPtr() + segment->getSize())); - offsetAndKind.set(((target - reinterpret_cast(this) - 1) << 2) | kind); + offsetAndKind.set((static_cast(target - reinterpret_cast(this) - 1) << 2) | kind); } KJ_ALWAYS_INLINE(void setKindWithZeroOffset(Kind kind)) { offsetAndKind.set(kind); @@ -300,6 +321,13 @@ static_assert(unboundAs(POINTERS * BITS_PER_POINTER / BITS_PER_BYTE / BY sizeof(WirePointer), "BITS_PER_POINTER is wrong."); +#define OUT_OF_BOUNDS_ERROR_DETAIL \ + "This usually indicates that " \ + "the input data was corrupted, used a different encoding than specified (e.g. " \ + "packed vs. non-packed), or was not a Cap'n Proto message to begin with. Note " \ + "that this error is NOT due to a schema mismatch; the input is invalid " \ + "regardless of schema." + namespace { static const union { @@ -366,15 +394,15 @@ struct WireHelpers { #endif static KJ_ALWAYS_INLINE(void zeroMemory(byte* ptr, ByteCount32 count)) { - memset(ptr, 0, unbound(count / BYTES)); + if (count != ZERO * BYTES) memset(ptr, 0, unbound(count / BYTES)); } static KJ_ALWAYS_INLINE(void zeroMemory(word* ptr, WordCountN<29> count)) { - memset(ptr, 0, unbound(count * BYTES_PER_WORD / BYTES)); + if (count != ZERO * WORDS) memset(ptr, 0, unbound(count * BYTES_PER_WORD / BYTES)); } static KJ_ALWAYS_INLINE(void zeroMemory(WirePointer* ptr, WirePointerCountN<29> count)) { - memset(ptr, 0, unbound(count * BYTES_PER_POINTER / BYTES)); + if (count != ZERO * POINTERS) memset(ptr, 0, unbound(count * BYTES_PER_POINTER / BYTES)); } static KJ_ALWAYS_INLINE(void zeroMemory(WirePointer* ptr)) { @@ -383,20 +411,20 @@ struct WireHelpers { template static inline void zeroMemory(kj::ArrayPtr array) { - memset(array.begin(), 0, array.size() * sizeof(array[0])); + if (array.size() != 0u) memset(array.begin(), 0, array.size() * sizeof(array[0])); } static KJ_ALWAYS_INLINE(void copyMemory(byte* to, const byte* from, ByteCount32 count)) { - memcpy(to, from, unbound(count / BYTES)); + if (count != ZERO * BYTES) memcpy(to, from, unbound(count / BYTES)); } static KJ_ALWAYS_INLINE(void copyMemory(word* to, const word* from, WordCountN<29> count)) { - memcpy(to, from, unbound(count * BYTES_PER_WORD / BYTES)); + if (count != ZERO * WORDS) memcpy(to, from, unbound(count * BYTES_PER_WORD / BYTES)); } static KJ_ALWAYS_INLINE(void copyMemory(WirePointer* to, const WirePointer* from, WirePointerCountN<29> count)) { - memcpy(to, from, unbound(count * BYTES_PER_POINTER / BYTES)); + if (count != ZERO * POINTERS) memcpy(to, from, unbound(count * BYTES_PER_POINTER / BYTES)); } template @@ -407,14 +435,14 @@ struct WireHelpers { // TODO(cleanup): Turn these into a .copyTo() method of ArrayPtr? template static inline void copyMemory(T* to, kj::ArrayPtr from) { - memcpy(to, from.begin(), from.size() * sizeof(from[0])); + if (from.size() != 0u) memcpy(to, from.begin(), from.size() * sizeof(from[0])); } template static inline void copyMemory(T* to, kj::ArrayPtr from) { - memcpy(to, from.begin(), from.size() * sizeof(from[0])); + if (from.size() != 0u) memcpy(to, from.begin(), from.size() * sizeof(from[0])); } static KJ_ALWAYS_INLINE(void copyMemory(char* to, kj::StringPtr from)) { - memcpy(to, from.begin(), from.size() * sizeof(from[0])); + if (from.size() != 0u) memcpy(to, from.begin(), from.size() * sizeof(from[0])); } static KJ_ALWAYS_INLINE(bool boundsCheck( @@ -431,7 +459,9 @@ struct WireHelpers { static KJ_ALWAYS_INLINE(word* allocate( WirePointer*& ref, SegmentBuilder*& segment, CapTableBuilder* capTable, SegmentWordCount amount, WirePointer::Kind kind, BuilderArena* orphanArena)) { - // Allocate space in the message for a new object, creating far pointers if necessary. + // Allocate space in the message for a new object, creating far pointers if necessary. The + // space is guaranteed to be zero'd (because MessageBuilder implementations are required to + // return zero'd memory). // // * `ref` starts out being a reference to the pointer which shall be assigned to point at the // new object. On return, `ref` points to a pointer which needs to be initialized with @@ -461,6 +491,7 @@ struct WireHelpers { return reinterpret_cast(ref); } + KJ_ASSUME(segment != nullptr); word* ptr = segment->allocate(amount); if (ptr == nullptr) { @@ -554,7 +585,8 @@ struct WireHelpers { const word* ptr = ref->farTarget(segment); auto padWords = (ONE + bounded(ref->isDoubleFar())) * POINTER_SIZE_IN_WORDS; KJ_REQUIRE(boundsCheck(segment, ptr, padWords), - "Message contains out-of-bounds far pointer.") { + "Message contains out-of-bounds far pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return nullptr; } @@ -766,7 +798,8 @@ struct WireHelpers { switch (ref->kind()) { case WirePointer::STRUCT: { KJ_REQUIRE(boundsCheck(segment, ptr, ref->structRef.wordSize()), - "Message contained out-of-bounds struct pointer.") { + "Message contained out-of-bounds struct pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } result.addWords(ref->structRef.wordSize()); @@ -792,7 +825,8 @@ struct WireHelpers { upgradeBound(ref->listRef.elementCount()) * dataBitsPerElement(ref->listRef.elementSize())); KJ_REQUIRE(boundsCheck(segment, ptr, totalWords), - "Message contained out-of-bounds list pointer.") { + "Message contained out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } result.addWords(totalWords); @@ -802,7 +836,8 @@ struct WireHelpers { auto count = ref->listRef.elementCount() * (POINTERS / ELEMENTS); KJ_REQUIRE(boundsCheck(segment, ptr, count * WORDS_PER_POINTER), - "Message contained out-of-bounds list pointer.") { + "Message contained out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } @@ -817,7 +852,8 @@ struct WireHelpers { case ElementSize::INLINE_COMPOSITE: { auto wordCount = ref->listRef.inlineCompositeWordCount(); KJ_REQUIRE(boundsCheck(segment, ptr, wordCount + POINTER_SIZE_IN_WORDS), - "Message contained out-of-bounds list pointer.") { + "Message contained out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } @@ -832,13 +868,14 @@ struct WireHelpers { auto actualSize = elementTag->structRef.wordSize() / ELEMENTS * upgradeBound(count); KJ_REQUIRE(actualSize <= wordCount, - "Struct list pointer's elements overran size.") { + "Struct list pointer's elements overran size. " + OUT_OF_BOUNDS_ERROR_DETAIL) { return result; } // We count the actual size rather than the claimed word count because that's what // we'll end up with if we make a copy. - result.addWords(wordCount + POINTER_SIZE_IN_WORDS); + result.addWords(actualSize + POINTER_SIZE_IN_WORDS); WordCount dataSize = elementTag->structRef.dataSize.get(); WirePointerCount pointerCount = elementTag->structRef.ptrCount.get(); @@ -1107,7 +1144,7 @@ struct WireHelpers { word* oldPtr = followFars(oldRef, refTarget, oldSegment); KJ_REQUIRE(oldRef->kind() == WirePointer::STRUCT, - "Message contains non-struct pointer where struct pointer was expected.") { + "Schema mismatch: Message contains non-struct pointer where struct pointer was expected.") { goto useDefault; } @@ -1249,7 +1286,7 @@ struct WireHelpers { word* ptr = followFars(ref, origRefTarget, segment); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Called getWritableListPointer() but existing pointer is not a list.") { + "Schema mismatch: Called getWritableListPointer() but existing pointer is not a list.") { goto useDefault; } @@ -1277,8 +1314,8 @@ struct WireHelpers { case ElementSize::BIT: KJ_FAIL_REQUIRE( - "Found struct list where bit list was expected; upgrading boolean lists to structs " - "is no longer supported.") { + "Schema mismatch: Found struct list where bit list was expected; upgrading boolean " + "lists to structs is no longer supported.") { goto useDefault; } break; @@ -1288,14 +1325,14 @@ struct WireHelpers { case ElementSize::FOUR_BYTES: case ElementSize::EIGHT_BYTES: KJ_REQUIRE(dataSize >= ONE * WORDS, - "Existing list value is incompatible with expected type.") { + "Schema mismatch: Existing list value is incompatible with expected type.") { goto useDefault; } break; case ElementSize::POINTER: KJ_REQUIRE(pointerCount >= ONE * POINTERS, - "Existing list value is incompatible with expected type.") { + "Schema mismatch: Existing list value is incompatible with expected type.") { goto useDefault; } // Adjust the pointer to point at the reference segment. @@ -1318,20 +1355,20 @@ struct WireHelpers { if (elementSize == ElementSize::BIT) { KJ_REQUIRE(oldSize == ElementSize::BIT, - "Found non-bit list where bit list was expected.") { + "Schema mismatch: Found non-bit list where bit list was expected.") { goto useDefault; } } else { KJ_REQUIRE(oldSize != ElementSize::BIT, - "Found bit list where non-bit list was expected.") { + "Schema mismatch: Found bit list where non-bit list was expected.") { goto useDefault; } KJ_REQUIRE(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS, - "Existing list value is incompatible with expected type.") { + "Schema mismatch: Existing list value is incompatible with expected type.") { goto useDefault; } KJ_REQUIRE(pointerCount >= pointersPerElement(elementSize) * ELEMENTS, - "Existing list value is incompatible with expected type.") { + "Schema mismatch: Existing list value is incompatible with expected type.") { goto useDefault; } } @@ -1369,7 +1406,8 @@ struct WireHelpers { word* ptr = followFars(ref, origRefTarget, segment); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Called getWritableListPointerAnySize() but existing pointer is not a list.") { + "Schema mismatch: Called getWritableListPointerAnySize() but existing pointer is not a " + "list.") { goto useDefault; } @@ -1425,7 +1463,8 @@ struct WireHelpers { word* oldPtr = followFars(oldRef, origRefTarget, oldSegment); KJ_REQUIRE(oldRef->kind() == WirePointer::LIST, - "Called getList{Field,Element}() but existing pointer is not a list.") { + "Schema mismatch: Called getList{Field,Element}() but existing pointer is not a " + "list.") { goto useDefault; } @@ -1520,8 +1559,8 @@ struct WireHelpers { // Upgrading to an inline composite list. KJ_REQUIRE(oldSize != ElementSize::BIT, - "Found bit list where struct list was expected; upgrading boolean lists to structs " - "is no longer supported.") { + "Schema mismatch: Found bit list where struct list was expected; upgrading boolean " + "lists to structs is no longer supported.") { goto useDefault; } @@ -1599,7 +1638,8 @@ struct WireHelpers { // Initialize the pointer. ref->listRef.set(ElementSize::BYTE, byteSize * (ONE * ELEMENTS / BYTES)); - // Build the Text::Builder. This will initialize the NUL terminator. + // Build the Text::Builder. Note that since allocate()ed memory is pre-zero'd, we don't need + // to initialize the NUL terminator. return { segment, Text::Builder(reinterpret_cast(ptr), unbound(size / BYTES)) }; } @@ -1638,11 +1678,12 @@ struct WireHelpers { byte* bptr = reinterpret_cast(ptr); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Called getText{Field,Element}() but existing pointer is not a list.") { + "Schema mismatch: Called getText{Field,Element}() but existing pointer is not a list.") { goto useDefault; } KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, - "Called getText{Field,Element}() but existing list pointer is not byte-sized.") { + "Schema mismatch: Called getText{Field,Element}() but existing list pointer is not " + "byte-sized.") { goto useDefault; } @@ -1709,11 +1750,12 @@ struct WireHelpers { word* ptr = followFars(ref, refTarget, segment); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Called getData{Field,Element}() but existing pointer is not a list.") { + "Schema mismatch: Called getData{Field,Element}() but existing pointer is not a list.") { goto useDefault; } KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, - "Called getData{Field,Element}() but existing list pointer is not byte-sized.") { + "Schema mismatch: Called getData{Field,Element}() but existing list pointer is not " + "byte-sized.") { goto useDefault; } @@ -1940,7 +1982,8 @@ struct WireHelpers { } KJ_REQUIRE(boundsCheck(srcSegment, ptr, src->structRef.wordSize()), - "Message contained out-of-bounds struct pointer.") { + "Message contained out-of-bounds struct pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } return setStructPointer(dstSegment, dstCapTable, dst, @@ -1964,7 +2007,8 @@ struct WireHelpers { const WirePointer* tag = reinterpret_cast(ptr); KJ_REQUIRE(boundsCheck(srcSegment, ptr, wordCount + POINTER_SIZE_IN_WORDS), - "Message contains out-of-bounds list pointer.") { + "Message contains out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2007,7 +2051,8 @@ struct WireHelpers { auto wordCount = roundBitsUpToWords(upgradeBound(elementCount) * step); KJ_REQUIRE(boundsCheck(srcSegment, ptr, wordCount), - "Message contains out-of-bounds list pointer.") { + "Message contains out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2151,12 +2196,14 @@ struct WireHelpers { } KJ_REQUIRE(ref->kind() == WirePointer::STRUCT, - "Message contains non-struct pointer where struct pointer was expected.") { + "Schema mismatch: Message contains non-struct pointer where struct pointer" + "was expected.") { goto useDefault; } KJ_REQUIRE(boundsCheck(segment, ptr, ref->structRef.wordSize()), - "Message contained out-of-bounds struct pointer.") { + "Message contained out-of-bounds struct pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2174,6 +2221,8 @@ struct WireHelpers { const WirePointer* ref, int nestingLimit)) { kj::Maybe> maybeCap; + auto brokenCapFactory = readGlobalBrokenCapFactoryForLayoutCpp(); + KJ_REQUIRE(brokenCapFactory != nullptr, "Trying to read capabilities without ever having created a capability context. " "To read capabilities from a message, you must imbue it with CapReaderContext, or " @@ -2183,7 +2232,8 @@ struct WireHelpers { return brokenCapFactory->newNullCap(); } else if (!ref->isCapability()) { KJ_FAIL_REQUIRE( - "Message contains non-capability pointer where capability pointer was expected.") { + "Schema mismatch: Message contains non-capability pointer where capability pointer was " + "expected.") { break; } return brokenCapFactory->newBrokenCap( @@ -2237,7 +2287,8 @@ struct WireHelpers { } KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Message contains non-list pointer where list pointer was expected.") { + "Schema mismatch: Message contains non-list pointer where list pointer was " + "expected.") { goto useDefault; } @@ -2249,7 +2300,8 @@ struct WireHelpers { const WirePointer* tag = reinterpret_cast(ptr); KJ_REQUIRE(boundsCheck(segment, ptr, wordCount + POINTER_SIZE_IN_WORDS), - "Message contains out-of-bounds list pointer.") { + "Message contains out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2301,18 +2353,16 @@ struct WireHelpers { case ElementSize::FOUR_BYTES: case ElementSize::EIGHT_BYTES: KJ_REQUIRE(tag->structRef.dataSize.get() > ZERO * WORDS, - "Expected a primitive list, but got a list of pointer-only structs.") { + "Schema mismatch: Expected a primitive list, but got a list of pointer-only " + "structs.") { goto useDefault; } break; case ElementSize::POINTER: - // We expected a list of pointers but got a list of structs. Assuming the first field - // in the struct is the pointer we were looking for, we want to munge the pointer to - // point at the first element's pointer section. - ptr += tag->structRef.dataSize.get(); KJ_REQUIRE(tag->structRef.ptrCount.get() > ZERO * POINTERS, - "Expected a pointer list, but got a list of data-only structs.") { + "Schema mismatch: Expected a pointer list, but got a list of data-only " + "structs.") { goto useDefault; } break; @@ -2338,7 +2388,8 @@ struct WireHelpers { auto wordCount = roundBitsUpToWords(upgradeBound(elementCount) * step); KJ_REQUIRE(boundsCheck(segment, ptr, wordCount), - "Message contains out-of-bounds list pointer.") { + "Message contains out-of-bounds list pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2371,11 +2422,11 @@ struct WireHelpers { pointersPerElement(expectedElementSize) * ELEMENTS; KJ_REQUIRE(expectedDataBitsPerElement <= dataSize, - "Message contained list with incompatible element type.") { + "Schema mismatch: Message contained list with incompatible element type.") { goto useDefault; } KJ_REQUIRE(expectedPointersPerElement <= pointerCount, - "Message contained list with incompatible element type.") { + "Schema mismatch: Message contained list with incompatible element type.") { goto useDefault; } } @@ -2410,17 +2461,19 @@ struct WireHelpers { auto size = ref->listRef.elementCount() * (ONE * BYTES / ELEMENTS); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Message contains non-list pointer where text was expected.") { + "Schema mismatch: Message contains non-list pointer where text was expected.") { goto useDefault; } KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, - "Message contains list pointer of non-bytes where text was expected.") { + "Schema mismatch: Message contains list pointer of non-bytes where text was " + "expected.") { goto useDefault; } KJ_REQUIRE(boundsCheck(segment, ptr, roundBytesUpToWords(size)), - "Message contained out-of-bounds text pointer.") { + "Message contained out-of-bounds text pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2468,17 +2521,19 @@ struct WireHelpers { auto size = ref->listRef.elementCount() * (ONE * BYTES / ELEMENTS); KJ_REQUIRE(ref->kind() == WirePointer::LIST, - "Message contains non-list pointer where data was expected.") { + "Schema mismatch: Message contains non-list pointer where data was expected.") { goto useDefault; } KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, - "Message contains list pointer of non-bytes where data was expected.") { + "Schema mismatch: Message contains list pointer of non-bytes where data was " + "expected.") { goto useDefault; } KJ_REQUIRE(boundsCheck(segment, ptr, roundBytesUpToWords(size)), - "Message contained out-of-bounds data pointer.") { + "Message contained out-of-bounds data pointer. " + OUT_OF_BOUNDS_ERROR_DETAIL) { goto useDefault; } @@ -2758,12 +2813,23 @@ bool PointerReader::isCanonical(const word **readHead) { // The pointer is null, we are canonical and do not read return true; case PointerType::STRUCT: { - bool dataTrunc, ptrTrunc; + bool dataTrunc = false, ptrTrunc = false; auto structReader = this->getStruct(nullptr); if (structReader.getDataSectionSize() == ZERO * BITS && structReader.getPointerSectionSize() == ZERO * POINTERS) { return reinterpret_cast(this->pointer) == structReader.getLocation(); } else { + // Fun fact: Once this call to isCanonical() returns, Clang may re-order the evaluation of + // the && operators. In theory this is wrong because && is short-circuiting, but Clang + // apparently sees that there are no side effects to the right of &&, so decides it is + // safe to skip short-circuiting. It turns out, though, this is observable under + // valgrind: if we don't initialize `dataTrunc` when declaring it above, then valgrind + // reports "Conditional jump or move depends on uninitialised value(s)". Specifically + // this happens in cases where structReader.isCanonical() returns false -- it is allowed + // to skip initializing `dataTrunc` in that case. The short-circuiting && should mean + // that we don't read `dataTrunc` in that case, except Clang's optimizations. Ultimately + // the uninitialized read is fine because eventually the whole expression evaluates false + // either way. But, to make valgrind happy, we initialize the bools above... return structReader.isCanonical(readHead, readHead, &dataTrunc, &ptrTrunc) && dataTrunc && ptrTrunc; } } @@ -2838,6 +2904,17 @@ void StructBuilder::transferContentFrom(StructBuilder other) { void StructBuilder::copyContentFrom(StructReader other) { // Determine the amount of data the builders have in common. auto sharedDataSize = kj::min(dataSize, other.dataSize); + auto sharedPointerCount = kj::min(pointerCount, other.pointerCount); + + if ((sharedDataSize > ZERO * BITS && other.data == data) || + (sharedPointerCount > ZERO * POINTERS && other.pointers == pointers)) { + // At least one of the section pointers is pointing to ourself. Verify that the other is two + // (but ignore empty sections). + KJ_ASSERT((sharedDataSize == ZERO * BITS || other.data == data) && + (sharedPointerCount == ZERO * POINTERS || other.pointers == pointers)); + // So `other` appears to be a reader for this same struct. No coping is needed. + return; + } if (dataSize > sharedDataSize) { // Since the target is larger than the source, make sure to zero out the extra bits that the @@ -2867,7 +2944,6 @@ void StructBuilder::copyContentFrom(StructReader other) { WireHelpers::zeroMemory(pointers, pointerCount); // Copy the pointers. - auto sharedPointerCount = kj::min(pointerCount, other.pointerCount); for (auto i: kj::zeroTo(sharedPointerCount)) { WireHelpers::copyPointer(segment, capTable, pointers + i, other.segment, other.capTable, other.pointers + i, other.nestingLimit); @@ -3045,7 +3121,7 @@ ListBuilder ListBuilder::imbue(CapTableBuilder* capTable) { Text::Reader ListReader::asText() { KJ_REQUIRE(structDataSize == G(8) * BITS && structPointerCount == ZERO * POINTERS, - "Expected Text, got list of non-bytes.") { + "Schema mismatch: Expected Text, got list of non-bytes.") { return Text::Reader(); } @@ -3067,16 +3143,16 @@ Text::Reader ListReader::asText() { Data::Reader ListReader::asData() { KJ_REQUIRE(structDataSize == G(8) * BITS && structPointerCount == ZERO * POINTERS, - "Expected Text, got list of non-bytes.") { + "Schema mismatch: Expected Text, got list of non-bytes.") { return Data::Reader(); } return Data::Reader(reinterpret_cast(ptr), unbound(elementCount / ELEMENTS)); } -kj::ArrayPtr ListReader::asRawBytes() { +kj::ArrayPtr ListReader::asRawBytes() const { KJ_REQUIRE(structPointerCount == ZERO * POINTERS, - "Expected data only, got pointers.") { + "Schema mismatch: Expected data only, got pointers.") { return kj::ArrayPtr(); } @@ -3096,11 +3172,6 @@ StructReader ListReader::getStructElement(ElementCount index) const { const WirePointer* structPointers = reinterpret_cast(structData + structDataSize / BITS_PER_BYTE); - // This check should pass if there are no bugs in the list pointer validation code. - KJ_DASSERT(structPointerCount == ZERO * POINTERS || - (uintptr_t)structPointers % sizeof(void*) == 0, - "Pointer section of struct list element not aligned."); - KJ_DASSERT(indexBit % BITS_PER_BYTE == ZERO * BITS); return StructReader( segment, capTable, structData, structPointers, @@ -3108,6 +3179,64 @@ StructReader ListReader::getStructElement(ElementCount index) const { nestingLimit - 1); } +MessageSizeCounts ListReader::totalSize() const { + // TODO(cleanup): This is kind of a lot of logic duplicated from WireHelpers::totalSize(), but + // it's unclear how to share it effectively. + + MessageSizeCounts result = { ZERO * WORDS, 0 }; + + switch (elementSize) { + case ElementSize::VOID: + // Nothing. + break; + case ElementSize::BIT: + case ElementSize::BYTE: + case ElementSize::TWO_BYTES: + case ElementSize::FOUR_BYTES: + case ElementSize::EIGHT_BYTES: + result.addWords(WireHelpers::roundBitsUpToWords( + upgradeBound(elementCount) * dataBitsPerElement(elementSize))); + break; + case ElementSize::POINTER: { + auto count = elementCount * (POINTERS / ELEMENTS); + result.addWords(count * WORDS_PER_POINTER); + + for (auto i: kj::zeroTo(count)) { + result += WireHelpers::totalSize(segment, reinterpret_cast(ptr) + i, + nestingLimit); + } + break; + } + case ElementSize::INLINE_COMPOSITE: { + // Don't forget to count the tag word. + auto wordSize = upgradeBound(elementCount) * step / BITS_PER_WORD; + result.addWords(wordSize + POINTER_SIZE_IN_WORDS); + + if (structPointerCount > ZERO * POINTERS) { + const word* pos = reinterpret_cast(ptr); + for (auto i KJ_UNUSED: kj::zeroTo(elementCount)) { + pos += structDataSize / BITS_PER_WORD; + + for (auto j KJ_UNUSED: kj::zeroTo(structPointerCount)) { + result += WireHelpers::totalSize(segment, reinterpret_cast(pos), + nestingLimit); + pos += POINTER_SIZE_IN_WORDS; + } + } + } + break; + } + } + + if (segment != nullptr) { + // This traversal should not count against the read limit, because it's highly likely that + // the caller is going to traverse the object again, e.g. to copy it. + segment->unread(result.wordCount); + } + + return result; +} + CapTableReader* ListReader::getCapTable() { return capTable; } @@ -3561,7 +3690,7 @@ bool OrphanBuilder::truncate(ElementCount uncheckedSize, bool isText) { return size == ZERO * ELEMENTS; } - KJ_REQUIRE(ref->kind() == WirePointer::LIST, "Can't truncate non-list.") { + KJ_REQUIRE(ref->kind() == WirePointer::LIST, "Schema mismatch: Can't truncate non-list.") { return false; } diff --git a/c++/src/capnp/layout.h b/c++/src/capnp/layout.h index 99dc533b2b..7a27f68a1f 100644 --- a/c++/src/capnp/layout.h +++ b/c++/src/capnp/layout.h @@ -26,18 +26,16 @@ // as does other parts of the Cap'n proto library which provide a higher-level interface for // dynamic introspection. -#ifndef CAPNP_LAYOUT_H_ -#define CAPNP_LAYOUT_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include #include "common.h" #include "blob.h" #include "endian.h" +#include // work-around macro conflict with `VOID` + +CAPNP_BEGIN_HEADER #if (defined(__mips__) || defined(__hppa__)) && !defined(CAPNP_CANONICALIZE_NAN) #define CAPNP_CANONICALIZE_NAN 1 @@ -60,9 +58,7 @@ namespace capnp { -#if !CAPNP_LITE class ClientHook; -#endif // !CAPNP_LITE namespace _ { // private @@ -314,15 +310,12 @@ inline double unmask(uint64_t value, uint64_t mask) { class CapTableReader { public: -#if !CAPNP_LITE virtual kj::Maybe> extractCap(uint index) = 0; // Extract the capability at the given index. If the index is invalid, returns null. -#endif // !CAPNP_LITE }; class CapTableBuilder: public CapTableReader { public: -#if !CAPNP_LITE virtual uint injectCap(kj::Own&& cap) = 0; // Add the capability to the message and return its index. If the same ClientHook is injected // twice, this may return the same index both times, but in this case dropCap() needs to be @@ -330,7 +323,6 @@ class CapTableBuilder: public CapTableReader { virtual void dropCap(uint index) = 0; // Remove a capability injected earlier. Called when the pointer is overwritten or zero'd out. -#endif // !CAPNP_LITE }; // ------------------------------------------------------------------- @@ -598,8 +590,8 @@ class StructReader { inline StructDataBitCount getDataSectionSize() const { return dataSize; } inline StructPointerCount getPointerSectionSize() const { return pointerCount; } - inline kj::ArrayPtr getDataSectionAsBlob(); - inline _::ListReader getPointerSectionAsList(); + inline kj::ArrayPtr getDataSectionAsBlob() const; + inline _::ListReader getPointerSectionAsList() const; kj::Array canonicalize(); @@ -625,9 +617,7 @@ class StructReader { MessageSizeCounts totalSize() const; // Return the total size of the struct and everything to which it points. Does not count far // pointer overhead. This is useful for deciding how much space is needed to copy the struct - // into a flat array. However, the caller is advised NOT to treat this value as secure. Instead, - // use the result as a hint for allocating the first segment, do the copy, and then throw an - // exception if it overruns. + // into a flat array. CapTableReader* getCapTable(); // Gets the capability context in which this object is operating. @@ -783,7 +773,7 @@ class ListReader { Data::Reader asData(); // Reinterpret the list as a blob. Throws an exception if the elements are not byte-sized. - kj::ArrayPtr asRawBytes(); + kj::ArrayPtr asRawBytes() const; template KJ_ALWAYS_INLINE(T getDataElement(ElementCount index) const); @@ -793,6 +783,9 @@ class ListReader { StructReader getStructElement(ElementCount index) const; + MessageSizeCounts totalSize() const; + // Like StructReader::totalSize(). Note that for struct lists, the size includes the list tag. + CapTableReader* getCapTable(); // Gets the capability context in which this object is operating. @@ -1080,12 +1073,12 @@ inline PointerBuilder StructBuilder::getPointerField(StructPointerOffset ptrInde // ------------------------------------------------------------------- -inline kj::ArrayPtr StructReader::getDataSectionAsBlob() { +inline kj::ArrayPtr StructReader::getDataSectionAsBlob() const { return kj::ArrayPtr(reinterpret_cast(data), unbound(dataSize / BITS_PER_BYTE / BYTES)); } -inline _::ListReader StructReader::getPointerSectionAsList() { +inline _::ListReader StructReader::getPointerSectionAsList() const { return _::ListReader(segment, capTable, pointers, pointerCount * (ONE * ELEMENTS / POINTERS), ONE * POINTERS * BITS_PER_POINTER / ELEMENTS, ZERO * BITS, ONE * POINTERS, ElementSize::POINTER, nestingLimit); @@ -1234,8 +1227,12 @@ inline Void ListReader::getDataElement(ElementCount index) const { } inline PointerReader ListReader::getPointerElement(ElementCount index) const { + // If the list elements have data sections we need to skip those. Note that for pointers to be + // present at all (which already must be true if we get here), then `structDataSize` must be a + // whole number of words, so we don't have to worry about unaligned reads here. + auto offset = structDataSize / BITS_PER_BYTE; return PointerReader(segment, capTable, reinterpret_cast( - ptr + upgradeBound(index) * step / BITS_PER_BYTE), nestingLimit); + ptr + offset + upgradeBound(index) * step / BITS_PER_BYTE), nestingLimit); } // ------------------------------------------------------------------- @@ -1271,4 +1268,4 @@ inline OrphanBuilder& OrphanBuilder::operator=(OrphanBuilder&& other) { } // namespace _ (private) } // namespace capnp -#endif // CAPNP_LAYOUT_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/list.h b/c++/src/capnp/list.h index 23e5e6c10e..2c777f8179 100644 --- a/c++/src/capnp/list.h +++ b/c++/src/capnp/list.h @@ -19,19 +19,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_LIST_H_ -#define CAPNP_LIST_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "layout.h" #include "orphan.h" #include -#ifdef KJ_STD_COMPAT -#include -#endif // KJ_STD_COMPAT + +CAPNP_BEGIN_HEADER namespace capnp { namespace _ { // private @@ -52,6 +46,8 @@ class TemporaryPointer { T value; }; +// By default this isn't compatible with STL algorithms. To add STL support either define +// KJ_STD_COMPAT at the top of your compilation unit or include capnp/compat/std-iterator.h. template class IndexingIterator { public: @@ -124,6 +120,10 @@ struct List { inline Iterator begin() const { return Iterator(this, 0); } inline Iterator end() const { return Iterator(this, size()); } + inline MessageSize totalSize() const { + return reader.totalSize().asPublic(); + } + private: _::ListReader reader; template @@ -220,6 +220,10 @@ struct List { inline Iterator begin() const { return Iterator(this, 0); } inline Iterator end() const { return Iterator(this, size()); } + inline MessageSize totalSize() const { + return reader.totalSize().asPublic(); + } + private: _::ListReader reader; template @@ -343,6 +347,10 @@ struct List, Kind::LIST> { inline Iterator begin() const { return Iterator(this, 0); } inline Iterator end() const { return Iterator(this, size()); } + inline MessageSize totalSize() const { + return reader.totalSize().asPublic(); + } + private: _::ListReader reader; template @@ -388,13 +396,13 @@ struct List, Kind::LIST> { l.set(i++, element); } } - inline void adopt(uint index, Orphan&& value) { + inline void adopt(uint index, Orphan>&& value) { KJ_IREQUIRE(index < size()); builder.getPointerElement(bounded(index) * ELEMENTS).adopt(kj::mv(value.builder)); } - inline Orphan disown(uint index) { + inline Orphan> disown(uint index) { KJ_IREQUIRE(index < size()); - return Orphan(builder.getPointerElement(bounded(index) * ELEMENTS).disown()); + return Orphan>(builder.getPointerElement(bounded(index) * ELEMENTS).disown()); } typedef _::IndexingIterator::Builder> Iterator; @@ -452,6 +460,10 @@ struct List { inline Iterator begin() const { return Iterator(this, 0); } inline Iterator end() const { return Iterator(this, size()); } + inline MessageSize totalSize() const { + return reader.totalSize().asPublic(); + } + private: _::ListReader reader; template @@ -534,13 +546,7 @@ struct List { } // namespace capnp #ifdef KJ_STD_COMPAT -namespace std { - -template -struct iterator_traits> - : public std::iterator {}; - -} // namespace std +#include "compat/std-iterator.h" #endif // KJ_STD_COMPAT -#endif // CAPNP_LIST_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/llvm-fuzzer-testcase.c++ b/c++/src/capnp/llvm-fuzzer-testcase.c++ new file mode 100644 index 0000000000..1d47e276c7 --- /dev/null +++ b/c++/src/capnp/llvm-fuzzer-testcase.c++ @@ -0,0 +1,26 @@ +#include "test-util.h" +#include +#include "serialize.h" +#include +#include + +/* This is the entry point of a fuzz target to be used with libFuzzer + * or another fuzz driver. + * Such a fuzz driver is used by the autotools target capnp-llvm-fuzzer-testcase + * when the environment variable LIB_FUZZING_ENGINE is defined + * for instance LIB_FUZZING_ENGINE=-fsanitize=fuzzer for libFuzzer + */ +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) { + kj::ArrayPtr array(Data, Size); + kj::ArrayInputStream ais(array); + + KJ_IF_MAYBE(e, kj::runCatchingExceptions([&]() { + capnp::InputStreamMessageReader reader(ais); + capnp::_::checkTestMessage(reader.getRoot()); + capnp::_::checkDynamicTestMessage(reader.getRoot(capnp::Schema::from())); + kj::str(reader.getRoot()); + })) { + KJ_LOG(ERROR, "threw"); + } + return 0; +} diff --git a/c++/src/capnp/membrane-test.c++ b/c++/src/capnp/membrane-test.c++ index e0ee3c13f2..9f3bc21583 100644 --- a/c++/src/capnp/membrane-test.c++ +++ b/c++/src/capnp/membrane-test.c++ @@ -88,10 +88,18 @@ protected: context.getResults().setThing(context.getParams().getThing()); return kj::READY_NOW; } + + kj::Promise waitForever(WaitForeverContext context) override { + return kj::NEVER_DONE; + } }; class MembranePolicyImpl: public MembranePolicy, public kj::Refcounted { public: + MembranePolicyImpl() = default; + MembranePolicyImpl(kj::Maybe> revokePromise) + : revokePromise(revokePromise.map([](kj::Promise& p) { return p.fork(); })) {} + kj::Maybe inboundCall(uint64_t interfaceId, uint16_t methodId, Capability::Client target) override { if (interfaceId == capnp::typeId() && methodId == 1) { @@ -113,6 +121,17 @@ public: kj::Own addRef() override { return kj::addRef(*this); } + + kj::Maybe> onRevoked() override { + return revokePromise.map([](kj::ForkedPromise& fork) { + return fork.addBranch(); + }); + } + + bool shouldResolveBeforeRedirecting() override { return true; } + +private: + kj::Maybe> revokePromise; }; void testThingImpl(kj::WaitScope& waitScope, test::TestMembrane::Client membraned, @@ -258,26 +277,55 @@ KJ_TEST("apply membrane using copyOutOfMembrane() on AnyPointer") { }, "inside", "inbound", "inside", "inside"); } +KJ_TEST("MembraneHook::whenMoreResolved returns same value even when called concurrently.") { + TestEnv env; + + auto paf = kj::newPromiseAndFulfiller(); + test::TestMembrane::Client promCap(kj::mv(paf.promise)); + + auto prom = promCap.whenResolved(); + prom = prom.then([promCap = kj::mv(promCap), &env]() mutable { + auto membraned = membrane(kj::mv(promCap), env.policy->addRef()); + auto hook = ClientHook::from(membraned); + + auto arr = kj::heapArrayBuilder>>(2); + arr.add(KJ_ASSERT_NONNULL(hook->whenMoreResolved())); + arr.add(KJ_ASSERT_NONNULL(hook->whenMoreResolved())); + + return kj::joinPromises(arr.finish()).attach(kj::mv(hook)); + }).then([](kj::Vector> hooks) { + auto first = hooks[0].get(); + auto second = hooks[1].get(); + KJ_ASSERT(first == second); + }).eagerlyEvaluate(nullptr); + + auto newClient = kj::heap(); + paf.fulfiller->fulfill(kj::mv(newClient)); + prom.wait(env.waitScope); +} + struct TestRpcEnv { - kj::AsyncIoContext io; + kj::EventLoop loop; + kj::WaitScope waitScope; kj::TwoWayPipe pipe; TwoPartyClient client; TwoPartyClient server; test::TestMembrane::Client membraned; - TestRpcEnv() - : io(kj::setupAsyncIo()), - pipe(io.provider->newTwoWayPipe()), + TestRpcEnv(kj::Maybe> revokePromise = nullptr) + : waitScope(loop), + pipe(kj::newTwoWayPipe()), client(*pipe.ends[0]), server(*pipe.ends[1], - membrane(kj::heap(), kj::refcounted()), + membrane(kj::heap(), + kj::refcounted(kj::mv(revokePromise))), rpc::twoparty::Side::SERVER), membraned(client.bootstrap().castAs()) {} void testThing(kj::Function makeThing, kj::StringPtr localPassThrough, kj::StringPtr localIntercept, kj::StringPtr remotePassThrough, kj::StringPtr remoteIntercept) { - testThingImpl(io.waitScope, membraned, kj::mv(makeThing), + testThingImpl(waitScope, membraned, kj::mv(makeThing), localPassThrough, localIntercept, remotePassThrough, remoteIntercept); } }; @@ -285,7 +333,7 @@ struct TestRpcEnv { KJ_TEST("call remote object inside membrane") { TestRpcEnv env; env.testThing([&]() { - return env.membraned.makeThingRequest().send().wait(env.io.waitScope).getThing(); + return env.membraned.makeThingRequest().send().wait(env.waitScope).getThing(); }, "inside", "inbound", "inside", "inside"); } @@ -317,7 +365,7 @@ KJ_TEST("call remote capability that has passed into and back out of membrane") env.testThing([&]() { auto req = env.membraned.loopbackRequest(); req.setThing(kj::heap("outside")); - return req.send().wait(env.io.waitScope).getThing(); + return req.send().wait(env.waitScope).getThing(); }, "outside", "outside", "outside", "outbound"); } @@ -330,6 +378,35 @@ KJ_TEST("call remote promise pointing into membrane that eventually resolves to }, "outside", "outside", "outside", "outbound"); } +KJ_TEST("revoke membrane") { + auto paf = kj::newPromiseAndFulfiller(); + + TestRpcEnv env(kj::mv(paf.promise)); + + auto thing = env.membraned.makeThingRequest().send().wait(env.waitScope).getThing(); + + auto callPromise = env.membraned.waitForeverRequest().send(); + + KJ_EXPECT(!callPromise.poll(env.waitScope)); + + paf.fulfiller->reject(KJ_EXCEPTION(DISCONNECTED, "foobar")); + + // TRICKY: We need to use .ignoreResult().wait() below because when compiling with + // -fno-exceptions, void waits throw recoverable exceptions while non-void waits necessarily + // throw fatal exceptions... but testing for fatal exceptions when exceptions are disabled + // involves fork()ing the process to run the code so if it has side effects on file descriptors + // then we'll get in a bad state... + + KJ_ASSERT(callPromise.poll(env.waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("foobar", callPromise.ignoreResult().wait(env.waitScope)); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("foobar", + env.membraned.makeThingRequest().send().ignoreResult().wait(env.waitScope)); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("foobar", + thing.passThroughRequest().send().ignoreResult().wait(env.waitScope)); +} + } // namespace } // namespace _ } // namespace capnp diff --git a/c++/src/capnp/membrane.c++ b/c++/src/capnp/membrane.c++ index 049b8988d6..845ff95489 100644 --- a/c++/src/capnp/membrane.c++ +++ b/c++/src/capnp/membrane.c++ @@ -198,19 +198,44 @@ public: auto newPipeline = AnyPointer::Pipeline(kj::refcounted( PipelineHook::from(kj::mv(promise)), policy->addRef(), reverse)); + auto onRevoked = policy->onRevoked(); + bool reverse = this->reverse; // for capture - auto newPromise = promise.then(kj::mvCapture(policy, - [reverse](kj::Own&& policy, Response&& response) { + auto newPromise = promise.then( + [reverse,policy=kj::mv(policy)](Response&& response) mutable { AnyPointer::Reader reader = response; auto newRespHook = kj::heap( ResponseHook::from(kj::mv(response)), policy->addRef(), reverse); reader = newRespHook->imbue(reader); return Response(reader, kj::mv(newRespHook)); - })); + }); + + KJ_IF_MAYBE(r, kj::mv(onRevoked)) { + newPromise = newPromise.exclusiveJoin(r->then([]() -> Response { + KJ_FAIL_REQUIRE("onRevoked() promise resolved; it should only reject"); + })); + } return RemotePromise(kj::mv(newPromise), kj::mv(newPipeline)); } + kj::Promise sendStreaming() override { + auto promise = inner->sendStreaming(); + + KJ_IF_MAYBE(r, policy->onRevoked()) { + promise = promise.exclusiveJoin(r->then([]() { + KJ_FAIL_REQUIRE("onRevoked() promise resolved; it should only reject"); + })); + } + + return promise; + } + + AnyPointer::Pipeline sendForPipeline() override { + return AnyPointer::Pipeline(kj::refcounted( + PipelineHook::from(inner->sendForPipeline()), policy->addRef(), reverse)); + } + const void* getBrand() override { return MEMBRANE_BRAND; } @@ -242,7 +267,7 @@ public: } void releaseParams() override { - KJ_REQUIRE(!releasedParams); + // Note that releaseParams() is idempotent -- it can be called multiple times. releasedParams = true; inner->releaseParams(); } @@ -257,12 +282,13 @@ public: } } - kj::Promise tailCall(kj::Own&& request) override { - return inner->tailCall(MembraneRequestHook::wrap(kj::mv(request), *policy, !reverse)); + void setPipeline(kj::Own&& pipeline) override { + inner->setPipeline(kj::refcounted( + kj::mv(pipeline), policy->addRef(), !reverse)); } - void allowCancellation() override { - inner->allowCancellation(); + kj::Promise tailCall(kj::Own&& request) override { + return inner->tailCall(MembraneRequestHook::wrap(kj::mv(request), *policy, !reverse)); } kj::Promise onTailCall() override { @@ -299,87 +325,150 @@ private: kj::Maybe results; }; +} // namespace + class MembraneHook final: public ClientHook, public kj::Refcounted { public: - MembraneHook(kj::Own&& inner, kj::Own&& policy, bool reverse) - : inner(kj::mv(inner)), policy(kj::mv(policy)), reverse(reverse) {} + MembraneHook(kj::Own&& inner, kj::Own&& policyParam, bool reverse) + : inner(kj::mv(inner)), policy(kj::mv(policyParam)), reverse(reverse) { + KJ_IF_MAYBE(r, policy->onRevoked()) { + revocationTask = r->eagerlyEvaluate([this](kj::Exception&& exception) { + this->inner = newBrokenCap(kj::mv(exception)); + }); + } + } + + ~MembraneHook() noexcept(false) { + auto& map = reverse ? policy->reverseWrappers : policy->wrappers; + map.erase(inner.get()); + } static kj::Own wrap(ClientHook& cap, MembranePolicy& policy, bool reverse) { if (cap.getBrand() == MEMBRANE_BRAND) { auto& otherMembrane = kj::downcast(cap); - if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) { + auto& rootPolicy = policy.rootPolicy(); + if (&otherMembrane.policy->rootPolicy() == &rootPolicy && + otherMembrane.reverse == !reverse) { // Capability that passed across the membrane one way is now passing back the other way. // Unwrap it rather than double-wrap it. - return otherMembrane.inner->addRef(); + Capability::Client unwrapped(otherMembrane.inner->addRef()); + return ClientHook::from( + reverse ? rootPolicy.importInternal(kj::mv(unwrapped), *otherMembrane.policy, policy) + : rootPolicy.exportExternal(kj::mv(unwrapped), *otherMembrane.policy, policy)); } } - return kj::refcounted(cap.addRef(), policy.addRef(), reverse); + auto& map = reverse ? policy.reverseWrappers : policy.wrappers; + ClientHook*& slot = map.findOrCreate(&cap, [&]() -> kj::Decay::Entry { + return { &cap, nullptr }; + }); + if (slot == nullptr) { + auto result = ClientHook::from( + reverse ? policy.importExternal(Capability::Client(cap.addRef())) + : policy.exportInternal(Capability::Client(cap.addRef()))); + slot = result; + return result; + } else { + return slot->addRef(); + } } static kj::Own wrap(kj::Own cap, MembranePolicy& policy, bool reverse) { if (cap->getBrand() == MEMBRANE_BRAND) { auto& otherMembrane = kj::downcast(*cap); - if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) { + auto& rootPolicy = policy.rootPolicy(); + if (&otherMembrane.policy->rootPolicy() == &rootPolicy && + otherMembrane.reverse == !reverse) { // Capability that passed across the membrane one way is now passing back the other way. // Unwrap it rather than double-wrap it. - return otherMembrane.inner->addRef(); + Capability::Client unwrapped(otherMembrane.inner->addRef()); + return ClientHook::from( + reverse ? rootPolicy.importInternal(kj::mv(unwrapped), *otherMembrane.policy, policy) + : rootPolicy.exportExternal(kj::mv(unwrapped), *otherMembrane.policy, policy)); } } - return kj::refcounted(kj::mv(cap), policy.addRef(), reverse); + auto& map = reverse ? policy.reverseWrappers : policy.wrappers; + ClientHook*& slot = map.findOrCreate(cap.get(), [&]() -> kj::Decay::Entry { + return { cap.get(), nullptr }; + }); + if (slot == nullptr) { + auto result = ClientHook::from( + reverse ? policy.importExternal(Capability::Client(kj::mv(cap))) + : policy.exportInternal(Capability::Client(kj::mv(cap)))); + slot = result; + return result; + } else { + return slot->addRef(); + } } Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { KJ_IF_MAYBE(r, resolved) { - return r->get()->newCall(interfaceId, methodId, sizeHint); + return r->get()->newCall(interfaceId, methodId, sizeHint, hints); } auto redirect = reverse ? policy->outboundCall(interfaceId, methodId, Capability::Client(inner->addRef())) : policy->inboundCall(interfaceId, methodId, Capability::Client(inner->addRef())); KJ_IF_MAYBE(r, redirect) { - // The policy says that *if* this capability points into the membrane, then we want to - // redirect the call. However, if this capability is a promise, then it could resolve to - // something outside the membrane later. We have to wait before we actually redirect, - // otherwise behavior will differ depending on whether the promise is resolved. - KJ_IF_MAYBE(p, whenMoreResolved()) { - return newLocalPromiseClient(kj::mv(*p))->newCall(interfaceId, methodId, sizeHint); + if (policy->shouldResolveBeforeRedirecting()) { + // The policy says that *if* this capability points into the membrane, then we want to + // redirect the call. However, if this capability is a promise, then it could resolve to + // something outside the membrane later. We have to wait before we actually redirect, + // otherwise behavior will differ depending on whether the promise is resolved. + KJ_IF_MAYBE(p, whenMoreResolved()) { + return newLocalPromiseClient(p->attach(addRef())) + ->newCall(interfaceId, methodId, sizeHint, hints); + } } - return ClientHook::from(kj::mv(*r))->newCall(interfaceId, methodId, sizeHint); + return ClientHook::from(kj::mv(*r))->newCall(interfaceId, methodId, sizeHint, hints); } else { // For pass-through calls, we don't worry about promises, because if the capability resolves // to something outside the membrane, then the call will pass back out of the membrane too. return MembraneRequestHook::wrap( - inner->newCall(interfaceId, methodId, sizeHint), *policy, reverse); + inner->newCall(interfaceId, methodId, sizeHint, hints), *policy, reverse); } } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { + kj::Own&& context, + CallHints hints) override { KJ_IF_MAYBE(r, resolved) { - return r->get()->call(interfaceId, methodId, kj::mv(context)); + return r->get()->call(interfaceId, methodId, kj::mv(context), hints); } auto redirect = reverse ? policy->outboundCall(interfaceId, methodId, Capability::Client(inner->addRef())) : policy->inboundCall(interfaceId, methodId, Capability::Client(inner->addRef())); KJ_IF_MAYBE(r, redirect) { - // The policy says that *if* this capability points into the membrane, then we want to - // redirect the call. However, if this capability is a promise, then it could resolve to - // something outside the membrane later. We have to wait before we actually redirect, - // otherwise behavior will differ depending on whether the promise is resolved. - KJ_IF_MAYBE(p, whenMoreResolved()) { - return newLocalPromiseClient(kj::mv(*p))->call(interfaceId, methodId, kj::mv(context)); + if (policy->shouldResolveBeforeRedirecting()) { + // The policy says that *if* this capability points into the membrane, then we want to + // redirect the call. However, if this capability is a promise, then it could resolve to + // something outside the membrane later. We have to wait before we actually redirect, + // otherwise behavior will differ depending on whether the promise is resolved. + KJ_IF_MAYBE(p, whenMoreResolved()) { + return newLocalPromiseClient(p->attach(addRef())) + ->call(interfaceId, methodId, kj::mv(context), hints); + } } - return ClientHook::from(kj::mv(*r))->call(interfaceId, methodId, kj::mv(context)); + return ClientHook::from(kj::mv(*r))->call(interfaceId, methodId, kj::mv(context), hints); } else { // !reverse because calls to the CallContext go in the opposite direction. auto result = inner->call(interfaceId, methodId, - kj::refcounted(kj::mv(context), policy->addRef(), !reverse)); + kj::refcounted(kj::mv(context), policy->addRef(), !reverse), + hints); + + if (hints.onlyPromisePipeline) { + // Just in case the called capability returned a valid promise, replace it here. + result.promise = kj::NEVER_DONE; + } else KJ_IF_MAYBE(r, policy->onRevoked()) { + result.promise = result.promise.exclusiveJoin(kj::mv(*r)); + } return { kj::mv(result.promise), @@ -409,12 +498,21 @@ public: } KJ_IF_MAYBE(promise, inner->whenMoreResolved()) { + KJ_IF_MAYBE(r, policy->onRevoked()) { + *promise = promise->exclusiveJoin(r->then([]() -> kj::Own { + KJ_FAIL_REQUIRE("onRevoked() promise resolved; it should only reject"); + })); + } + return promise->then([this](kj::Own&& newInner) { - kj::Own newResolved = wrap(*newInner, *policy, reverse); - if (resolved == nullptr) { - resolved = newResolved->addRef(); + // There's a chance resolved was set by getResolved() or a concurrent whenMoreResolved() + // while we yielded the event loop. If the inner ClientHook is maintaining the contract, + // then resolved would already be set to newInner after wrapping in a MembraneHook. + KJ_IF_MAYBE(r, resolved) { + return (*r)->addRef(); + } else { + return resolved.emplace(wrap(*newInner, *policy, reverse))->addRef(); } - return newResolved; }); } else { return nullptr; @@ -429,19 +527,51 @@ public: return MEMBRANE_BRAND; } + kj::Maybe getFd() override { + KJ_IF_MAYBE(f, inner->getFd()) { + if (policy->allowFdPassthrough()) { + return *f; + } + } + return nullptr; + } + private: kj::Own inner; kj::Own policy; bool reverse; kj::Maybe> resolved; + kj::Promise revocationTask = nullptr; }; +namespace { + kj::Own membrane(kj::Own inner, MembranePolicy& policy, bool reverse) { return MembraneHook::wrap(kj::mv(inner), policy, reverse); } } // namespace +Capability::Client MembranePolicy::importExternal(Capability::Client external) { + return Capability::Client(kj::refcounted( + ClientHook::from(kj::mv(external)), addRef(), true)); +} + +Capability::Client MembranePolicy::exportInternal(Capability::Client internal) { + return Capability::Client(kj::refcounted( + ClientHook::from(kj::mv(internal)), addRef(), false)); +} + +Capability::Client MembranePolicy::importInternal( + Capability::Client internal, MembranePolicy& exportPolicy, MembranePolicy& importPolicy) { + return kj::mv(internal); +} + +Capability::Client MembranePolicy::exportExternal( + Capability::Client external, MembranePolicy& importPolicy, MembranePolicy& exportPolicy) { + return kj::mv(external); +} + Capability::Client membrane(Capability::Client inner, kj::Own policy) { return Capability::Client(membrane( ClientHook::from(kj::mv(inner)), *policy, false)); diff --git a/c++/src/capnp/membrane.h b/c++/src/capnp/membrane.h index 6fa8a1335d..60629cb4dd 100644 --- a/c++/src/capnp/membrane.h +++ b/c++/src/capnp/membrane.h @@ -19,8 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_MEMBRANE_H_ -#define CAPNP_MEMBRANE_H_ +#pragma once // In capability theory, a "membrane" is a wrapper around a capability which (usually) forwards // calls but recursively wraps capabilities in those calls in the same membrane. The purpose of a // membrane is to enforce a barrier between two capabilities that cannot be bypassed by merely @@ -49,6 +48,9 @@ // Mark Miller on membranes: http://www.eros-os.org/pipermail/e-lang/2003-January/008434.html #include "capability.h" +#include + +CAPNP_BEGIN_HEADER namespace capnp { @@ -104,6 +106,101 @@ class MembranePolicy { // object actually to be the *same* membrane. This is relevant when an object passes into the // membrane and then back out (or out and then back in): instead of double-wrapping the object, // the wrapping will be removed. + + virtual kj::Maybe> onRevoked() { return nullptr; } + // If this returns non-null, then it is a promise that will reject (throw an exception) when the + // membrane should be revoked. On revocation, all capabilities pointing across the membrane will + // be dropped and all outstanding calls canceled. The exception thrown by the promise will be + // propagated to all these calls. It is an error for the promise to resolve without throwing. + // + // After the revocation promise has rejected, inboundCall() and outboundCall() will still be + // invoked for new calls, but the `target` passed to them will be a capability that always + // rethrows the revocation exception. + + virtual bool shouldResolveBeforeRedirecting() { return false; } + // If this returns true, then when inboundCall() or outboundCall() returns a redirect, but the + // original target is a promise, then the membrane will discard the redirect and instead wait + // for the promise to become more resolved and try again. + // + // This behavior is important in particular when implementing a membrane that wants to intercept + // calls that would otherwise terminate inside the membrane, but needs to be careful not to + // intercept calls that might be reflected back out of the membrane. If the promise eventually + // resolves to a capability outside the membrane, then the call will be forwarded to that + // capability without applying the policy at all. + // + // However, some membranes don't need this behavior, and may be negatively impacted by the + // unnecessary waiting. Such membranes can keep this disabled. + // + // TODO(cleanup): Consider a backwards-incompatible revamp of the MembranePolicy API with a + // better design here. Maybe we should more carefully distinguish between MembranePolicies + // which are reversible vs. those which are one-way? + + virtual bool allowFdPassthrough() { return false; } + // Should file descriptors be allowed to pass through this membrane? + // + // A MembranePolicy obviously cannot mediate nor revoke access to a file descriptor once it has + // passed through, so this must be used with caution. If you only want to allow file descriptors + // on certain methods, you could do so by implementing inboundCall()/outboundCall() to + // special-case those methods. + + // --------------------------------------------------------------------------- + // Control over importing and exporting. + // + // Most membranes should not override these methods. The default behavior is that a capability + // that crosses the membrane is wrapped in it, and if the wrapped version crosses back the other + // way, it is unwrapped. + + virtual Capability::Client importExternal(Capability::Client external); + // An external capability is crossing into the membrane. Returns the capability that should + // substitute for it when called from the inside. + // + // The default implementation creates a capability that invokes this MembranePolicy. E.g. all + // calls will invoke outboundCall(). + // + // Note that reverseMembrane(cap, policy) normally calls policy->importExternal(cap), unless + // `cap` itself was originally returned by the default implementation of exportInternal(), in + // which case importInternal() is called instead. + + virtual Capability::Client exportInternal(Capability::Client internal); + // An internal capability is crossing out of the membrane. Returns the capability that should + // substitute for it when called from the outside. + // + // The default implementation creates a capability that invokes this MembranePolicy. E.g. all + // calls will invoke inboundCall(). + // + // Note that membrane(cap, policy) normally calls policy->exportInternal(cap), unless `cap` + // itself was originally returned by the default implementation of exportInternal(), in which + // case importInternal() is called instead. + + virtual MembranePolicy& rootPolicy() { return *this; } + // If two policies return the same value for rootPolicy(), then a capability imported through + // one can be exported through the other, and vice versa. `importInternal()` and + // `exportExternal()` will always be called on the root policy, passing the two child policies + // as parameters. If you don't override rootPolicy(), then the policy references passed to + // importInternal() and exportExternal() will always be references to *this. + + virtual Capability::Client importInternal( + Capability::Client internal, MembranePolicy& exportPolicy, MembranePolicy& importPolicy); + // An internal capability which was previously exported is now being re-imported, i.e. a + // capability passed out of the membrane and then back in. + // + // The default implementation simply returns `internal`. + + virtual Capability::Client exportExternal( + Capability::Client external, MembranePolicy& importPolicy, MembranePolicy& exportPolicy); + // An external capability which was previously imported is now being re-exported, i.e. a + // capability passed into the membrane and then back out. + // + // The default implementation simply returns `external`. + +private: + kj::HashMap wrappers; + kj::HashMap reverseWrappers; + // Tracks capabilities that already have wrappers instantiated. The maps map from pointer to + // inner capability to pointer to wrapper. When a wrapper is destroyed it removes itself from + // the map. + + friend class MembraneHook; }; Capability::Client membrane(Capability::Client inner, kj::Own policy); @@ -199,4 +296,4 @@ Orphan::Reads> copyOutOfMembrane( } // namespace capnp -#endif // CAPNP_MEMBRANE_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/message-test.c++ b/c++/src/capnp/message-test.c++ index e7a65ebf7a..6545b17a6f 100644 --- a/c++/src/capnp/message-test.c++ +++ b/c++/src/capnp/message-test.c++ @@ -168,6 +168,51 @@ TEST(Message, ReadWriteDataStruct) { checkTestMessageAllZero(defaultValue()); } +KJ_TEST("clone()") { + MallocMessageBuilder builder(2048); + initTestMessage(builder.getRoot()); + + auto copy = clone(builder.getRoot().asReader()); + checkTestMessage(*copy); +} + +#if !CAPNP_ALLOW_UNALIGNED +KJ_TEST("disallow unaligned") { + union { + char buffer[16]; + word align; + }; + memset(buffer, 0, sizeof(buffer)); + + auto unaligned = kj::arrayPtr(reinterpret_cast(buffer + 1), 1); + + kj::ArrayPtr segments[1] = {unaligned}; + SegmentArrayMessageReader message(segments); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("unaligned", message.getRoot()); +} +#endif + +KJ_TEST("MessageBuilder::sizeInWords()") { + capnp::MallocMessageBuilder builder; + auto root = builder.initRoot(); + initTestMessage(root); + + size_t expected = root.totalSize().wordCount + 1; + + KJ_EXPECT(builder.sizeInWords() == expected); + + auto segments = builder.getSegmentsForOutput(); + size_t total = 0; + for (auto& segment: segments) { + total += segment.size(); + } + KJ_EXPECT(total == expected); + + capnp::SegmentArrayMessageReader reader(segments); + checkTestMessage(reader.getRoot()); + KJ_EXPECT(reader.sizeInWords() == expected); +} + // TODO(test): More tests. } // namespace diff --git a/c++/src/capnp/message.c++ b/c++/src/capnp/message.c++ index 9446b57aa6..90a66cca1f 100644 --- a/c++/src/capnp/message.c++ +++ b/c++/src/capnp/message.c++ @@ -25,9 +25,6 @@ #include "arena.h" #include "orphan.h" #include -#include -#include -#include #include namespace capnp { @@ -36,11 +33,13 @@ namespace { class DummyCapTableReader: public _::CapTableReader { public: -#if !CAPNP_LITE kj::Maybe> extractCap(uint index) override { +#if CAPNP_LITE + KJ_UNIMPLEMENTED("no cap tables in lite mode"); +#else return nullptr; - } #endif + } }; static KJ_CONSTEXPR(const) DummyCapTableReader dummyCapTableReader = DummyCapTableReader(); @@ -58,7 +57,7 @@ bool MessageReader::isCanonical() { static_assert(sizeof(_::ReaderArena) <= sizeof(arenaSpace), "arenaSpace is too small to hold a ReaderArena. Please increase it. This will break " "ABI compatibility."); - new(arena()) _::ReaderArena(this); + kj::ctor(*arena(), this); allocatedArena = true; } @@ -83,13 +82,16 @@ bool MessageReader::isCanonical() { return rootIsCanonical && allWordsConsumed; } +size_t MessageReader::sizeInWords() { + return arena()->sizeInWords(); +} AnyPointer::Reader MessageReader::getRootInternal() { if (!allocatedArena) { static_assert(sizeof(_::ReaderArena) <= sizeof(arenaSpace), "arenaSpace is too small to hold a ReaderArena. Please increase it. This will break " "ABI compatibility."); - new(arena()) _::ReaderArena(this); + kj::ctor(*arena(), this); allocatedArena = true; } @@ -181,6 +183,14 @@ bool MessageBuilder::isCanonical() { .isCanonical(&readHead); } +size_t MessageBuilder::sizeInWords() { + return arena()->sizeInWords(); +} + +kj::Own<_::CapTableBuilder> MessageBuilder::releaseBuiltinCapTable() { + return arena()->releaseLocalCapTable(); +} + // ======================================================================================= SegmentArrayMessageReader::SegmentArrayMessageReader( @@ -199,10 +209,6 @@ kj::ArrayPtr SegmentArrayMessageReader::getSegment(uint id) { // ------------------------------------------------------------------- -struct MallocMessageBuilder::MoreSegments { - std::vector segments; -}; - MallocMessageBuilder::MallocMessageBuilder( uint firstSegmentWords, AllocationStrategy allocationStrategy) : nextSize(firstSegmentWords), allocationStrategy(allocationStrategy), @@ -233,10 +239,8 @@ MallocMessageBuilder::~MallocMessageBuilder() noexcept(false) { } } - KJ_IF_MAYBE(s, moreSegments) { - for (void* ptr: s->get()->segments) { - free(ptr); - } + for (void* ptr: moreSegments) { + free(ptr); } } } @@ -273,15 +277,7 @@ kj::ArrayPtr MallocMessageBuilder::allocateSegment(uint minimumSize) { // After the first segment, we want nextSize to equal the total size allocated so far. if (allocationStrategy == AllocationStrategy::GROW_HEURISTICALLY) nextSize = size; } else { - MoreSegments* segments; - KJ_IF_MAYBE(s, moreSegments) { - segments = *s; - } else { - auto newSegments = kj::heap(); - segments = newSegments; - moreSegments = mv(newSegments); - } - segments->segments.push_back(result); + moreSegments.add(result); if (allocationStrategy == AllocationStrategy::GROW_HEURISTICALLY) { // set nextSize = min(nextSize+size, MAX_SEGMENT_WORDS) // while protecting against possible overflow of (nextSize+size) diff --git a/c++/src/capnp/message.h b/c++/src/capnp/message.h index 9a2b4853b6..983761dfb3 100644 --- a/c++/src/capnp/message.h +++ b/c++/src/capnp/message.h @@ -19,26 +19,25 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#pragma once + #include #include #include #include +#include #include "common.h" #include "layout.h" #include "any.h" -#ifndef CAPNP_MESSAGE_H_ -#define CAPNP_MESSAGE_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +CAPNP_BEGIN_HEADER namespace capnp { namespace _ { // private class ReaderArena; class BuilderArena; + struct CloneImpl; } class StructSchema; @@ -122,14 +121,23 @@ class MessageReader { bool isCanonical(); // Returns whether the message encoded in the reader is in canonical form. + size_t sizeInWords(); + // Add up the size of all segments. + private: ReaderOptions options; +#if defined(__EMSCRIPTEN__) || (defined(__APPLE__) && (defined(__ppc__) || defined(__i386__))) + static constexpr size_t arenaSpacePadding = 19; +#else + static constexpr size_t arenaSpacePadding = 18; +#endif + // Space in which we can construct a ReaderArena. We don't use ReaderArena directly here // because we don't want clients to have to #include arena.h, which itself includes a bunch of - // big STL headers. We don't use a pointer to a ReaderArena because that would require an + // other headers. We don't use a pointer to a ReaderArena because that would require an // extra malloc on every message which could be expensive when processing small messages. - alignas(8) void* arenaSpace[15 + sizeof(kj::MutexGuarded) / sizeof(void*)]; + alignas(8) void* arenaSpace[arenaSpacePadding + sizeof(kj::MutexGuarded) / sizeof(void*)]; bool allocatedArena; _::ReaderArena* arena() { return reinterpret_cast<_::ReaderArena*>(arenaSpace); } @@ -151,7 +159,7 @@ class MessageBuilder { public: MessageBuilder(); virtual ~MessageBuilder() noexcept(false); - KJ_DISALLOW_COPY(MessageBuilder); + KJ_DISALLOW_COPY_AND_MOVE(MessageBuilder); struct SegmentInit { kj::ArrayPtr space; @@ -183,13 +191,15 @@ class MessageBuilder { // new objects in this message. virtual kj::ArrayPtr allocateSegment(uint minimumSize) = 0; - // Allocates an array of at least the given number of words, throwing an exception or crashing if - // this is not possible. It is expected that this method will usually return more space than - // requested, and the caller should use that extra space as much as possible before allocating - // more. The returned space remains valid at least until the MessageBuilder is destroyed. + // Allocates an array of at least the given number of zero'd words, throwing an exception or + // crashing if this is not possible. It is expected that this method will usually return more + // space than requested, and the caller should use that extra space as much as possible before + // allocating more. The returned space remains valid at least until the MessageBuilder is + // destroyed. // - // Cap'n Proto will only call this once at a time, so the subclass need not worry about - // thread-safety. + // allocateSegment() is responsible for zeroing the memory before returning. This is required + // because otherwise the Cap'n Proto implementation would have to zero the memory anyway, and + // many allocators are able to provide already-zero'd memory more efficiently. template typename RootType::Builder initRoot(); @@ -227,6 +237,9 @@ class MessageBuilder { bool isCanonical(); // Check whether the message builder is in canonical form + size_t sizeInWords(); + // Add up the allocated space from all segments. + private: alignas(8) void* arenaSpace[22]; // Space in which we can construct a BuilderArena. We don't use BuilderArena directly here @@ -244,6 +257,14 @@ class MessageBuilder { _::BuilderArena* arena() { return reinterpret_cast<_::BuilderArena*>(arenaSpace); } _::SegmentBuilder* getRootSegment(); AnyPointer::Builder getRootInternal(); + + kj::Own<_::CapTableBuilder> releaseBuiltinCapTable(); + // Hack for clone() to extract the cap table. + + friend struct _::CloneImpl; + // We can't declare clone() as a friend directly because old versions of GCC incorrectly demand + // that the first declaration (even if it is a friend declaration) specify the default type args, + // whereas correct compilers do not permit default type args to be specified on a friend decl. }; template @@ -305,6 +326,10 @@ static typename Type::Reader defaultValue(); // // TODO(cleanup): Find a better home for this function? +template > +kj::Own> clone(Reader&& reader); +// Make a deep copy of the given Reader on the heap, producing an owned pointer. + // ======================================================================================= class SegmentArrayMessageReader: public MessageReader { @@ -318,7 +343,7 @@ class SegmentArrayMessageReader: public MessageReader { // Creates a message pointing at the given segment array, without taking ownership of the // segments. All arrays passed in must remain valid until the MessageReader is destroyed. - KJ_DISALLOW_COPY(SegmentArrayMessageReader); + KJ_DISALLOW_COPY_AND_MOVE(SegmentArrayMessageReader); ~SegmentArrayMessageReader() noexcept(false); virtual kj::ArrayPtr getSegment(uint id) override; @@ -375,7 +400,7 @@ class MallocMessageBuilder: public MessageBuilder { // firstSegment MUST be zero-initialized. MallocMessageBuilder's destructor will write new zeros // over any space that was used so that it can be reused. - KJ_DISALLOW_COPY(MallocMessageBuilder); + KJ_DISALLOW_COPY_AND_MOVE(MallocMessageBuilder); virtual ~MallocMessageBuilder() noexcept(false); virtual kj::ArrayPtr allocateSegment(uint minimumSize) override; @@ -388,9 +413,7 @@ class MallocMessageBuilder: public MessageBuilder { bool returnedFirstSegment; void* firstSegment; - - struct MoreSegments; - kj::Maybe> moreSegments; + kj::Vector moreSegments; }; class FlatMessageBuilder: public MessageBuilder { @@ -408,7 +431,7 @@ class FlatMessageBuilder: public MessageBuilder { public: explicit FlatMessageBuilder(kj::ArrayPtr array); - KJ_DISALLOW_COPY(FlatMessageBuilder); + KJ_DISALLOW_COPY_AND_MOVE(FlatMessageBuilder); virtual ~FlatMessageBuilder() noexcept(false); void requireFilled(); @@ -498,6 +521,33 @@ static typename Type::Reader defaultValue() { return typename Type::Reader(_::StructReader()); } +namespace _ { + struct CloneImpl { + static inline kj::Own<_::CapTableBuilder> releaseBuiltinCapTable(MessageBuilder& message) { + return message.releaseBuiltinCapTable(); + } + }; +}; + +template +kj::Own> clone(Reader&& reader) { + auto size = reader.totalSize(); + auto buffer = kj::heapArray(size.wordCount + 1); + memset(buffer.asBytes().begin(), 0, buffer.asBytes().size()); + if (size.capCount == 0) { + copyToUnchecked(reader, buffer); + auto result = readMessageUnchecked>(buffer.begin()); + return kj::attachVal(result, kj::mv(buffer)); + } else { + FlatMessageBuilder builder(buffer); + builder.setRoot(kj::fwd(reader)); + builder.requireFilled(); + auto capTable = _::CloneImpl::releaseBuiltinCapTable(builder); + AnyPointer::Reader raw(_::PointerReader::getRootUnchecked(buffer.begin()).imbue(capTable)); + return kj::attachVal(raw.getAs>(), kj::mv(buffer), kj::mv(capTable)); + } +} + template kj::Array canonicalize(T&& reader) { return _::PointerHelpers>::getInternalReader(reader).canonicalize(); @@ -505,4 +555,4 @@ kj::Array canonicalize(T&& reader) { } // namespace capnp -#endif // CAPNP_MESSAGE_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/orphan.h b/c++/src/capnp/orphan.h index 8c8b9a6054..0ef4a671c8 100644 --- a/c++/src/capnp/orphan.h +++ b/c++/src/capnp/orphan.h @@ -19,15 +19,12 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_ORPHAN_H_ -#define CAPNP_ORPHAN_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "layout.h" +CAPNP_BEGIN_HEADER + namespace capnp { class StructSchema; @@ -74,7 +71,7 @@ class Orphan { // If the new size is less than the original, the remaining elements will be discarded. The // list is never moved in this case. If the list happens to be located at the end of its segment // (which is always true if the list was the last thing allocated), the removed memory will be - // reclaimed (reducing the messag size), otherwise it is simply zeroed. The reclaiming behavior + // reclaimed (reducing the message size), otherwise it is simply zeroed. The reclaiming behavior // is particularly useful for allocating buffer space when you aren't sure how much space you // actually need: you can pre-allocate, say, a 4k byte array, read() from a file into it, and // then truncate it back to the amount of space actually used. @@ -437,4 +434,4 @@ inline Orphan Orphanage::referenceExternalData(Data::Reader data) const { } // namespace capnp -#endif // CAPNP_ORPHAN_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/persistent.capnp b/c++/src/capnp/persistent.capnp index a13b47168a..fefd188aef 100644 --- a/c++/src/capnp/persistent.capnp +++ b/c++/src/capnp/persistent.capnp @@ -108,23 +108,6 @@ interface Persistent@0xc8cb212fcd9f5691(SturdyRef, Owner) { } } -interface RealmGateway(InternalRef, ExternalRef, InternalOwner, ExternalOwner) { - # Interface invoked when a SturdyRef is about to cross realms. The RPC system supports providing - # a RealmGateway as a callback hook when setting up RPC over some VatNetwork. - - import @0 (cap :Persistent(ExternalRef, ExternalOwner), - params :Persistent(InternalRef, InternalOwner).SaveParams) - -> Persistent(InternalRef, InternalOwner).SaveResults; - # Given an external capability, save it and return an internal reference. Used when someone - # inside the realm tries to save a capability from outside the realm. - - export @1 (cap :Persistent(InternalRef, InternalOwner), - params :Persistent(ExternalRef, ExternalOwner).SaveParams) - -> Persistent(ExternalRef, ExternalOwner).SaveResults; - # Given an internal capability, save it and return an external reference. Used when someone - # outside the realm tries to save a capability from inside the realm. -} - annotation persistent(interface, field) :Void; # Apply this annotation to interfaces for objects that will always be persistent, instead of # extending the Persistent capability, since the correct type parameters to Persistent depend on diff --git a/c++/src/capnp/persistent.capnp.c++ b/c++/src/capnp/persistent.capnp.c++ index 028aefe788..195c71549b 100644 --- a/c++/src/capnp/persistent.capnp.c++ +++ b/c++/src/capnp/persistent.capnp.c++ @@ -74,7 +74,7 @@ KJ_CONSTEXPR(const) ::capnp::_::RawBrandedSchema::Dependency bd_c8cb212fcd9f5691 }; const ::capnp::_::RawSchema s_c8cb212fcd9f5691 = { 0xc8cb212fcd9f5691, b_c8cb212fcd9f5691.words, 54, d_c8cb212fcd9f5691, m_c8cb212fcd9f5691, - 2, 1, nullptr, nullptr, nullptr, { &s_c8cb212fcd9f5691, nullptr, bd_c8cb212fcd9f5691, 0, sizeof(bd_c8cb212fcd9f5691) / sizeof(bd_c8cb212fcd9f5691[0]), nullptr } + 2, 1, nullptr, nullptr, nullptr, { &s_c8cb212fcd9f5691, nullptr, bd_c8cb212fcd9f5691, 0, sizeof(bd_c8cb212fcd9f5691) / sizeof(bd_c8cb212fcd9f5691[0]), nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<35> b_f76fba59183073a5 = { @@ -120,7 +120,7 @@ static const uint16_t m_f76fba59183073a5[] = {0}; static const uint16_t i_f76fba59183073a5[] = {0}; const ::capnp::_::RawSchema s_f76fba59183073a5 = { 0xf76fba59183073a5, b_f76fba59183073a5.words, 35, nullptr, m_f76fba59183073a5, - 0, 1, i_f76fba59183073a5, nullptr, nullptr, { &s_f76fba59183073a5, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_f76fba59183073a5, nullptr, nullptr, { &s_f76fba59183073a5, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<36> b_b76848c18c40efbf = { @@ -167,337 +167,7 @@ static const uint16_t m_b76848c18c40efbf[] = {0}; static const uint16_t i_b76848c18c40efbf[] = {0}; const ::capnp::_::RawSchema s_b76848c18c40efbf = { 0xb76848c18c40efbf, b_b76848c18c40efbf.words, 36, nullptr, m_b76848c18c40efbf, - 0, 1, i_b76848c18c40efbf, nullptr, nullptr, { &s_b76848c18c40efbf, nullptr, nullptr, 0, 0, nullptr } -}; -#endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<99> b_84ff286cd00a3ed4 = { - { 0, 0, 0, 0, 5, 0, 6, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 23, 0, 0, 0, 3, 0, 0, 0, - 215, 238, 63, 152, 54, 8, 99, 184, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 0, 0, 0, - 21, 0, 0, 0, 34, 1, 0, 0, - 37, 0, 0, 0, 7, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 33, 0, 0, 0, 135, 0, 0, 0, - 41, 1, 0, 0, 7, 0, 0, 0, - 41, 1, 0, 0, 39, 0, 0, 0, - 99, 97, 112, 110, 112, 47, 112, 101, - 114, 115, 105, 115, 116, 101, 110, 116, - 46, 99, 97, 112, 110, 112, 58, 82, - 101, 97, 108, 109, 71, 97, 116, 101, - 119, 97, 121, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 0, 1, 0, - 8, 0, 0, 0, 3, 0, 5, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 77, 87, 9, 57, 29, 204, 194, 240, - 191, 239, 64, 140, 193, 72, 104, 183, - 49, 0, 0, 0, 58, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 44, 0, 0, 0, 0, 0, 1, 0, - 60, 0, 0, 0, 0, 0, 1, 0, - 129, 0, 0, 0, 7, 0, 0, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 170, 163, 45, 72, 139, 161, 175, 236, - 191, 239, 64, 140, 193, 72, 104, 183, - 117, 0, 0, 0, 58, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 0, 0, 0, 0, 0, 1, 0, - 128, 0, 0, 0, 0, 0, 1, 0, - 197, 0, 0, 0, 7, 0, 0, 0, - 105, 109, 112, 111, 114, 116, 0, 0, - 1, 0, 0, 0, 31, 0, 0, 0, - 4, 0, 0, 0, 2, 0, 1, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 1, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 31, 0, 0, 0, - 4, 0, 0, 0, 2, 0, 1, 0, - 145, 86, 159, 205, 47, 33, 203, 200, - 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 39, 0, 0, 0, - 8, 0, 0, 0, 1, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 8, 0, 0, 0, 3, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 16, 0, 0, 0, 3, 0, 1, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 2, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 0, - 101, 120, 112, 111, 114, 116, 0, 0, - 1, 0, 0, 0, 31, 0, 0, 0, - 4, 0, 0, 0, 2, 0, 1, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 1, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 31, 0, 0, 0, - 4, 0, 0, 0, 2, 0, 1, 0, - 145, 86, 159, 205, 47, 33, 203, 200, - 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 39, 0, 0, 0, - 8, 0, 0, 0, 1, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 8, 0, 0, 0, 3, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 16, 0, 0, 0, 3, 0, 1, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 1, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 3, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 0, - 0, 0, 0, 0, 1, 0, 1, 0, - 16, 0, 0, 0, 0, 0, 1, 0, - 13, 0, 0, 0, 98, 0, 0, 0, - 17, 0, 0, 0, 98, 0, 0, 0, - 21, 0, 0, 0, 114, 0, 0, 0, - 25, 0, 0, 0, 114, 0, 0, 0, - 73, 110, 116, 101, 114, 110, 97, 108, - 82, 101, 102, 0, 0, 0, 0, 0, - 69, 120, 116, 101, 114, 110, 97, 108, - 82, 101, 102, 0, 0, 0, 0, 0, - 73, 110, 116, 101, 114, 110, 97, 108, - 79, 119, 110, 101, 114, 0, 0, 0, - 69, 120, 116, 101, 114, 110, 97, 108, - 79, 119, 110, 101, 114, 0, 0, 0, } -}; -::capnp::word const* const bp_84ff286cd00a3ed4 = b_84ff286cd00a3ed4.words; -#if !CAPNP_LITE -static const ::capnp::_::RawSchema* const d_84ff286cd00a3ed4[] = { - &s_b76848c18c40efbf, - &s_ecafa18b482da3aa, - &s_f0c2cc1d3909574d, -}; -static const uint16_t m_84ff286cd00a3ed4[] = {1, 0}; -KJ_CONSTEXPR(const) ::capnp::_::RawBrandedSchema::Dependency bd_84ff286cd00a3ed4[] = { - { 33554432, ::capnp::RealmGateway< ::capnp::AnyPointer, ::capnp::AnyPointer, ::capnp::AnyPointer, ::capnp::AnyPointer>::ImportParams::_capnpPrivate::brand() }, - { 33554433, ::capnp::RealmGateway< ::capnp::AnyPointer, ::capnp::AnyPointer, ::capnp::AnyPointer, ::capnp::AnyPointer>::ExportParams::_capnpPrivate::brand() }, - { 50331648, ::capnp::Persistent< ::capnp::AnyPointer, ::capnp::AnyPointer>::SaveResults::_capnpPrivate::brand() }, - { 50331649, ::capnp::Persistent< ::capnp::AnyPointer, ::capnp::AnyPointer>::SaveResults::_capnpPrivate::brand() }, -}; -const ::capnp::_::RawSchema s_84ff286cd00a3ed4 = { - 0x84ff286cd00a3ed4, b_84ff286cd00a3ed4.words, 99, d_84ff286cd00a3ed4, m_84ff286cd00a3ed4, - 3, 2, nullptr, nullptr, nullptr, { &s_84ff286cd00a3ed4, nullptr, bd_84ff286cd00a3ed4, 0, sizeof(bd_84ff286cd00a3ed4) / sizeof(bd_84ff286cd00a3ed4[0]), nullptr } -}; -#endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<86> b_f0c2cc1d3909574d = { - { 0, 0, 0, 0, 5, 0, 6, 0, - 77, 87, 9, 57, 29, 204, 194, 240, - 36, 0, 0, 0, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 2, 0, 7, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 0, 0, 0, - 21, 0, 0, 0, 146, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 37, 0, 0, 0, 119, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 99, 97, 112, 110, 112, 47, 112, 101, - 114, 115, 105, 115, 116, 101, 110, 116, - 46, 99, 97, 112, 110, 112, 58, 82, - 101, 97, 108, 109, 71, 97, 116, 101, - 119, 97, 121, 46, 105, 109, 112, 111, - 114, 116, 36, 80, 97, 114, 97, 109, - 115, 0, 0, 0, 0, 0, 0, 0, - 8, 0, 0, 0, 3, 0, 4, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 41, 0, 0, 0, 34, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 36, 0, 0, 0, 3, 0, 1, 0, - 120, 0, 0, 0, 2, 0, 1, 0, - 1, 0, 0, 0, 1, 0, 0, 0, - 0, 0, 1, 0, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 117, 0, 0, 0, 58, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 0, 0, 0, 3, 0, 1, 0, - 196, 0, 0, 0, 2, 0, 1, 0, - 99, 97, 112, 0, 0, 0, 0, 0, - 17, 0, 0, 0, 0, 0, 0, 0, - 145, 86, 159, 205, 47, 33, 203, 200, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 0, - 1, 0, 0, 0, 31, 0, 0, 0, - 4, 0, 0, 0, 2, 0, 1, 0, - 145, 86, 159, 205, 47, 33, 203, 200, - 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 39, 0, 0, 0, - 8, 0, 0, 0, 1, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 8, 0, 0, 0, 3, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 16, 0, 0, 0, 3, 0, 1, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 1, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 3, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 17, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 97, 114, 97, 109, 115, 0, 0, - 16, 0, 0, 0, 0, 0, 0, 0, - 165, 115, 48, 24, 89, 186, 111, 247, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 0, - 1, 0, 0, 0, 31, 0, 0, 0, - 4, 0, 0, 0, 2, 0, 1, 0, - 145, 86, 159, 205, 47, 33, 203, 200, - 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 39, 0, 0, 0, - 8, 0, 0, 0, 1, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 8, 0, 0, 0, 3, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 16, 0, 0, 0, 3, 0, 1, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 2, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 16, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, } -}; -::capnp::word const* const bp_f0c2cc1d3909574d = b_f0c2cc1d3909574d.words; -#if !CAPNP_LITE -static const ::capnp::_::RawSchema* const d_f0c2cc1d3909574d[] = { - &s_c8cb212fcd9f5691, - &s_f76fba59183073a5, -}; -static const uint16_t m_f0c2cc1d3909574d[] = {0, 1}; -static const uint16_t i_f0c2cc1d3909574d[] = {0, 1}; -KJ_CONSTEXPR(const) ::capnp::_::RawBrandedSchema::Dependency bd_f0c2cc1d3909574d[] = { - { 16777216, ::capnp::Persistent< ::capnp::AnyPointer, ::capnp::AnyPointer>::_capnpPrivate::brand() }, - { 16777217, ::capnp::Persistent< ::capnp::AnyPointer, ::capnp::AnyPointer>::SaveParams::_capnpPrivate::brand() }, -}; -const ::capnp::_::RawSchema s_f0c2cc1d3909574d = { - 0xf0c2cc1d3909574d, b_f0c2cc1d3909574d.words, 86, d_f0c2cc1d3909574d, m_f0c2cc1d3909574d, - 2, 2, i_f0c2cc1d3909574d, nullptr, nullptr, { &s_f0c2cc1d3909574d, nullptr, bd_f0c2cc1d3909574d, 0, sizeof(bd_f0c2cc1d3909574d) / sizeof(bd_f0c2cc1d3909574d[0]), nullptr } -}; -#endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<86> b_ecafa18b482da3aa = { - { 0, 0, 0, 0, 5, 0, 6, 0, - 170, 163, 45, 72, 139, 161, 175, 236, - 36, 0, 0, 0, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 2, 0, 7, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 0, 0, 0, - 21, 0, 0, 0, 146, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 37, 0, 0, 0, 119, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 99, 97, 112, 110, 112, 47, 112, 101, - 114, 115, 105, 115, 116, 101, 110, 116, - 46, 99, 97, 112, 110, 112, 58, 82, - 101, 97, 108, 109, 71, 97, 116, 101, - 119, 97, 121, 46, 101, 120, 112, 111, - 114, 116, 36, 80, 97, 114, 97, 109, - 115, 0, 0, 0, 0, 0, 0, 0, - 8, 0, 0, 0, 3, 0, 4, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 41, 0, 0, 0, 34, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 36, 0, 0, 0, 3, 0, 1, 0, - 120, 0, 0, 0, 2, 0, 1, 0, - 1, 0, 0, 0, 1, 0, 0, 0, - 0, 0, 1, 0, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 117, 0, 0, 0, 58, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 0, 0, 0, 3, 0, 1, 0, - 196, 0, 0, 0, 2, 0, 1, 0, - 99, 97, 112, 0, 0, 0, 0, 0, - 17, 0, 0, 0, 0, 0, 0, 0, - 145, 86, 159, 205, 47, 33, 203, 200, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 0, - 1, 0, 0, 0, 31, 0, 0, 0, - 4, 0, 0, 0, 2, 0, 1, 0, - 145, 86, 159, 205, 47, 33, 203, 200, - 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 39, 0, 0, 0, - 8, 0, 0, 0, 1, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 8, 0, 0, 0, 3, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 16, 0, 0, 0, 3, 0, 1, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 2, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 17, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 97, 114, 97, 109, 115, 0, 0, - 16, 0, 0, 0, 0, 0, 0, 0, - 165, 115, 48, 24, 89, 186, 111, 247, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 0, - 1, 0, 0, 0, 31, 0, 0, 0, - 4, 0, 0, 0, 2, 0, 1, 0, - 145, 86, 159, 205, 47, 33, 203, 200, - 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 39, 0, 0, 0, - 8, 0, 0, 0, 1, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 8, 0, 0, 0, 3, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, - 16, 0, 0, 0, 3, 0, 1, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 1, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 18, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 3, 0, 0, 0, 0, 0, - 212, 62, 10, 208, 108, 40, 255, 132, - 0, 0, 0, 0, 0, 0, 0, 0, - 16, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, } -}; -::capnp::word const* const bp_ecafa18b482da3aa = b_ecafa18b482da3aa.words; -#if !CAPNP_LITE -static const ::capnp::_::RawSchema* const d_ecafa18b482da3aa[] = { - &s_c8cb212fcd9f5691, - &s_f76fba59183073a5, -}; -static const uint16_t m_ecafa18b482da3aa[] = {0, 1}; -static const uint16_t i_ecafa18b482da3aa[] = {0, 1}; -KJ_CONSTEXPR(const) ::capnp::_::RawBrandedSchema::Dependency bd_ecafa18b482da3aa[] = { - { 16777216, ::capnp::Persistent< ::capnp::AnyPointer, ::capnp::AnyPointer>::_capnpPrivate::brand() }, - { 16777217, ::capnp::Persistent< ::capnp::AnyPointer, ::capnp::AnyPointer>::SaveParams::_capnpPrivate::brand() }, -}; -const ::capnp::_::RawSchema s_ecafa18b482da3aa = { - 0xecafa18b482da3aa, b_ecafa18b482da3aa.words, 86, d_ecafa18b482da3aa, m_ecafa18b482da3aa, - 2, 2, i_ecafa18b482da3aa, nullptr, nullptr, { &s_ecafa18b482da3aa, nullptr, bd_ecafa18b482da3aa, 0, sizeof(bd_ecafa18b482da3aa) / sizeof(bd_ecafa18b482da3aa[0]), nullptr } + 0, 1, i_b76848c18c40efbf, nullptr, nullptr, { &s_b76848c18c40efbf, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<22> b_f622595091cafb67 = { @@ -528,7 +198,7 @@ static const ::capnp::_::AlignedData<22> b_f622595091cafb67 = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_f622595091cafb67 = { 0xf622595091cafb67, b_f622595091cafb67.words, 22, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_f622595091cafb67, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_f622595091cafb67, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas diff --git a/c++/src/capnp/persistent.capnp.h b/c++/src/capnp/persistent.capnp.h index f9b443220a..cd65e86674 100644 --- a/c++/src/capnp/persistent.capnp.h +++ b/c++/src/capnp/persistent.capnp.h @@ -1,28 +1,29 @@ // Generated by Cap'n Proto compiler, DO NOT EDIT // source: persistent.capnp -#ifndef CAPNP_INCLUDED_b8630836983feed7_ -#define CAPNP_INCLUDED_b8630836983feed7_ +#pragma once #include +#include #if !CAPNP_LITE #include #endif // !CAPNP_LITE -#if CAPNP_VERSION != 6001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace schemas { CAPNP_DECLARE_SCHEMA(c8cb212fcd9f5691); CAPNP_DECLARE_SCHEMA(f76fba59183073a5); CAPNP_DECLARE_SCHEMA(b76848c18c40efbf); -CAPNP_DECLARE_SCHEMA(84ff286cd00a3ed4); -CAPNP_DECLARE_SCHEMA(f0c2cc1d3909574d); -CAPNP_DECLARE_SCHEMA(ecafa18b482da3aa); CAPNP_DECLARE_SCHEMA(f622595091cafb67); } // namespace schemas @@ -92,70 +93,6 @@ struct Persistent::SaveResults { }; }; -template -struct RealmGateway { - RealmGateway() = delete; - -#if !CAPNP_LITE - class Client; - class Server; -#endif // !CAPNP_LITE - - struct ImportParams; - struct ExportParams; - - #if !CAPNP_LITE - struct _capnpPrivate { - CAPNP_DECLARE_INTERFACE_HEADER(84ff286cd00a3ed4) - static const ::capnp::_::RawBrandedSchema::Scope brandScopes[]; - static const ::capnp::_::RawBrandedSchema::Binding brandBindings[]; - static const ::capnp::_::RawBrandedSchema::Dependency brandDependencies[]; - static const ::capnp::_::RawBrandedSchema specificBrand; - static constexpr ::capnp::_::RawBrandedSchema const* brand() { return ::capnp::_::ChooseBrand<_capnpPrivate, InternalRef, ExternalRef, InternalOwner, ExternalOwner>::brand(); } - }; - #endif // !CAPNP_LITE -}; - -template -struct RealmGateway::ImportParams { - ImportParams() = delete; - - class Reader; - class Builder; - class Pipeline; - - struct _capnpPrivate { - CAPNP_DECLARE_STRUCT_HEADER(f0c2cc1d3909574d, 0, 2) - #if !CAPNP_LITE - static const ::capnp::_::RawBrandedSchema::Scope brandScopes[]; - static const ::capnp::_::RawBrandedSchema::Binding brandBindings[]; - static const ::capnp::_::RawBrandedSchema::Dependency brandDependencies[]; - static const ::capnp::_::RawBrandedSchema specificBrand; - static constexpr ::capnp::_::RawBrandedSchema const* brand() { return ::capnp::_::ChooseBrand<_capnpPrivate, InternalRef, ExternalRef, InternalOwner, ExternalOwner>::brand(); } - #endif // !CAPNP_LITE - }; -}; - -template -struct RealmGateway::ExportParams { - ExportParams() = delete; - - class Reader; - class Builder; - class Pipeline; - - struct _capnpPrivate { - CAPNP_DECLARE_STRUCT_HEADER(ecafa18b482da3aa, 0, 2) - #if !CAPNP_LITE - static const ::capnp::_::RawBrandedSchema::Scope brandScopes[]; - static const ::capnp::_::RawBrandedSchema::Binding brandBindings[]; - static const ::capnp::_::RawBrandedSchema::Dependency brandDependencies[]; - static const ::capnp::_::RawBrandedSchema specificBrand; - static constexpr ::capnp::_::RawBrandedSchema const* brand() { return ::capnp::_::ChooseBrand<_capnpPrivate, InternalRef, ExternalRef, InternalOwner, ExternalOwner>::brand(); } - #endif // !CAPNP_LITE - }; -}; - // ======================================================================================= #if !CAPNP_LITE @@ -196,7 +133,8 @@ class Persistent::Server public: typedef Persistent Serves; - ::kj::Promise dispatchCall(uint64_t interfaceId, uint16_t methodId, + ::capnp::Capability::Server::DispatchCallResult dispatchCall( + uint64_t interfaceId, uint16_t methodId, ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) override; @@ -209,7 +147,8 @@ class Persistent::Server .template castAs< ::capnp::Persistent>(); } - ::kj::Promise dispatchCallInternal(uint16_t methodId, + ::capnp::Capability::Server::DispatchCallResult dispatchCallInternal( + uint16_t methodId, ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context); }; #endif // !CAPNP_LITE @@ -406,288 +345,6 @@ class Persistent::SaveResults::Pipeline { }; #endif // !CAPNP_LITE -#if !CAPNP_LITE -template -class RealmGateway::Client - : public virtual ::capnp::Capability::Client { -public: - typedef RealmGateway Calls; - typedef RealmGateway Reads; - - Client(decltype(nullptr)); - explicit Client(::kj::Own< ::capnp::ClientHook>&& hook); - template ()>> - Client(::kj::Own<_t>&& server); - template ()>> - Client(::kj::Promise<_t>&& promise); - Client(::kj::Exception&& exception); - Client(Client&) = default; - Client(Client&&) = default; - Client& operator=(Client& other); - Client& operator=(Client&& other); - - template - typename RealmGateway::Client asGeneric() { - return castAs>(); - } - - CAPNP_AUTO_IF_MSVC(::capnp::Request::ImportParams, typename ::capnp::Persistent::SaveResults>) importRequest( - ::kj::Maybe< ::capnp::MessageSize> sizeHint = nullptr); - CAPNP_AUTO_IF_MSVC(::capnp::Request::ExportParams, typename ::capnp::Persistent::SaveResults>) exportRequest( - ::kj::Maybe< ::capnp::MessageSize> sizeHint = nullptr); - -protected: - Client() = default; -}; - -template -class RealmGateway::Server - : public virtual ::capnp::Capability::Server { -public: - typedef RealmGateway Serves; - - ::kj::Promise dispatchCall(uint64_t interfaceId, uint16_t methodId, - ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) - override; - -protected: - typedef typename ::capnp::RealmGateway::ImportParams ImportParams; - typedef ::capnp::CallContext::SaveResults> ImportContext; - virtual ::kj::Promise import(ImportContext context); - typedef typename ::capnp::RealmGateway::ExportParams ExportParams; - typedef ::capnp::CallContext::SaveResults> ExportContext; - virtual ::kj::Promise export_(ExportContext context); - - inline typename ::capnp::RealmGateway::Client thisCap() { - return ::capnp::Capability::Server::thisCap() - .template castAs< ::capnp::RealmGateway>(); - } - - ::kj::Promise dispatchCallInternal(uint16_t methodId, - ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context); -}; -#endif // !CAPNP_LITE - -template -class RealmGateway::ImportParams::Reader { -public: - typedef ImportParams Reads; - - Reader() = default; - inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} - - inline ::capnp::MessageSize totalSize() const { - return _reader.totalSize().asPublic(); - } - -#if !CAPNP_LITE - inline ::kj::StringTree toString() const { - return ::capnp::_::structString(_reader, *_capnpPrivate::brand()); - } -#endif // !CAPNP_LITE - - template - typename RealmGateway::ImportParams::Reader asRealmGatewayGeneric() { - return typename RealmGateway::ImportParams::Reader(_reader); - } - - inline bool hasCap() const; -#if !CAPNP_LITE - inline typename ::capnp::Persistent::Client getCap() const; -#endif // !CAPNP_LITE - - inline bool hasParams() const; - inline typename ::capnp::Persistent::SaveParams::Reader getParams() const; - -private: - ::capnp::_::StructReader _reader; - template - friend struct ::capnp::ToDynamic_; - template - friend struct ::capnp::_::PointerHelpers; - template - friend struct ::capnp::List; - friend class ::capnp::MessageBuilder; - friend class ::capnp::Orphanage; -}; - -template -class RealmGateway::ImportParams::Builder { -public: - typedef ImportParams Builds; - - Builder() = delete; // Deleted to discourage incorrect usage. - // You can explicitly initialize to nullptr instead. - inline Builder(decltype(nullptr)) {} - inline explicit Builder(::capnp::_::StructBuilder base): _builder(base) {} - inline operator Reader() const { return Reader(_builder.asReader()); } - inline Reader asReader() const { return *this; } - - inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); } -#if !CAPNP_LITE - inline ::kj::StringTree toString() const { return asReader().toString(); } -#endif // !CAPNP_LITE - - template - typename RealmGateway::ImportParams::Builder asRealmGatewayGeneric() { - return typename RealmGateway::ImportParams::Builder(_builder); - } - - inline bool hasCap(); -#if !CAPNP_LITE - inline typename ::capnp::Persistent::Client getCap(); - inline void setCap(typename ::capnp::Persistent::Client&& value); - inline void setCap(typename ::capnp::Persistent::Client& value); - inline void adoptCap(::capnp::Orphan< ::capnp::Persistent>&& value); - inline ::capnp::Orphan< ::capnp::Persistent> disownCap(); -#endif // !CAPNP_LITE - - inline bool hasParams(); - inline typename ::capnp::Persistent::SaveParams::Builder getParams(); - inline void setParams(typename ::capnp::Persistent::SaveParams::Reader value); - inline typename ::capnp::Persistent::SaveParams::Builder initParams(); - inline void adoptParams(::capnp::Orphan::SaveParams>&& value); - inline ::capnp::Orphan::SaveParams> disownParams(); - -private: - ::capnp::_::StructBuilder _builder; - template - friend struct ::capnp::ToDynamic_; - friend class ::capnp::Orphanage; - template - friend struct ::capnp::_::PointerHelpers; -}; - -#if !CAPNP_LITE -template -class RealmGateway::ImportParams::Pipeline { -public: - typedef ImportParams Pipelines; - - inline Pipeline(decltype(nullptr)): _typeless(nullptr) {} - inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) - : _typeless(kj::mv(typeless)) {} - - inline typename ::capnp::Persistent::Client getCap(); - inline typename ::capnp::Persistent::SaveParams::Pipeline getParams(); -private: - ::capnp::AnyPointer::Pipeline _typeless; - friend class ::capnp::PipelineHook; - template - friend struct ::capnp::ToDynamic_; -}; -#endif // !CAPNP_LITE - -template -class RealmGateway::ExportParams::Reader { -public: - typedef ExportParams Reads; - - Reader() = default; - inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} - - inline ::capnp::MessageSize totalSize() const { - return _reader.totalSize().asPublic(); - } - -#if !CAPNP_LITE - inline ::kj::StringTree toString() const { - return ::capnp::_::structString(_reader, *_capnpPrivate::brand()); - } -#endif // !CAPNP_LITE - - template - typename RealmGateway::ExportParams::Reader asRealmGatewayGeneric() { - return typename RealmGateway::ExportParams::Reader(_reader); - } - - inline bool hasCap() const; -#if !CAPNP_LITE - inline typename ::capnp::Persistent::Client getCap() const; -#endif // !CAPNP_LITE - - inline bool hasParams() const; - inline typename ::capnp::Persistent::SaveParams::Reader getParams() const; - -private: - ::capnp::_::StructReader _reader; - template - friend struct ::capnp::ToDynamic_; - template - friend struct ::capnp::_::PointerHelpers; - template - friend struct ::capnp::List; - friend class ::capnp::MessageBuilder; - friend class ::capnp::Orphanage; -}; - -template -class RealmGateway::ExportParams::Builder { -public: - typedef ExportParams Builds; - - Builder() = delete; // Deleted to discourage incorrect usage. - // You can explicitly initialize to nullptr instead. - inline Builder(decltype(nullptr)) {} - inline explicit Builder(::capnp::_::StructBuilder base): _builder(base) {} - inline operator Reader() const { return Reader(_builder.asReader()); } - inline Reader asReader() const { return *this; } - - inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); } -#if !CAPNP_LITE - inline ::kj::StringTree toString() const { return asReader().toString(); } -#endif // !CAPNP_LITE - - template - typename RealmGateway::ExportParams::Builder asRealmGatewayGeneric() { - return typename RealmGateway::ExportParams::Builder(_builder); - } - - inline bool hasCap(); -#if !CAPNP_LITE - inline typename ::capnp::Persistent::Client getCap(); - inline void setCap(typename ::capnp::Persistent::Client&& value); - inline void setCap(typename ::capnp::Persistent::Client& value); - inline void adoptCap(::capnp::Orphan< ::capnp::Persistent>&& value); - inline ::capnp::Orphan< ::capnp::Persistent> disownCap(); -#endif // !CAPNP_LITE - - inline bool hasParams(); - inline typename ::capnp::Persistent::SaveParams::Builder getParams(); - inline void setParams(typename ::capnp::Persistent::SaveParams::Reader value); - inline typename ::capnp::Persistent::SaveParams::Builder initParams(); - inline void adoptParams(::capnp::Orphan::SaveParams>&& value); - inline ::capnp::Orphan::SaveParams> disownParams(); - -private: - ::capnp::_::StructBuilder _builder; - template - friend struct ::capnp::ToDynamic_; - friend class ::capnp::Orphanage; - template - friend struct ::capnp::_::PointerHelpers; -}; - -#if !CAPNP_LITE -template -class RealmGateway::ExportParams::Pipeline { -public: - typedef ExportParams Pipelines; - - inline Pipeline(decltype(nullptr)): _typeless(nullptr) {} - inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) - : _typeless(kj::mv(typeless)) {} - - inline typename ::capnp::Persistent::Client getCap(); - inline typename ::capnp::Persistent::SaveParams::Pipeline getParams(); -private: - ::capnp::AnyPointer::Pipeline _typeless; - friend class ::capnp::PipelineHook; - template - friend struct ::capnp::ToDynamic_; -}; -#endif // !CAPNP_LITE - // ======================================================================================= #if !CAPNP_LITE @@ -775,15 +432,19 @@ inline ::capnp::Orphan Persistent::SaveParams::Builder: } // Persistent::SaveParams +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr uint16_t Persistent::SaveParams::_capnpPrivate::dataWordSize; template constexpr uint16_t Persistent::SaveParams::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr ::capnp::Kind Persistent::SaveParams::_capnpPrivate::kind; template constexpr ::capnp::_::RawSchema const* Persistent::SaveParams::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template const ::capnp::_::RawBrandedSchema::Scope Persistent::SaveParams::_capnpPrivate::brandScopes[] = { { 0xc8cb212fcd9f5691, brandBindings + 0, 2, false}, @@ -854,15 +515,19 @@ inline ::capnp::Orphan Persistent::SaveResults::Bui } // Persistent::SaveResults +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr uint16_t Persistent::SaveResults::_capnpPrivate::dataWordSize; template constexpr uint16_t Persistent::SaveResults::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr ::capnp::Kind Persistent::SaveResults::_capnpPrivate::kind; template constexpr ::capnp::_::RawSchema const* Persistent::SaveResults::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template const ::capnp::_::RawBrandedSchema::Scope Persistent::SaveResults::_capnpPrivate::brandScopes[] = { { 0xc8cb212fcd9f5691, brandBindings + 0, 2, false}, @@ -884,7 +549,7 @@ template CAPNP_AUTO_IF_MSVC(::capnp::Request::SaveParams, typename ::capnp::Persistent::SaveResults>) Persistent::Client::saveRequest(::kj::Maybe< ::capnp::MessageSize> sizeHint) { return newCall::SaveParams, typename ::capnp::Persistent::SaveResults>( - 0xc8cb212fcd9f5691ull, 0, sizeHint); + 0xc8cb212fcd9f5691ull, 0, sizeHint, {false}); } template ::kj::Promise Persistent::Server::save(SaveContext) { @@ -893,7 +558,7 @@ ::kj::Promise Persistent::Server::save(SaveContext) { 0xc8cb212fcd9f5691ull, 0); } template -::kj::Promise Persistent::Server::dispatchCall( +::capnp::Capability::Server::DispatchCallResult Persistent::Server::dispatchCall( uint64_t interfaceId, uint16_t methodId, ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) { switch (interfaceId) { @@ -904,13 +569,17 @@ ::kj::Promise Persistent::Server::dispatchCall( } } template -::kj::Promise Persistent::Server::dispatchCallInternal( +::capnp::Capability::Server::DispatchCallResult Persistent::Server::dispatchCallInternal( uint16_t methodId, ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) { switch (methodId) { case 0: - return save(::capnp::Capability::Server::internalGetTypedContext< - typename ::capnp::Persistent::SaveParams, typename ::capnp::Persistent::SaveResults>(context)); + return { + save(::capnp::Capability::Server::internalGetTypedContext< + typename ::capnp::Persistent::SaveParams, typename ::capnp::Persistent::SaveResults>(context)), + false, + false + }; default: (void)context; return ::capnp::Capability::Server::internalUnimplemented( @@ -922,10 +591,12 @@ ::kj::Promise Persistent::Server::dispatchCallInternal( // Persistent #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template constexpr ::capnp::Kind Persistent::_capnpPrivate::kind; template constexpr ::capnp::_::RawSchema const* Persistent::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL template const ::capnp::_::RawBrandedSchema::Scope Persistent::_capnpPrivate::brandScopes[] = { { 0xc8cb212fcd9f5691, brandBindings + 0, 2, false}, @@ -947,382 +618,7 @@ const ::capnp::_::RawBrandedSchema Persistent::_capnpPrivate:: }; #endif // !CAPNP_LITE -#if !CAPNP_LITE -template -inline RealmGateway::Client::Client(decltype(nullptr)) - : ::capnp::Capability::Client(nullptr) {} -template -inline RealmGateway::Client::Client( - ::kj::Own< ::capnp::ClientHook>&& hook) - : ::capnp::Capability::Client(::kj::mv(hook)) {} -template -template -inline RealmGateway::Client::Client(::kj::Own<_t>&& server) - : ::capnp::Capability::Client(::kj::mv(server)) {} -template -template -inline RealmGateway::Client::Client(::kj::Promise<_t>&& promise) - : ::capnp::Capability::Client(::kj::mv(promise)) {} -template -inline RealmGateway::Client::Client(::kj::Exception&& exception) - : ::capnp::Capability::Client(::kj::mv(exception)) {} -template -inline typename ::capnp::RealmGateway::Client& RealmGateway::Client::operator=(Client& other) { - ::capnp::Capability::Client::operator=(other); - return *this; -} -template -inline typename ::capnp::RealmGateway::Client& RealmGateway::Client::operator=(Client&& other) { - ::capnp::Capability::Client::operator=(kj::mv(other)); - return *this; -} - -#endif // !CAPNP_LITE -template -inline bool RealmGateway::ImportParams::Reader::hasCap() const { - return !_reader.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); -} -template -inline bool RealmGateway::ImportParams::Builder::hasCap() { - return !_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); -} -#if !CAPNP_LITE -template -inline typename ::capnp::Persistent::Client RealmGateway::ImportParams::Reader::getCap() const { - return ::capnp::_::PointerHelpers< ::capnp::Persistent>::get(_reader.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS)); -} -template -inline typename ::capnp::Persistent::Client RealmGateway::ImportParams::Builder::getCap() { - return ::capnp::_::PointerHelpers< ::capnp::Persistent>::get(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS)); -} -template -inline typename ::capnp::Persistent::Client RealmGateway::ImportParams::Pipeline::getCap() { - return typename ::capnp::Persistent::Client(_typeless.getPointerField(0).asCap()); -} -template -inline void RealmGateway::ImportParams::Builder::setCap(typename ::capnp::Persistent::Client&& cap) { - ::capnp::_::PointerHelpers< ::capnp::Persistent>::set(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(cap)); -} -template -inline void RealmGateway::ImportParams::Builder::setCap(typename ::capnp::Persistent::Client& cap) { - ::capnp::_::PointerHelpers< ::capnp::Persistent>::set(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS), cap); -} -template -inline void RealmGateway::ImportParams::Builder::adoptCap( - ::capnp::Orphan< ::capnp::Persistent>&& value) { - ::capnp::_::PointerHelpers< ::capnp::Persistent>::adopt(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); -} -template -inline ::capnp::Orphan< ::capnp::Persistent> RealmGateway::ImportParams::Builder::disownCap() { - return ::capnp::_::PointerHelpers< ::capnp::Persistent>::disown(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS)); -} -#endif // !CAPNP_LITE - -template -inline bool RealmGateway::ImportParams::Reader::hasParams() const { - return !_reader.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); -} -template -inline bool RealmGateway::ImportParams::Builder::hasParams() { - return !_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); -} -template -inline typename ::capnp::Persistent::SaveParams::Reader RealmGateway::ImportParams::Reader::getParams() const { - return ::capnp::_::PointerHelpers::SaveParams>::get(_reader.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS)); -} -template -inline typename ::capnp::Persistent::SaveParams::Builder RealmGateway::ImportParams::Builder::getParams() { - return ::capnp::_::PointerHelpers::SaveParams>::get(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS)); -} -#if !CAPNP_LITE -template -inline typename ::capnp::Persistent::SaveParams::Pipeline RealmGateway::ImportParams::Pipeline::getParams() { - return typename ::capnp::Persistent::SaveParams::Pipeline(_typeless.getPointerField(1)); -} -#endif // !CAPNP_LITE -template -inline void RealmGateway::ImportParams::Builder::setParams(typename ::capnp::Persistent::SaveParams::Reader value) { - ::capnp::_::PointerHelpers::SaveParams>::set(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS), value); -} -template -inline typename ::capnp::Persistent::SaveParams::Builder RealmGateway::ImportParams::Builder::initParams() { - return ::capnp::_::PointerHelpers::SaveParams>::init(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS)); -} -template -inline void RealmGateway::ImportParams::Builder::adoptParams( - ::capnp::Orphan::SaveParams>&& value) { - ::capnp::_::PointerHelpers::SaveParams>::adopt(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); -} -template -inline ::capnp::Orphan::SaveParams> RealmGateway::ImportParams::Builder::disownParams() { - return ::capnp::_::PointerHelpers::SaveParams>::disown(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS)); -} - -// RealmGateway::ImportParams -template -constexpr uint16_t RealmGateway::ImportParams::_capnpPrivate::dataWordSize; -template -constexpr uint16_t RealmGateway::ImportParams::_capnpPrivate::pointerCount; -#if !CAPNP_LITE -template -constexpr ::capnp::Kind RealmGateway::ImportParams::_capnpPrivate::kind; -template -constexpr ::capnp::_::RawSchema const* RealmGateway::ImportParams::_capnpPrivate::schema; -template -const ::capnp::_::RawBrandedSchema::Scope RealmGateway::ImportParams::_capnpPrivate::brandScopes[] = { - { 0x84ff286cd00a3ed4, brandBindings + 0, 4, false}, -}; -template -const ::capnp::_::RawBrandedSchema::Binding RealmGateway::ImportParams::_capnpPrivate::brandBindings[] = { - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), -}; -template -const ::capnp::_::RawBrandedSchema::Dependency RealmGateway::ImportParams::_capnpPrivate::brandDependencies[] = { - { 16777216, ::capnp::Persistent::_capnpPrivate::brand() }, - { 16777217, ::capnp::Persistent::SaveParams::_capnpPrivate::brand() }, -}; -template -const ::capnp::_::RawBrandedSchema RealmGateway::ImportParams::_capnpPrivate::specificBrand = { - &::capnp::schemas::s_f0c2cc1d3909574d, brandScopes, brandDependencies, - 1, 2, nullptr -}; -#endif // !CAPNP_LITE - -template -inline bool RealmGateway::ExportParams::Reader::hasCap() const { - return !_reader.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); -} -template -inline bool RealmGateway::ExportParams::Builder::hasCap() { - return !_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); -} -#if !CAPNP_LITE -template -inline typename ::capnp::Persistent::Client RealmGateway::ExportParams::Reader::getCap() const { - return ::capnp::_::PointerHelpers< ::capnp::Persistent>::get(_reader.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS)); -} -template -inline typename ::capnp::Persistent::Client RealmGateway::ExportParams::Builder::getCap() { - return ::capnp::_::PointerHelpers< ::capnp::Persistent>::get(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS)); -} -template -inline typename ::capnp::Persistent::Client RealmGateway::ExportParams::Pipeline::getCap() { - return typename ::capnp::Persistent::Client(_typeless.getPointerField(0).asCap()); -} -template -inline void RealmGateway::ExportParams::Builder::setCap(typename ::capnp::Persistent::Client&& cap) { - ::capnp::_::PointerHelpers< ::capnp::Persistent>::set(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(cap)); -} -template -inline void RealmGateway::ExportParams::Builder::setCap(typename ::capnp::Persistent::Client& cap) { - ::capnp::_::PointerHelpers< ::capnp::Persistent>::set(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS), cap); -} -template -inline void RealmGateway::ExportParams::Builder::adoptCap( - ::capnp::Orphan< ::capnp::Persistent>&& value) { - ::capnp::_::PointerHelpers< ::capnp::Persistent>::adopt(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); -} -template -inline ::capnp::Orphan< ::capnp::Persistent> RealmGateway::ExportParams::Builder::disownCap() { - return ::capnp::_::PointerHelpers< ::capnp::Persistent>::disown(_builder.getPointerField( - ::capnp::bounded<0>() * ::capnp::POINTERS)); -} -#endif // !CAPNP_LITE - -template -inline bool RealmGateway::ExportParams::Reader::hasParams() const { - return !_reader.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); -} -template -inline bool RealmGateway::ExportParams::Builder::hasParams() { - return !_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); -} -template -inline typename ::capnp::Persistent::SaveParams::Reader RealmGateway::ExportParams::Reader::getParams() const { - return ::capnp::_::PointerHelpers::SaveParams>::get(_reader.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS)); -} -template -inline typename ::capnp::Persistent::SaveParams::Builder RealmGateway::ExportParams::Builder::getParams() { - return ::capnp::_::PointerHelpers::SaveParams>::get(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS)); -} -#if !CAPNP_LITE -template -inline typename ::capnp::Persistent::SaveParams::Pipeline RealmGateway::ExportParams::Pipeline::getParams() { - return typename ::capnp::Persistent::SaveParams::Pipeline(_typeless.getPointerField(1)); -} -#endif // !CAPNP_LITE -template -inline void RealmGateway::ExportParams::Builder::setParams(typename ::capnp::Persistent::SaveParams::Reader value) { - ::capnp::_::PointerHelpers::SaveParams>::set(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS), value); -} -template -inline typename ::capnp::Persistent::SaveParams::Builder RealmGateway::ExportParams::Builder::initParams() { - return ::capnp::_::PointerHelpers::SaveParams>::init(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS)); -} -template -inline void RealmGateway::ExportParams::Builder::adoptParams( - ::capnp::Orphan::SaveParams>&& value) { - ::capnp::_::PointerHelpers::SaveParams>::adopt(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); -} -template -inline ::capnp::Orphan::SaveParams> RealmGateway::ExportParams::Builder::disownParams() { - return ::capnp::_::PointerHelpers::SaveParams>::disown(_builder.getPointerField( - ::capnp::bounded<1>() * ::capnp::POINTERS)); -} - -// RealmGateway::ExportParams -template -constexpr uint16_t RealmGateway::ExportParams::_capnpPrivate::dataWordSize; -template -constexpr uint16_t RealmGateway::ExportParams::_capnpPrivate::pointerCount; -#if !CAPNP_LITE -template -constexpr ::capnp::Kind RealmGateway::ExportParams::_capnpPrivate::kind; -template -constexpr ::capnp::_::RawSchema const* RealmGateway::ExportParams::_capnpPrivate::schema; -template -const ::capnp::_::RawBrandedSchema::Scope RealmGateway::ExportParams::_capnpPrivate::brandScopes[] = { - { 0x84ff286cd00a3ed4, brandBindings + 0, 4, false}, -}; -template -const ::capnp::_::RawBrandedSchema::Binding RealmGateway::ExportParams::_capnpPrivate::brandBindings[] = { - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), -}; -template -const ::capnp::_::RawBrandedSchema::Dependency RealmGateway::ExportParams::_capnpPrivate::brandDependencies[] = { - { 16777216, ::capnp::Persistent::_capnpPrivate::brand() }, - { 16777217, ::capnp::Persistent::SaveParams::_capnpPrivate::brand() }, -}; -template -const ::capnp::_::RawBrandedSchema RealmGateway::ExportParams::_capnpPrivate::specificBrand = { - &::capnp::schemas::s_ecafa18b482da3aa, brandScopes, brandDependencies, - 1, 2, nullptr -}; -#endif // !CAPNP_LITE - -#if !CAPNP_LITE -template -CAPNP_AUTO_IF_MSVC(::capnp::Request::ImportParams, typename ::capnp::Persistent::SaveResults>) -RealmGateway::Client::importRequest(::kj::Maybe< ::capnp::MessageSize> sizeHint) { - return newCall::ImportParams, typename ::capnp::Persistent::SaveResults>( - 0x84ff286cd00a3ed4ull, 0, sizeHint); -} -template -::kj::Promise RealmGateway::Server::import(ImportContext) { - return ::capnp::Capability::Server::internalUnimplemented( - "capnp/persistent.capnp:RealmGateway", "import", - 0x84ff286cd00a3ed4ull, 0); -} -template -CAPNP_AUTO_IF_MSVC(::capnp::Request::ExportParams, typename ::capnp::Persistent::SaveResults>) -RealmGateway::Client::exportRequest(::kj::Maybe< ::capnp::MessageSize> sizeHint) { - return newCall::ExportParams, typename ::capnp::Persistent::SaveResults>( - 0x84ff286cd00a3ed4ull, 1, sizeHint); -} -template -::kj::Promise RealmGateway::Server::export_(ExportContext) { - return ::capnp::Capability::Server::internalUnimplemented( - "capnp/persistent.capnp:RealmGateway", "export", - 0x84ff286cd00a3ed4ull, 1); -} -template -::kj::Promise RealmGateway::Server::dispatchCall( - uint64_t interfaceId, uint16_t methodId, - ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) { - switch (interfaceId) { - case 0x84ff286cd00a3ed4ull: - return dispatchCallInternal(methodId, context); - default: - return internalUnimplemented("capnp/persistent.capnp:RealmGateway", interfaceId); - } -} -template -::kj::Promise RealmGateway::Server::dispatchCallInternal( - uint16_t methodId, - ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) { - switch (methodId) { - case 0: - return import(::capnp::Capability::Server::internalGetTypedContext< - typename ::capnp::RealmGateway::ImportParams, typename ::capnp::Persistent::SaveResults>(context)); - case 1: - return export_(::capnp::Capability::Server::internalGetTypedContext< - typename ::capnp::RealmGateway::ExportParams, typename ::capnp::Persistent::SaveResults>(context)); - default: - (void)context; - return ::capnp::Capability::Server::internalUnimplemented( - "capnp/persistent.capnp:RealmGateway", - 0x84ff286cd00a3ed4ull, methodId); - } -} -#endif // !CAPNP_LITE - -// RealmGateway -#if !CAPNP_LITE -template -constexpr ::capnp::Kind RealmGateway::_capnpPrivate::kind; -template -constexpr ::capnp::_::RawSchema const* RealmGateway::_capnpPrivate::schema; -template -const ::capnp::_::RawBrandedSchema::Scope RealmGateway::_capnpPrivate::brandScopes[] = { - { 0x84ff286cd00a3ed4, brandBindings + 0, 4, false}, -}; -template -const ::capnp::_::RawBrandedSchema::Binding RealmGateway::_capnpPrivate::brandBindings[] = { - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), - ::capnp::_::brandBindingFor(), -}; -template -const ::capnp::_::RawBrandedSchema::Dependency RealmGateway::_capnpPrivate::brandDependencies[] = { - { 33554432, ::capnp::RealmGateway::ImportParams::_capnpPrivate::brand() }, - { 33554433, ::capnp::RealmGateway::ExportParams::_capnpPrivate::brand() }, - { 50331648, ::capnp::Persistent::SaveResults::_capnpPrivate::brand() }, - { 50331649, ::capnp::Persistent::SaveResults::_capnpPrivate::brand() }, -}; -template -const ::capnp::_::RawBrandedSchema RealmGateway::_capnpPrivate::specificBrand = { - &::capnp::schemas::s_84ff286cd00a3ed4, brandScopes, brandDependencies, - 1, 4, nullptr -}; -#endif // !CAPNP_LITE - } // namespace -#endif // CAPNP_INCLUDED_b8630836983feed7_ +CAPNP_END_HEADER + diff --git a/c++/src/capnp/pointer-helpers.h b/c++/src/capnp/pointer-helpers.h index fe70e5036f..c5ce574527 100644 --- a/c++/src/capnp/pointer-helpers.h +++ b/c++/src/capnp/pointer-helpers.h @@ -19,16 +19,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_POINTER_HELPERS_H_ -#define CAPNP_POINTER_HELPERS_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "layout.h" #include "list.h" +CAPNP_BEGIN_HEADER + namespace capnp { namespace _ { // private @@ -157,4 +154,4 @@ struct PointerHelpers { } // namespace _ (private) } // namespace capnp -#endif // CAPNP_POINTER_HELPERS_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/pretty-print.h b/c++/src/capnp/pretty-print.h index e6458bca49..f3c6ced82f 100644 --- a/c++/src/capnp/pretty-print.h +++ b/c++/src/capnp/pretty-print.h @@ -19,16 +19,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_PRETTY_PRINT_H_ -#define CAPNP_PRETTY_PRINT_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "dynamic.h" #include +CAPNP_BEGIN_HEADER + namespace capnp { kj::StringTree prettyPrint(DynamicStruct::Reader value); @@ -44,4 +41,4 @@ kj::StringTree prettyPrint(DynamicList::Builder value); } // namespace capnp -#endif // PRETTY_PRINT_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/raw-schema.h b/c++/src/capnp/raw-schema.h index ed9425a624..44b696c5ca 100644 --- a/c++/src/capnp/raw-schema.h +++ b/c++/src/capnp/raw-schema.h @@ -19,19 +19,16 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_RAW_SCHEMA_H_ -#define CAPNP_RAW_SCHEMA_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "common.h" // for uint and friends -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) #include #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace _ { // private @@ -148,7 +145,7 @@ struct RawBrandedSchema { // is required in particular when traversing the dependency list. RawSchemas for compiled-in // types are always initialized; only dynamically-loaded schemas may be lazy. -#if __GNUC__ +#if __GNUC__ || defined(__clang__) const Initializer* i = __atomic_load_n(&lazyInitializer, __ATOMIC_ACQUIRE); #elif _MSC_VER const Initializer* i = *static_cast(&lazyInitializer); @@ -214,7 +211,7 @@ struct RawSchema { // is required in particular when traversing the dependency list. RawSchemas for compiled-in // types are always initialized; only dynamically-loaded schemas may be lazy. -#if __GNUC__ +#if __GNUC__ || defined(__clang__) const Initializer* i = __atomic_load_n(&lazyInitializer, __ATOMIC_ACQUIRE); #elif _MSC_VER const Initializer* i = *static_cast(&lazyInitializer); @@ -229,6 +226,9 @@ struct RawSchema { // Specifies the brand to use for this schema if no generic parameters have been bound to // anything. Generally, in the default brand, all generic parameters are treated as if they were // bound to `AnyPointer`. + + bool mayContainCapabilities = true; + // See StructSchema::mayContainCapabilities. }; inline bool RawBrandedSchema::isUnbound() const { @@ -239,4 +239,4 @@ inline bool RawBrandedSchema::isUnbound() const { } // namespace _ (private) } // namespace capnp -#endif // CAPNP_RAW_SCHEMA_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/reconnect-test.c++ b/c++/src/capnp/reconnect-test.c++ new file mode 100644 index 0000000000..12ef59333d --- /dev/null +++ b/c++/src/capnp/reconnect-test.c++ @@ -0,0 +1,228 @@ +// Copyright (c) 2020 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "reconnect.h" +#include "test-util.h" +#include +#include +#include +#include "rpc-twoparty.h" + +namespace capnp { +namespace _ { +namespace { + +class TestInterfaceImpl final: public test::TestInterface::Server { +public: + TestInterfaceImpl(uint generation): generation(generation) {} + + void setError(kj::Exception e) { + error = kj::mv(e); + } + + kj::Own> block() { + auto paf = kj::newPromiseAndFulfiller(); + blocker = paf.promise.fork(); + return kj::mv(paf.fulfiller); + } + +protected: + kj::Promise foo(FooContext context) override { + KJ_IF_MAYBE(e, error) { + return kj::cp(*e); + } + auto params = context.getParams(); + context.initResults().setX(kj::str(params.getI(), ' ', params.getJ(), ' ', generation)); + return blocker.addBranch(); + } + +private: + uint generation; + kj::Maybe error; + kj::ForkedPromise blocker = kj::Promise(kj::READY_NOW).fork(); +}; + +void doAutoReconnectTest(kj::WaitScope& ws, + kj::Function wrapClient) { + TestInterfaceImpl* currentServer = nullptr; + uint connectCount = 0; + + test::TestInterface::Client client = wrapClient(autoReconnect([&]() { + auto server = kj::heap(connectCount++); + currentServer = server; + return test::TestInterface::Client(kj::mv(server)); + })); + + auto testPromise = [&](uint i, bool j) { + auto req = client.fooRequest(); + req.setI(i); + req.setJ(j); + return req.send(); + }; + + auto test = [&](uint i, bool j) { + return kj::str(testPromise(i, j).wait(ws).getX()); + }; + + KJ_EXPECT(test(123, true) == "123 true 0"); + + currentServer->setError(KJ_EXCEPTION(DISCONNECTED, "test1 disconnect")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test1 disconnect", + testPromise(456, true).ignoreResult().wait(ws)); + + KJ_EXPECT(test(789, false) == "789 false 1"); + KJ_EXPECT(test(21, true) == "21 true 1"); + + { + // We cause two disconnect promises to be thrown concurrently. This should only cause the + // reconnector to reconnect once, not twice. + auto fulfiller = currentServer->block(); + auto promise1 = testPromise(32, false); + auto promise2 = testPromise(43, true); + KJ_EXPECT(!promise1.poll(ws)); + KJ_EXPECT(!promise2.poll(ws)); + fulfiller->reject(KJ_EXCEPTION(DISCONNECTED, "test2 disconnect")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test2 disconnect", promise1.ignoreResult().wait(ws)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test2 disconnect", promise2.ignoreResult().wait(ws)); + } + + KJ_EXPECT(test(43, false) == "43 false 2"); + + // Start a couple calls that will block at the server end, plus an unsent request. + auto fulfiller = currentServer->block(); + + auto promise1 = testPromise(1212, true); + auto promise2 = testPromise(3434, false); + auto req3 = client.fooRequest(); + req3.setI(5656); + req3.setJ(true); + KJ_EXPECT(!promise1.poll(ws)); + KJ_EXPECT(!promise2.poll(ws)); + + // Now force a reconnect. + currentServer->setError(KJ_EXCEPTION(DISCONNECTED, "test3 disconnect")); + + // Initiate a request that will fail with DISCONNECTED. + auto promise4 = testPromise(7878, false); + + // And throw away our capability entirely, just to make sure that anyone who needs it is holding + // onto their own ref. + client = nullptr; + + // Everything we initiated should still finish. + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test3 disconnect", promise4.ignoreResult().wait(ws)); + + // Send the request which we created before the disconnect. There are two behaviors we accept + // as correct here: it may throw the disconnect exception, or it may automatically redirect to + // the newly-reconnected destination. + req3.send().then([](Response resp) { + KJ_EXPECT(resp.getX() == "5656 true 3"); + }, [](kj::Exception e) { + KJ_EXPECT(e.getDescription().endsWith("test3 disconnect")); + }).wait(ws); + + KJ_EXPECT(!promise1.poll(ws)); + KJ_EXPECT(!promise2.poll(ws)); + fulfiller->fulfill(); + KJ_EXPECT(promise1.wait(ws).getX() == "1212 true 2"); + KJ_EXPECT(promise2.wait(ws).getX() == "3434 false 2"); +} + +KJ_TEST("autoReconnect() direct call (exercises newCall() / RequestHook)") { + kj::EventLoop loop; + kj::WaitScope ws(loop); + + doAutoReconnectTest(ws, [](auto c) {return kj::mv(c);}); +} + +KJ_TEST("autoReconnect() through RPC (exercises call() / CallContextHook)") { + kj::EventLoop loop; + kj::WaitScope ws(loop); + + auto paf = kj::newPromiseAndFulfiller(); + + auto pipe = kj::newTwoWayPipe(); + TwoPartyClient client(*pipe.ends[0]); + TwoPartyClient server(*pipe.ends[1], kj::mv(paf.promise), rpc::twoparty::Side::SERVER); + + doAutoReconnectTest(ws, [&](test::TestInterface::Client c) { + paf.fulfiller->fulfill(kj::mv(c)); + return client.bootstrap().castAs(); + }); +} + +KJ_TEST("lazyAutoReconnect() direct call (exercises newCall() / RequestHook)") { + kj::EventLoop loop; + kj::WaitScope ws(loop); + + doAutoReconnectTest(ws, [](auto c) {return kj::mv(c);}); +} + +KJ_TEST("lazyAutoReconnect() initialies lazily") { + kj::EventLoop loop; + kj::WaitScope ws(loop); + + int connectCount = 0; + TestInterfaceImpl* currentServer = nullptr; + auto connectCounter = [&]() { + auto server = kj::heap(connectCount++); + currentServer = server; + return test::TestInterface::Client(kj::mv(server)); + }; + + test::TestInterface::Client client = autoReconnect(connectCounter); + + auto test = [&](uint i, bool j) { + auto req = client.fooRequest(); + req.setI(i); + req.setJ(j); + return kj::str(req.send().wait(ws).getX()); + }; + auto testIgnoreResult = [&](uint i, bool j) { + auto req = client.fooRequest(); + req.setI(i); + req.setJ(j); + req.send().ignoreResult().wait(ws); + }; + + KJ_EXPECT(connectCount == 1); + KJ_EXPECT(test(123, true) == "123 true 0"); + KJ_EXPECT(connectCount == 1); + + client = lazyAutoReconnect(connectCounter); + KJ_EXPECT(connectCount == 1); + KJ_EXPECT(test(123, true) == "123 true 1"); + KJ_EXPECT(connectCount == 2); + KJ_EXPECT(test(234, false) == "234 false 1"); + KJ_EXPECT(connectCount == 2); + + currentServer->setError(KJ_EXCEPTION(DISCONNECTED, "test1 disconnect")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test1 disconnect", testIgnoreResult(345, true)); + + // lazyAutoReconnect is only lazy on the first request, not on reconnects. + KJ_EXPECT(connectCount == 3); + KJ_EXPECT(test(456, false) == "456 false 2"); + KJ_EXPECT(connectCount == 3); +} + +} // namespace +} // namespace _ +} // namespace capnp diff --git a/c++/src/capnp/reconnect.c++ b/c++/src/capnp/reconnect.c++ new file mode 100644 index 0000000000..2a8c67f6c2 --- /dev/null +++ b/c++/src/capnp/reconnect.c++ @@ -0,0 +1,161 @@ +// Copyright (c) 2020 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "reconnect.h" + +namespace capnp { + +namespace { + +class ReconnectHook final: public ClientHook, public kj::Refcounted { +public: + ReconnectHook(kj::Function connectParam, bool lazy = false) + : connect(kj::mv(connectParam)), + current(lazy ? kj::Maybe>() : ClientHook::from(connect())) {} + + Request newCall( + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { + auto result = getCurrent().newCall(interfaceId, methodId, sizeHint, hints); + AnyPointer::Builder builder = result; + auto hook = kj::heap(kj::addRef(*this), RequestHook::from(kj::mv(result))); + return { builder, kj::mv(hook) }; + } + + VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, + kj::Own&& context, CallHints hints) override { + auto result = getCurrent().call(interfaceId, methodId, kj::mv(context), hints); + if (hints.onlyPromisePipeline) { + // Just in case the callee didn't implement the hint, replace its promise. + result.promise = kj::NEVER_DONE; + + // TODO(bug): In this case we won't detect cancellation. This is essentially the same + // bug as described in `RequestImpl::send()` below, and will need the same solution. + } else { + wrap(result.promise); + } + return result; + } + + kj::Maybe getResolved() override { + // We can't let people resolve to the underlying capability because then we wouldn't be able + // to redirect them later. + return nullptr; + } + + kj::Maybe>> whenMoreResolved() override { + return nullptr; + } + + kj::Own addRef() override { + return kj::addRef(*this); + } + + const void* getBrand() override { + return nullptr; + } + + kj::Maybe getFd() override { + // It's not safe to return current->getFd() because normally callers wouldn't expect the FD to + // change or go away over time, but this one could whenever we reconnect. If there's a use + // case for being able to access the FD here, we'll need a different interface to do it. + return nullptr; + } + +private: + kj::Function connect; + kj::Maybe> current; + uint generation = 0; + + template + void wrap(kj::Promise& promise) { + promise = promise.catch_( + [self = kj::addRef(*this), startGeneration = generation] + (kj::Exception&& exception) mutable -> kj::Promise { + if (exception.getType() == kj::Exception::Type::DISCONNECTED && + self->generation == startGeneration) { + self->generation++; + KJ_IF_MAYBE(e2, kj::runCatchingExceptions([&]() { + self->current = ClientHook::from(self->connect()); + })) { + self->current = newBrokenCap(kj::mv(*e2)); + } + } + return kj::mv(exception); + }); + } + + ClientHook& getCurrent() { + KJ_IF_MAYBE(c, current) { + return **c; + } else { + return *current.emplace(ClientHook::from(connect())); + } + } + + class RequestImpl final: public RequestHook { + public: + RequestImpl(kj::Own parent, kj::Own inner) + : parent(kj::mv(parent)), inner(kj::mv(inner)) {} + + RemotePromise send() override { + auto result = inner->send(); + // TODO(bug): If the returned promise is dropped, e.g. because the caller only cares about + // pipelining, then the DISCONNECTED exception will not be noticed. I suppose we have to + // split the promise and hold one branch, but we don't want to prevent cancellation, so + // we only want to hold that branch as long as the PipelineHook or some pipelined + // capability obtained through it lives. So we need a bunch of custom wrappers for that. + // Ugh. + parent->wrap(result); + return result; + } + + kj::Promise sendStreaming() override { + auto result = inner->sendStreaming(); + parent->wrap(result); + return result; + } + + AnyPointer::Pipeline sendForPipeline() override { + // TODO(bug): This definitely fails to detect disconnects; see comment in send(). + return inner->sendForPipeline(); + } + + const void* getBrand() override { + return nullptr; + } + + private: + kj::Own parent; + kj::Own inner; + }; +}; + +} // namespace + +Capability::Client autoReconnect(kj::Function connect) { + return Capability::Client(kj::refcounted(kj::mv(connect))); +} + +Capability::Client lazyAutoReconnect(kj::Function connect) { + return Capability::Client(kj::refcounted(kj::mv(connect), true)); +} +} // namespace capnp diff --git a/c++/src/capnp/reconnect.h b/c++/src/capnp/reconnect.h new file mode 100644 index 0000000000..4e430951e9 --- /dev/null +++ b/c++/src/capnp/reconnect.h @@ -0,0 +1,80 @@ +// Copyright (c) 2020 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include + +CAPNP_BEGIN_HEADER + +namespace capnp { + +template +auto autoReconnect(ConnectFunc&& connect); +// Creates a capability that reconstructs itself every time it becomes disconnected. +// +// `connect()` is a function which is invoked to initially construct the capability, and then +// invoked again each time the capability is found to be disconnected. `connect()` may return +// any capability `Client` type. +// +// Example usage might look like: +// +// Foo::Client foo = autoReconnect([&rpcSystem, vatId]() { +// return rpcSystem.bootstrap(vatId).castAs().getFooRequest().send().getFoo(); +// }); +// +// The given function is initially called synchronously, and the returned `foo` is a wrapper +// around what the function returned. But any time this capability becomes disconnected, the +// function is invoked again, and future calls are directed to the new result. +// +// Any call that is in-flight when the capability becomes disconnected still fails with a +// DISCONNECTED exception. The caller should respond by retrying, as a retry will target the +// newly-reconnected capability. However, the caller should limit the number of times it retries, +// to avoid an infinite loop in the case that the DISCONNECTED exception actually represents a +// permanent problem. Consider using `kj::retryOnDisconnect()` to implement this behavior. + +template +auto lazyAutoReconnect(ConnectFunc&& connect); +// The same as autoReconnect, but doesn't call the provided connect function until the first +// time the capability is used. Note that only the initial connection is lazy -- upon +// disconnected errors this will still reconnect eagerly. + +// ======================================================================================= +// inline implementation details + +Capability::Client autoReconnect(kj::Function connect); +template +auto autoReconnect(ConnectFunc&& connect) { + return autoReconnect(kj::Function(kj::fwd(connect))) + .castAs>>(); +} + +Capability::Client lazyAutoReconnect(kj::Function connect); +template +auto lazyAutoReconnect(ConnectFunc&& connect) { + return lazyAutoReconnect(kj::Function(kj::fwd(connect))) + .castAs>>(); +} + +} // namespace capnp + +CAPNP_END_HEADER diff --git a/c++/src/capnp/rpc-prelude.h b/c++/src/capnp/rpc-prelude.h index 7d26e39de8..742aa868c9 100644 --- a/c++/src/capnp/rpc-prelude.h +++ b/c++/src/capnp/rpc-prelude.h @@ -22,20 +22,18 @@ // This file contains a bunch of internal declarations that must appear before rpc.h can start. // We don't define these directly in rpc.h because it makes the file hard to read. -#ifndef CAPNP_RPC_PRELUDE_H_ -#define CAPNP_RPC_PRELUDE_H_ +#pragma once -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif - -#include "capability.h" +#include #include "persistent.capnp.h" +CAPNP_BEGIN_HEADER + namespace capnp { class OutgoingRpcMessage; class IncomingRpcMessage; +class RpcFlowController; template class RpcSystem; @@ -60,6 +58,7 @@ class VatNetworkBase { virtual kj::Promise>> receiveIncomingMessage() = 0; virtual kj::Promise shutdown() = 0; virtual AnyStruct::Reader baseGetPeerVatId() = 0; + virtual kj::Own newStream() = 0; }; virtual kj::Maybe> baseConnect(AnyStruct::Reader vatId) = 0; virtual kj::Promise> baseAccept() = 0; @@ -80,14 +79,16 @@ class RpcSystemBase { // Non-template version of RpcSystem. Ignore this class; see RpcSystem in rpc.h. public: - RpcSystemBase(VatNetworkBase& network, kj::Maybe bootstrapInterface, - kj::Maybe::Client> gateway); - RpcSystemBase(VatNetworkBase& network, BootstrapFactoryBase& bootstrapFactory, - kj::Maybe::Client> gateway); + RpcSystemBase(VatNetworkBase& network, kj::Maybe bootstrapInterface); + RpcSystemBase(VatNetworkBase& network, BootstrapFactoryBase& bootstrapFactory); RpcSystemBase(VatNetworkBase& network, SturdyRefRestorerBase& restorer); RpcSystemBase(RpcSystemBase&& other) noexcept; ~RpcSystemBase() noexcept(false); + void setTraceEncoder(kj::Function func); + + kj::Promise run(); + private: class Impl; kj::Own impl; @@ -100,31 +101,7 @@ class RpcSystemBase { friend class capnp::RpcSystem; }; -template struct InternalRefFromRealmGateway_; -template -struct InternalRefFromRealmGateway_> { - typedef InternalRef Type; -}; -template -using InternalRefFromRealmGateway = typename InternalRefFromRealmGateway_::Type; -template -using InternalRefFromRealmGatewayClient = InternalRefFromRealmGateway; - -template struct ExternalRefFromRealmGateway_; -template -struct ExternalRefFromRealmGateway_> { - typedef ExternalRef Type; -}; -template -using ExternalRefFromRealmGateway = typename ExternalRefFromRealmGateway_::Type; -template -using ExternalRefFromRealmGatewayClient = ExternalRefFromRealmGateway; - } // namespace _ (private) } // namespace capnp -#endif // CAPNP_RPC_PRELUDE_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/rpc-test.c++ b/c++/src/capnp/rpc-test.c++ index e4c06a7e5e..da0d7abcbc 100644 --- a/c++/src/capnp/rpc-test.c++ +++ b/c++/src/capnp/rpc-test.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#define CAPNP_TESTING_CAPNP 1 + #include "rpc.h" #include "test-util.h" #include "schema.h" @@ -196,7 +198,7 @@ typedef VatNetwork< class TestNetworkAdapter final: public TestNetworkAdapterBase { public: - TestNetworkAdapter(TestNetwork& network): network(network) {} + TestNetworkAdapter(TestNetwork& network, kj::StringPtr self): network(network), self(self) {} ~TestNetworkAdapter() { kj::Exception exception = KJ_EXCEPTION(FAILED, "Network was destroyed."); @@ -208,6 +210,12 @@ public: uint getSentCount() { return sent; } uint getReceivedCount() { return received; } + void onSend(kj::Function callback) { + // Invokes the given callback every time a message is sent. Callback can return false to cause + // send() to do nothing. + sendCallback = kj::mv(callback); + } + typedef TestNetworkAdapterBase::Connection Connection; class ConnectionImpl final @@ -244,6 +252,10 @@ public: return message.getRoot(); } + size_t sizeInWords() override { + return data.size(); + } + kj::Array data; FlatArrayMessageReader message; }; @@ -260,6 +272,8 @@ public: } void send() override { + if (!connection.network.sendCallback(message)) return; + if (connection.networkException != nullptr) { return; } @@ -274,8 +288,8 @@ public: auto incomingMessage = kj::heap(messageToFlatArray(message)); auto connectionPtr = &connection; - connection.tasks->add(kj::evalLater(kj::mvCapture(incomingMessage, - [connectionPtr](kj::Own&& message) { + connection.tasks->add(kj::evalLater( + [connectionPtr,message=kj::mv(incomingMessage)]() mutable { KJ_IF_MAYBE(p, connectionPtr->partner) { if (p->fulfillers.empty()) { p->messages.push(kj::mv(message)); @@ -286,7 +300,11 @@ public: p->fulfillers.pop(); } } - }))); + })); + } + + size_t sizeInWords() override { + return message.sizeInWords(); } private: @@ -352,6 +370,10 @@ public: }; kj::Maybe> connect(test::TestSturdyRefHostId::Reader hostId) override { + if (hostId.getHost() == self) { + return nullptr; + } + TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost())); auto iter = connections.find(&dst); @@ -390,18 +412,21 @@ public: private: TestNetwork& network; + kj::StringPtr self; uint sent = 0; uint received = 0; std::map> connections; std::queue>>> fulfillerQueue; std::queue> connectionQueue; + + kj::Function sendCallback = [](MessageBuilder&) { return true; }; }; TestNetwork::~TestNetwork() noexcept(false) {} TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) { - return *(map[name] = kj::heap(*this)); + return *(map[name] = kj::heap(*this, name)); } // ======================================================================================= @@ -446,21 +471,12 @@ struct TestContext { serverNetwork(network.add("server")), rpcClient(makeRpcClient(clientNetwork)), rpcServer(makeRpcServer(serverNetwork, restorer)) {} - TestContext(Capability::Client bootstrap, - RealmGateway::Client gateway) - : waitScope(loop), - clientNetwork(network.add("client")), - serverNetwork(network.add("server")), - rpcClient(makeRpcClient(clientNetwork, gateway)), - rpcServer(makeRpcServer(serverNetwork, bootstrap)) {} - TestContext(Capability::Client bootstrap, - RealmGateway::Client gateway, - bool) + TestContext(Capability::Client bootstrap) : waitScope(loop), clientNetwork(network.add("client")), serverNetwork(network.add("server")), rpcClient(makeRpcClient(clientNetwork)), - rpcServer(makeRpcServer(serverNetwork, bootstrap, gateway)) {} + rpcServer(makeRpcServer(serverNetwork, bootstrap)) {} Capability::Client connect(test::TestSturdyRefObjectId::Tag tag) { MallocMessageBuilder refMessage(128); @@ -550,6 +566,71 @@ TEST(Rpc, Pipelining) { EXPECT_EQ(1, chainedCallCount); } +KJ_TEST("RPC sendForPipeline()") { + TestContext context; + + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_PIPELINE) + .castAs(); + + int chainedCallCount = 0; + + auto request = client.getCapRequest(); + request.setN(234); + request.setInCap(kj::heap(chainedCallCount)); + + auto pipeline = request.sendForPipeline(); + + auto pipelineRequest = pipeline.getOutBox().getCap().fooRequest(); + pipelineRequest.setI(321); + auto pipelinePromise = pipelineRequest.send(); + + auto pipelineRequest2 = pipeline.getOutBox().getCap().castAs().graultRequest(); + auto pipelinePromise2 = pipelineRequest2.send(); + + pipeline = nullptr; // Just to be annoying, drop the original pipeline. + + EXPECT_EQ(0, context.restorer.callCount); + EXPECT_EQ(0, chainedCallCount); + + auto response = pipelinePromise.wait(context.waitScope); + EXPECT_EQ("bar", response.getX()); + + auto response2 = pipelinePromise2.wait(context.waitScope); + checkTestMessage(response2); + + EXPECT_EQ(3, context.restorer.callCount); + EXPECT_EQ(1, chainedCallCount); +} + +KJ_TEST("RPC context.setPipeline") { + TestContext context; + + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_PIPELINE) + .castAs(); + + auto promise = client.getCapPipelineOnlyRequest().send(); + + auto pipelineRequest = promise.getOutBox().getCap().fooRequest(); + pipelineRequest.setI(321); + auto pipelinePromise = pipelineRequest.send(); + + auto pipelineRequest2 = promise.getOutBox().getCap().castAs().graultRequest(); + auto pipelinePromise2 = pipelineRequest2.send(); + + EXPECT_EQ(0, context.restorer.callCount); + + auto response = pipelinePromise.wait(context.waitScope); + EXPECT_EQ("bar", response.getX()); + + auto response2 = pipelinePromise2.wait(context.waitScope); + checkTestMessage(response2); + + EXPECT_EQ(3, context.restorer.callCount); + + // The original promise never completed. + KJ_EXPECT(!promise.poll(context.waitScope)); +} + TEST(Rpc, Release) { TestContext context; @@ -631,18 +712,129 @@ TEST(Rpc, TailCall) { auto dependentCall1 = promise.getC().getCallSequenceRequest().send(); - auto dependentCall2 = response.getC().getCallSequenceRequest().send(); - EXPECT_EQ(0, dependentCall0.wait(context.waitScope).getN()); EXPECT_EQ(1, dependentCall1.wait(context.waitScope).getN()); + + // TODO(someday): We used to initiate dependentCall2 here before waiting on the first two calls, + // and the ordering was still "correct". But this was apparently by accident. Calling getC() on + // the final response returns a different capability from calling getC() on the promise. There + // are no guarantees on the ordering of calls on the response capability vs. the earlier + // promise. When ordering matters, applications should take the original promise capability and + // keep using that. In theory the RPC system could create continuity here, but it would be + // annoying: for each capability that had been fetched on the promise, it would need to + // traverse to the same capability in the final response and swap it out in-place for the + // pipelined cap returned earlier. Maybe we'll determine later that that's really needed but + // for now I'm not gonna do it. + auto dependentCall2 = response.getC().getCallSequenceRequest().send(); + EXPECT_EQ(2, dependentCall2.wait(context.waitScope).getN()); EXPECT_EQ(1, calleeCallCount); EXPECT_EQ(1, context.restorer.callCount); } -TEST(Rpc, Cancelation) { - // Tests allowCancellation(). +class TestHangingTailCallee final: public test::TestTailCallee::Server { +public: + TestHangingTailCallee(int& callCount, int& cancelCount) + : callCount(callCount), cancelCount(cancelCount) {} + + kj::Promise foo(FooContext context) override { + ++callCount; + return kj::Promise(kj::NEVER_DONE) + .attach(kj::defer([&cancelCount = cancelCount]() { ++cancelCount; })); + } + +private: + int& callCount; + int& cancelCount; +}; + +class TestRacingTailCaller final: public test::TestTailCaller::Server { +public: + TestRacingTailCaller(kj::Promise unblock): unblock(kj::mv(unblock)) {} + + kj::Promise foo(FooContext context) override { + return unblock.then([context]() mutable { + auto tailRequest = context.getParams().getCallee().fooRequest(); + return context.tailCall(kj::mv(tailRequest)); + }); + } + +private: + kj::Promise unblock; +}; + +TEST(Rpc, TailCallCancel) { + TestContext context; + + auto caller = context.connect(test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER) + .castAs(); + + int callCount = 0, cancelCount = 0; + + test::TestTailCallee::Client callee(kj::heap(callCount, cancelCount)); + + { + auto request = caller.fooRequest(); + request.setCallee(callee); + + auto promise = request.send(); + + KJ_ASSERT(callCount == 0); + KJ_ASSERT(cancelCount == 0); + + KJ_ASSERT(!promise.poll(context.waitScope)); + + KJ_ASSERT(callCount == 1); + KJ_ASSERT(cancelCount == 0); + } + + kj::Promise(kj::NEVER_DONE).poll(context.waitScope); + + KJ_ASSERT(callCount == 1); + KJ_ASSERT(cancelCount == 1); +} + +TEST(Rpc, TailCallCancelRace) { + auto paf = kj::newPromiseAndFulfiller(); + TestContext context(kj::heap(kj::mv(paf.promise))); + + MallocMessageBuilder serverHostIdBuilder; + auto serverHostId = serverHostIdBuilder.getRoot(); + serverHostId.setHost("server"); + + auto caller = context.rpcClient.bootstrap(serverHostId).castAs(); + + int callCount = 0, cancelCount = 0; + + test::TestTailCallee::Client callee(kj::heap(callCount, cancelCount)); + + { + auto request = caller.fooRequest(); + request.setCallee(callee); + + auto promise = request.send(); + + KJ_ASSERT(callCount == 0); + KJ_ASSERT(cancelCount == 0); + + KJ_ASSERT(!promise.poll(context.waitScope)); + + KJ_ASSERT(callCount == 0); + KJ_ASSERT(cancelCount == 0); + + // Unblock the server and at the same time cancel the client. + paf.fulfiller->fulfill(); + } + + kj::Promise(kj::NEVER_DONE).poll(context.waitScope); + + KJ_ASSERT(callCount == 1); + KJ_ASSERT(cancelCount == 1); +} + +TEST(Rpc, Cancellation) { + // Tests cancellation. TestContext context; @@ -899,6 +1091,55 @@ TEST(Rpc, Embargo) { EXPECT_EQ(5, call5.wait(context.waitScope).getN()); } +TEST(Rpc, EmbargoUnwrap) { + // Test that embargos properly block unwraping a capability using CapabilityServerSet. + + TestContext context; + + capnp::CapabilityServerSet capSet; + + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) + .castAs(); + + auto cap = capSet.add(kj::heap()); + + auto earlyCall = client.getCallSequenceRequest().send(); + + auto echoRequest = client.echoRequest(); + echoRequest.setCap(cap); + auto echo = echoRequest.send(); + + auto pipeline = echo.getCap(); + + auto unwrap = capSet.getLocalServer(pipeline) + .then([](kj::Maybe unwrapped) { + return kj::downcast(KJ_ASSERT_NONNULL(unwrapped)).getCount(); + }).eagerlyEvaluate(nullptr); + + auto call0 = getCallSequence(pipeline, 0); + auto call1 = getCallSequence(pipeline, 1); + + earlyCall.wait(context.waitScope); + + auto call2 = getCallSequence(pipeline, 2); + + auto resolved = echo.wait(context.waitScope).getCap(); + + auto call3 = getCallSequence(pipeline, 4); + auto call4 = getCallSequence(pipeline, 4); + auto call5 = getCallSequence(pipeline, 5); + + EXPECT_EQ(0, call0.wait(context.waitScope).getN()); + EXPECT_EQ(1, call1.wait(context.waitScope).getN()); + EXPECT_EQ(2, call2.wait(context.waitScope).getN()); + EXPECT_EQ(3, call3.wait(context.waitScope).getN()); + EXPECT_EQ(4, call4.wait(context.waitScope).getN()); + EXPECT_EQ(5, call5.wait(context.waitScope).getN()); + + uint unwrappedAt = unwrap.wait(context.waitScope); + KJ_EXPECT(unwrappedAt >= 3, unwrappedAt); +} + template void expectPromiseThrows(kj::Promise&& promise, kj::WaitScope& waitScope) { EXPECT_TRUE(promise.then([](T&&) { return false; }, [](kj::Exception&&) { return true; }) @@ -1062,198 +1303,278 @@ TEST(Rpc, Abort) { EXPECT_TRUE(conn->receiveIncomingMessage().wait(context.waitScope) == nullptr); } -// ======================================================================================= - -typedef RealmGateway TestRealmGateway; - -class TestGateway final: public TestRealmGateway::Server { -public: - kj::Promise import(ImportContext context) override { - auto cap = context.getParams().getCap(); - context.releaseParams(); - return cap.saveRequest().send() - .then([KJ_CPCAP(context)](Response::SaveResults> response) mutable { - context.getResults().initSturdyRef().getObjectId().setAs( - kj::str("imported-", response.getSturdyRef())); - }); - } +KJ_TEST("loopback bootstrap()") { + int callCount = 0; + test::TestInterface::Client bootstrap = kj::heap(callCount); - kj::Promise export_(ExportContext context) override { - auto cap = context.getParams().getCap(); - context.releaseParams(); - return cap.saveRequest().send() - .then([KJ_CPCAP(context)] - (Response::SaveResults> response) mutable { - context.getResults().setSturdyRef(kj::str("exported-", - response.getSturdyRef().getObjectId().getAs())); - }); - } -}; + MallocMessageBuilder hostIdBuilder; + auto hostId = hostIdBuilder.getRoot(); + hostId.setHost("server"); -class TestPersistent final: public Persistent::Server { -public: - TestPersistent(kj::StringPtr name): name(name) {} + TestContext context(bootstrap); + auto client = context.rpcServer.bootstrap(hostId).castAs(); - kj::Promise save(SaveContext context) override { - context.initResults().initSturdyRef().getObjectId().setAs(name); - return kj::READY_NOW; - } + auto request = client.fooRequest(); + request.setI(123); + request.setJ(true); + auto response = request.send().wait(context.waitScope); -private: - kj::StringPtr name; -}; + KJ_EXPECT(response.getX() == "foo"); + KJ_EXPECT(callCount == 1); +} -class TestPersistentText final: public Persistent::Server { -public: - TestPersistentText(kj::StringPtr name): name(name) {} +KJ_TEST("method throws exception") { + TestContext context; - kj::Promise save(SaveContext context) override { - context.initResults().setSturdyRef(name); - return kj::READY_NOW; - } + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) + .castAs(); -private: - kj::StringPtr name; -}; + kj::Maybe maybeException; + client.throwExceptionRequest().send().ignoreResult() + .catch_([&](kj::Exception&& e) { + maybeException = kj::mv(e); + }).wait(context.waitScope); -TEST(Rpc, RealmGatewayImport) { - TestRealmGateway::Client gateway = kj::heap(); - Persistent::Client bootstrap = kj::heap("foo"); + auto exception = KJ_ASSERT_NONNULL(maybeException); + KJ_EXPECT(exception.getDescription() == "remote exception: test exception"); + KJ_EXPECT(exception.getRemoteTrace() == nullptr); +} - MallocMessageBuilder hostIdBuilder; - auto hostId = hostIdBuilder.getRoot(); - hostId.setHost("server"); +KJ_TEST("method throws exception won't redundantly add remote exception prefix") { + TestContext context; - TestContext context(bootstrap, gateway); - auto client = context.rpcClient.bootstrap(hostId).castAs>(); + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) + .castAs(); - auto response = client.saveRequest().send().wait(context.waitScope); + kj::Maybe maybeException; + client.throwRemoteExceptionRequest().send().ignoreResult() + .catch_([&](kj::Exception&& e) { + maybeException = kj::mv(e); + }).wait(context.waitScope); - EXPECT_EQ("imported-foo", response.getSturdyRef().getObjectId().getAs()); + auto exception = KJ_ASSERT_NONNULL(maybeException); + KJ_EXPECT(exception.getDescription() == "remote exception: test exception"); + KJ_EXPECT(exception.getRemoteTrace() == nullptr); } -TEST(Rpc, RealmGatewayExport) { - TestRealmGateway::Client gateway = kj::heap(); - Persistent::Client bootstrap = kj::heap("foo"); +KJ_TEST("method throws exception with trace encoder") { + TestContext context; - MallocMessageBuilder hostIdBuilder; - auto hostId = hostIdBuilder.getRoot(); - hostId.setHost("server"); + context.rpcServer.setTraceEncoder([](const kj::Exception& e) { + return kj::str("trace for ", e.getDescription()); + }); - TestContext context(bootstrap, gateway, true); - auto client = context.rpcClient.bootstrap(hostId).castAs>(); + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) + .castAs(); - auto response = client.saveRequest().send().wait(context.waitScope); + kj::Maybe maybeException; + client.throwExceptionRequest().send().ignoreResult() + .catch_([&](kj::Exception&& e) { + maybeException = kj::mv(e); + }).wait(context.waitScope); - EXPECT_EQ("exported-foo", response.getSturdyRef()); + auto exception = KJ_ASSERT_NONNULL(maybeException); + KJ_EXPECT(exception.getDescription() == "remote exception: test exception"); + KJ_EXPECT(exception.getRemoteTrace() == "trace for test exception"); } -TEST(Rpc, RealmGatewayImportExport) { - // Test that a save request which leaves the realm, bounces through a promise capability, and - // then comes back into the realm, does not actually get translated both ways. +KJ_TEST("when OutgoingRpcMessage::send() throws, we don't leak exports") { + // When OutgoingRpcMessage::send() throws an exception on a Call message, we need to clean up + // anything that had been added to the export table as part of the call. At one point this + // cleanup was missing, so exports would leak. - TestRealmGateway::Client gateway = kj::heap(); - Persistent::Client bootstrap = kj::heap("foo"); - - MallocMessageBuilder serverHostIdBuilder; - auto serverHostId = serverHostIdBuilder.getRoot(); - serverHostId.setHost("server"); + TestContext context; - MallocMessageBuilder clientHostIdBuilder; - auto clientHostId = clientHostIdBuilder.getRoot(); - clientHostId.setHost("client"); + uint32_t expectedExportNumber = 0; + uint interceptCount = 0; + bool shouldThrowFromSend = false; + context.clientNetwork.onSend([&](MessageBuilder& builder) { + auto message = builder.getRoot().asReader(); + if (message.isCall()) { + auto call = message.getCall(); + if (call.getInterfaceId() == capnp::typeId() && + call.getMethodId() == 0) { + // callFoo() request, expect a capability in the param caps. Specifically we expect a + // promise, because that's what we send below. + auto capTable = call.getParams().getCapTable(); + KJ_ASSERT(capTable.size() == 1); + auto desc = capTable[0]; + KJ_ASSERT(desc.isSenderPromise()); + KJ_ASSERT(desc.getSenderPromise() == expectedExportNumber); + + ++interceptCount; + if (shouldThrowFromSend) { + kj::throwRecoverableException(KJ_EXCEPTION(FAILED, "intercepted")); + return false; // only matters when -fno-exceptions + } + } + } + return true; + }); - kj::EventLoop loop; - kj::WaitScope waitScope(loop); - TestNetwork network; - TestRestorer restorer; - TestNetworkAdapter& clientNetwork = network.add("client"); - TestNetworkAdapter& serverNetwork = network.add("server"); - RpcSystem rpcClient = - makeRpcServer(clientNetwork, bootstrap, gateway); - auto paf = kj::newPromiseAndFulfiller(); - RpcSystem rpcServer = - makeRpcServer(serverNetwork, kj::mv(paf.promise)); - - auto client = rpcClient.bootstrap(serverHostId).castAs>(); - - bool responseReady = false; - auto responsePromise = client.saveRequest().send() - .then([&](Response::SaveResults>&& response) { - responseReady = true; - return kj::mv(response); - }).eagerlyEvaluate(nullptr); + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) + .castAs(); - // Crank the event loop to give the message time to reach the server and block on the promise - // resolution. - kj::evalLater([]() {}).wait(waitScope); - kj::evalLater([]() {}).wait(waitScope); - kj::evalLater([]() {}).wait(waitScope); - kj::evalLater([]() {}).wait(waitScope); + { + shouldThrowFromSend = true; + auto req = client.callFooRequest(); + req.setCap(kj::Promise(kj::NEVER_DONE)); + req.send().then([](auto&&) { + KJ_FAIL_ASSERT("should have thrown"); + }, [](kj::Exception&& e) { + KJ_EXPECT(e.getDescription() == "intercepted", e); + }).wait(context.waitScope); + } - EXPECT_FALSE(responseReady); + KJ_EXPECT(interceptCount == 1); - paf.fulfiller->fulfill(rpcServer.bootstrap(clientHostId)); + // Sending again should use the same export number, because the export table entry should have + // been released when send() threw. (At one point, this was a bug...) + { + shouldThrowFromSend = true; + auto req = client.callFooRequest(); + req.setCap(kj::Promise(kj::NEVER_DONE)); + req.send().then([](auto&&) { + KJ_FAIL_ASSERT("should have thrown"); + }, [](kj::Exception&& e) { + KJ_EXPECT(e.getDescription() == "intercepted", e); + }).wait(context.waitScope); + } - auto response = responsePromise.wait(waitScope); + KJ_EXPECT(interceptCount == 2); - // Should have the original value. If it went through export and re-import, though, then this - // will be "imported-exported-foo", which is wrong. - EXPECT_EQ("foo", response.getSturdyRef().getObjectId().getAs()); -} + // Now lets start a call that doesn't throw. The export number should still be zero because + // the previous exports were released. + { + shouldThrowFromSend = false; + auto req = client.callFooRequest(); + req.setCap(kj::Promise(kj::NEVER_DONE)); + auto promise = req.send(); + KJ_EXPECT(!promise.poll(context.waitScope)); -TEST(Rpc, RealmGatewayImportExport) { - // Test that a save request which enters the realm, bounces through a promise capability, and - // then goes back out of the realm, does not actually get translated both ways. + KJ_EXPECT(interceptCount == 3); + } - TestRealmGateway::Client gateway = kj::heap(); - Persistent::Client bootstrap = kj::heap("foo"); + // We canceled the previous call, BUT the exported capability is still present until the other + // side drops it, which it won't because the call isn't marked cancelable and never completes. + // Now, let's send another call. This time, we expect a new export number will actually be + // allocated. + { + shouldThrowFromSend = false; + expectedExportNumber = 1; + auto req = client.callFooRequest(); + auto paf = kj::newPromiseAndFulfiller(); + req.setCap(kj::mv(paf.promise)); + auto promise = req.send(); + KJ_EXPECT(!promise.poll(context.waitScope)); + + KJ_EXPECT(interceptCount == 4); + + // Now let's actually let the RPC complete so we can verify the RPC system isn't broken or + // anything. + int callCount = 0; + paf.fulfiller->fulfill(kj::heap(callCount)); + auto resp = promise.wait(context.waitScope); + KJ_EXPECT(resp.getS() == "bar"); + KJ_EXPECT(callCount == 1); + } - MallocMessageBuilder serverHostIdBuilder; - auto serverHostId = serverHostIdBuilder.getRoot(); - serverHostId.setHost("server"); + // Now if we do yet another call, it'll reuse export number 1. + { + shouldThrowFromSend = false; + expectedExportNumber = 1; + auto req = client.callFooRequest(); + req.setCap(kj::Promise(kj::NEVER_DONE)); + auto promise = req.send(); + KJ_EXPECT(!promise.poll(context.waitScope)); + + KJ_EXPECT(interceptCount == 5); + } +} - MallocMessageBuilder clientHostIdBuilder; - auto clientHostId = clientHostIdBuilder.getRoot(); - clientHostId.setHost("client"); +KJ_TEST("export the same promise twice") { + TestContext context; - kj::EventLoop loop; - kj::WaitScope waitScope(loop); - TestNetwork network; - TestRestorer restorer; - TestNetworkAdapter& clientNetwork = network.add("client"); - TestNetworkAdapter& serverNetwork = network.add("server"); - RpcSystem rpcClient = - makeRpcServer(clientNetwork, bootstrap); - auto paf = kj::newPromiseAndFulfiller(); - RpcSystem rpcServer = - makeRpcServer(serverNetwork, kj::mv(paf.promise), gateway); - - auto client = rpcClient.bootstrap(serverHostId).castAs>(); - - bool responseReady = false; - auto responsePromise = client.saveRequest().send() - .then([&](Response::SaveResults>&& response) { - responseReady = true; - return kj::mv(response); - }).eagerlyEvaluate(nullptr); + bool exportIsPromise; + uint32_t expectedExportNumber; + uint interceptCount = 0; + context.clientNetwork.onSend([&](MessageBuilder& builder) { + auto message = builder.getRoot().asReader(); + if (message.isCall()) { + auto call = message.getCall(); + if (call.getInterfaceId() == capnp::typeId() && + call.getMethodId() == 0) { + // callFoo() request, expect a capability in the param caps. Specifically we expect a + // promise, because that's what we send below. + auto capTable = call.getParams().getCapTable(); + KJ_ASSERT(capTable.size() == 1); + auto desc = capTable[0]; + if (exportIsPromise) { + KJ_ASSERT(desc.isSenderPromise()); + KJ_ASSERT(desc.getSenderPromise() == expectedExportNumber); + } else { + KJ_ASSERT(desc.isSenderHosted()); + KJ_ASSERT(desc.getSenderHosted() == expectedExportNumber); + } - // Crank the event loop to give the message time to reach the server and block on the promise - // resolution. - kj::evalLater([]() {}).wait(waitScope); - kj::evalLater([]() {}).wait(waitScope); - kj::evalLater([]() {}).wait(waitScope); - kj::evalLater([]() {}).wait(waitScope); + ++interceptCount; + } + } + return true; + }); - EXPECT_FALSE(responseReady); + auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) + .castAs(); - paf.fulfiller->fulfill(rpcServer.bootstrap(clientHostId)); + auto sendReq = [&](test::TestInterface::Client cap) { + auto req = client.callFooRequest(); + req.setCap(kj::mv(cap)); + return req.send(); + }; - auto response = responsePromise.wait(waitScope); + auto expectNeverDone = [&](auto& promise) { + if (promise.poll(context.waitScope)) { + promise.wait(context.waitScope); // let it throw if it's going to + KJ_FAIL_ASSERT("promise finished without throwing"); + } + }; - // Should have the original value. If it went through import and re-export, though, then this - // will be "exported-imported-foo", which is wrong. - EXPECT_EQ("foo", response.getSturdyRef()); + int callCount = 0; + test::TestInterface::Client normalCap = kj::heap(callCount); + test::TestInterface::Client promiseCap = kj::Promise(kj::NEVER_DONE); + + // Send request with a promise capability in the params. + exportIsPromise = true; + expectedExportNumber = 0; + auto promise1 = sendReq(promiseCap); + expectNeverDone(promise1); + KJ_EXPECT(interceptCount == 1); + + // Send a second request with the same promise should use the same export table entry. + auto promise2 = sendReq(promiseCap); + expectNeverDone(promise2); + KJ_EXPECT(interceptCount == 2); + + // Sending a request with a different promise should use a different export table entry. + expectedExportNumber = 1; + auto promise3 = sendReq(kj::Promise(kj::NEVER_DONE)); + expectNeverDone(promise3); + KJ_EXPECT(interceptCount == 3); + + // Now try sending a non-promise cap. We'll send all these requests at once before waiting on + // any of them since these will actually complete. + exportIsPromise = false; + expectedExportNumber = 2; + auto promise4 = sendReq(normalCap); + auto promise5 = sendReq(normalCap); + expectedExportNumber = 3; + auto promise6 = sendReq(kj::heap(callCount)); + KJ_EXPECT(interceptCount == 6); + + KJ_EXPECT(promise4.wait(context.waitScope).getS() == "bar"); + KJ_EXPECT(promise5.wait(context.waitScope).getS() == "bar"); + KJ_EXPECT(promise6.wait(context.waitScope).getS() == "bar"); + KJ_EXPECT(callCount == 3); } } // namespace diff --git a/c++/src/capnp/rpc-twoparty-test.c++ b/c++/src/capnp/rpc-twoparty-test.c++ index 27d05b84c8..5bf2215de5 100644 --- a/c++/src/capnp/rpc-twoparty-test.c++ +++ b/c++/src/capnp/rpc-twoparty-test.c++ @@ -19,12 +19,32 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#define CAPNP_TESTING_CAPNP 1 + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +// Includes just for need SOL_SOCKET and SO_SNDBUF +#if _WIN32 +#include +#endif + #include "rpc-twoparty.h" #include "test-util.h" #include #include #include #include +#include + +#if _WIN32 +#include +#include +#include +#else +#include +#endif // TODO(cleanup): Auto-generate stringification functions for union discriminants. namespace capnp { @@ -67,6 +87,18 @@ private: int& handleCount; }; +class TestMonotonicClock final: public kj::MonotonicClock { +public: + kj::TimePoint now() const override { + return time; + } + + void reset() { time = kj::systemCoarseMonotonicClock().now(); } + void increment(kj::Duration d) { time += d; } +private: + kj::TimePoint time = kj::systemCoarseMonotonicClock().now(); +}; + kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider, int& callCount, int& handleCount) { return ioProvider.newPipeThread( @@ -97,16 +129,28 @@ Capability::Client getPersistentCap(RpcSystem& client, TEST(TwoPartyNetwork, Basic) { auto ioContext = kj::setupAsyncIo(); + TestMonotonicClock clock; int callCount = 0; int handleCount = 0; auto serverThread = runServer(*ioContext.provider, callCount, handleCount); - TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT); + TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT, capnp::ReaderOptions(), clock); auto rpcClient = makeRpcClient(network); + KJ_EXPECT(network.getCurrentQueueCount() == 0); + KJ_EXPECT(network.getCurrentQueueSize() == 0); + KJ_EXPECT(network.getOutgoingMessageWaitTime() == 0 * kj::SECONDS); + // Request the particular capability from the server. auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER, test::TestSturdyRefObjectId::Tag::TEST_INTERFACE).castAs(); + clock.increment(1 * kj::SECONDS); + + KJ_EXPECT(network.getCurrentQueueCount() == 1); + KJ_EXPECT(network.getCurrentQueueSize() % sizeof(word) == 0); + KJ_EXPECT(network.getCurrentQueueSize() > 0); + KJ_EXPECT(network.getOutgoingMessageWaitTime() == 1 * kj::SECONDS); + size_t oldSize = network.getCurrentQueueSize(); // Use the capability. auto request1 = client.fooRequest(); @@ -114,10 +158,23 @@ TEST(TwoPartyNetwork, Basic) { request1.setJ(true); auto promise1 = request1.send(); + KJ_EXPECT(network.getCurrentQueueCount() == 2); + KJ_EXPECT(network.getCurrentQueueSize() % sizeof(word) == 0); + KJ_EXPECT(network.getCurrentQueueSize() > oldSize); + KJ_EXPECT(network.getOutgoingMessageWaitTime() == 1 * kj::SECONDS); + oldSize = network.getCurrentQueueSize(); + auto request2 = client.bazRequest(); initTestMessage(request2.initS()); auto promise2 = request2.send(); + KJ_EXPECT(network.getCurrentQueueCount() == 3); + KJ_EXPECT(network.getCurrentQueueSize() % sizeof(word) == 0); + KJ_EXPECT(network.getCurrentQueueSize() > oldSize); + oldSize = network.getCurrentQueueSize(); + + clock.increment(1 * kj::SECONDS); + bool barFailed = false; auto request3 = client.barRequest(); auto promise3 = request3.send().then( @@ -129,6 +186,13 @@ TEST(TwoPartyNetwork, Basic) { EXPECT_EQ(0, callCount); + KJ_EXPECT(network.getCurrentQueueCount() == 4); + KJ_EXPECT(network.getCurrentQueueSize() % sizeof(word) == 0); + KJ_EXPECT(network.getCurrentQueueSize() > oldSize); + // Oldest message is now 2 seconds old + KJ_EXPECT(network.getOutgoingMessageWaitTime() == 2 * kj::SECONDS); + oldSize = network.getCurrentQueueSize(); + auto response1 = promise1.wait(ioContext.waitScope); EXPECT_EQ("foo", response1.getX()); @@ -139,6 +203,32 @@ TEST(TwoPartyNetwork, Basic) { EXPECT_EQ(2, callCount); EXPECT_TRUE(barFailed); + + // There's still a `Finish` message queued. + KJ_EXPECT(network.getCurrentQueueCount() > 0); + KJ_EXPECT(network.getCurrentQueueSize() > 0); + // Oldest message was sent, next oldest should be 0 seconds old since we haven't incremented + // the clock yet. + KJ_EXPECT(network.getOutgoingMessageWaitTime() == 0 * kj::SECONDS); + + // Let any I/O finish. + kj::Promise(kj::NEVER_DONE).poll(ioContext.waitScope); + + // Now nothing is queued. + KJ_EXPECT(network.getCurrentQueueCount() == 0); + KJ_EXPECT(network.getCurrentQueueSize() == 0); + + // Ensure that sending a message after not sending one for some time + // doesn't return incorrect waitTime statistics. + clock.increment(10 * kj::SECONDS); + + auto request4 = client.fooRequest(); + request4.setI(123); + request4.setJ(true); + auto promise4 = request4.send(); + + KJ_EXPECT(network.getCurrentQueueCount() == 1); + KJ_EXPECT(network.getOutgoingMessageWaitTime() == 0 * kj::SECONDS); } TEST(TwoPartyNetwork, Pipelining) { @@ -193,6 +283,10 @@ TEST(TwoPartyNetwork, Pipelining) { EXPECT_FALSE(disconnected); // What if we disconnect? + // TODO(cleanup): This is kind of cheating, we are shutting down the underlying socket to + // simulate a disconnect, but it's weird to pull the rug out from under our VatNetwork like + // this and it causes a bit of a race between write failures and read failures. This part of + // the test should maybe be restructured. serverThread.pipe->shutdownWrite(); // The other side should also disconnect. @@ -214,8 +308,19 @@ TEST(TwoPartyNetwork, Pipelining) { .castAs().graultRequest(); auto pipelinePromise2 = pipelineRequest2.send(); - EXPECT_ANY_THROW(pipelinePromise.wait(ioContext.waitScope)); - EXPECT_ANY_THROW(pipelinePromise2.wait(ioContext.waitScope)); + pipelinePromise.then([](auto) { + KJ_FAIL_EXPECT("should have thrown"); + }, [](kj::Exception&& e) { + KJ_EXPECT(e.getType() == kj::Exception::Type::DISCONNECTED); + // I wish we could test stack traces somehow... oh well. + }).wait(ioContext.waitScope); + + pipelinePromise2.then([](auto) { + KJ_FAIL_EXPECT("should have thrown"); + }, [](kj::Exception&& e) { + KJ_EXPECT(e.getType() == kj::Exception::Type::DISCONNECTED); + // I wish we could test stack traces somehow... oh well. + }).wait(ioContext.waitScope); EXPECT_EQ(3, callCount); EXPECT_EQ(1, reverseCallCount); @@ -246,7 +351,7 @@ TEST(TwoPartyNetwork, Release) { // There once was a bug where the last outgoing message (and any capabilities attached) would // not get cleaned up (until a new message was sent). This appeared to be a bug in Release, - // becaues if a client received a message and then released a capability from it but then did + // because if a client received a message and then released a capability from it but then did // not make any further calls, then the capability would not be released because the message // introducing it remained the last server -> client message (because a "Release" message has // no reply). Here we are explicitly trying to catch this bug. This proves tricky, because when @@ -301,8 +406,10 @@ TEST(TwoPartyNetwork, Abort) { msg->send(); } - auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope)); - EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs().which()); + { + auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope)); + EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs().which()); + } EXPECT_TRUE(conn->receiveIncomingMessage().wait(ioContext.waitScope) == nullptr); } @@ -352,12 +459,12 @@ TEST(TwoPartyNetwork, HugeMessage) { auto req = client.methodWithDefaultsRequest(); req.initA(100000000); // 100 MB - KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than the single-message size limit", + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than our single-message size limit", req.send().ignoreResult().wait(ioContext.waitScope)); } // Oversized response fails. - KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than the single-message size limit", + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than our single-message size limit", client.getEnormousStringRequest().send().ignoreResult().wait(ioContext.waitScope)); // Connection is still up. @@ -417,6 +524,748 @@ TEST(TwoPartyNetwork, BootstrapFactory) { EXPECT_TRUE(bootstrapFactory.called); } +// ======================================================================================= + +#if !_WIN32 && !__CYGWIN__ // Windows and Cygwin don't support SCM_RIGHTS. +KJ_TEST("send FD over RPC") { + auto io = kj::setupAsyncIo(); + + int callCount = 0; + int handleCount = 0; + TwoPartyServer server(kj::heap(callCount, handleCount)); + auto pipe = io.provider->newCapabilityPipe(); + server.accept(kj::mv(pipe.ends[0]), 2); + TwoPartyClient client(*pipe.ends[1], 2); + + auto cap = client.bootstrap().castAs(); + + int pipeFds[2]; + KJ_SYSCALL(kj::miniposix::pipe(pipeFds)); + kj::AutoCloseFd in1(pipeFds[0]); + kj::AutoCloseFd out1(pipeFds[1]); + KJ_SYSCALL(kj::miniposix::pipe(pipeFds)); + kj::AutoCloseFd in2(pipeFds[0]); + kj::AutoCloseFd out2(pipeFds[1]); + + capnp::RemotePromise promise = nullptr; + { + auto req = cap.writeToFdRequest(); + + // Order reversal intentional, just trying to mix things up. + req.setFdCap1(kj::heap(kj::mv(out2))); + req.setFdCap2(kj::heap(kj::mv(out1))); + + promise = req.send(); + } + + int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope)); + KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope) + == "baz"); + + { + auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope + auto response = promise2.wait(io.waitScope); + KJ_EXPECT(response.getSecondFdPresent()); + } + + KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope) + == "bar"); + KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope) + == "foo"); +} + +KJ_TEST("FD per message limit") { + auto io = kj::setupAsyncIo(); + + int callCount = 0; + int handleCount = 0; + TwoPartyServer server(kj::heap(callCount, handleCount)); + auto pipe = io.provider->newCapabilityPipe(); + server.accept(kj::mv(pipe.ends[0]), 1); + TwoPartyClient client(*pipe.ends[1], 1); + + auto cap = client.bootstrap().castAs(); + + int pipeFds[2]; + KJ_SYSCALL(kj::miniposix::pipe(pipeFds)); + kj::AutoCloseFd in1(pipeFds[0]); + kj::AutoCloseFd out1(pipeFds[1]); + KJ_SYSCALL(kj::miniposix::pipe(pipeFds)); + kj::AutoCloseFd in2(pipeFds[0]); + kj::AutoCloseFd out2(pipeFds[1]); + + capnp::RemotePromise promise = nullptr; + { + auto req = cap.writeToFdRequest(); + + // Order reversal intentional, just trying to mix things up. + req.setFdCap1(kj::heap(kj::mv(out2))); + req.setFdCap2(kj::heap(kj::mv(out1))); + + promise = req.send(); + } + + int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope)); + KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope) + == "baz"); + + { + auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope + auto response = promise2.wait(io.waitScope); + KJ_EXPECT(!response.getSecondFdPresent()); + } + + KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope) + == ""); + KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope) + == "foo"); +} +#endif // !_WIN32 && !__CYGWIN__ + +// ======================================================================================= + +class MockSndbufStream final: public kj::AsyncIoStream { +public: + MockSndbufStream(kj::Own inner, size_t& window, size_t& written) + : inner(kj::mv(inner)), window(window), written(written) {} + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->read(buffer, minBytes, maxBytes); + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->tryRead(buffer, minBytes, maxBytes); + } + kj::Maybe tryGetLength() override { + return inner->tryGetLength(); + } + kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return inner->pumpTo(output, amount); + } + kj::Promise write(const void* buffer, size_t size) override { + written += size; + return inner->write(buffer, size); + } + kj::Promise write(kj::ArrayPtr> pieces) override { + for (auto& piece: pieces) written += piece.size(); + return inner->write(pieces); + } + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount) override { + return inner->tryPumpFrom(input, amount); + } + kj::Promise whenWriteDisconnected() override { return inner->whenWriteDisconnected(); } + void shutdownWrite() override { return inner->shutdownWrite(); } + void abortRead() override { return inner->abortRead(); } + + void getsockopt(int level, int option, void* value, uint* length) override { + if (level == SOL_SOCKET && option == SO_SNDBUF) { + KJ_ASSERT(*length == sizeof(int)); + *reinterpret_cast(value) = window; + } else { + KJ_UNIMPLEMENTED("not implemented for test", level, option); + } + } + +private: + kj::Own inner; + size_t& window; + size_t& written; +}; + +KJ_TEST("Streaming over RPC") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + + size_t window = 1024; + size_t clientWritten = 0; + size_t serverWritten = 0; + + pipe.ends[0] = kj::heap(kj::mv(pipe.ends[0]), window, clientWritten); + pipe.ends[1] = kj::heap(kj::mv(pipe.ends[1]), window, serverWritten); + + auto ownServer = kj::heap(); + auto& server = *ownServer; + test::TestStreaming::Client serverCap(kj::mv(ownServer)); + + TwoPartyClient tpClient(*pipe.ends[0]); + TwoPartyClient tpServer(*pipe.ends[1], serverCap, rpc::twoparty::Side::SERVER); + + auto cap = tpClient.bootstrap().castAs(); + + // Send stream requests until we can't anymore. + kj::Promise promise = kj::READY_NOW; + uint count = 0; + while (promise.poll(waitScope)) { + promise.wait(waitScope); + + auto req = cap.doStreamIRequest(); + req.setI(++count); + promise = req.send(); + } + + // We should have sent... several. + KJ_EXPECT(count > 5); + + // Now, cause calls to finish server-side one-at-a-time and check that this causes the client + // side to be willing to send more. + uint countReceived = 0; + for (uint i = 0; i < 50; i++) { + KJ_EXPECT(server.iSum == ++countReceived); + server.iSum = 0; + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + + auto req = cap.doStreamIRequest(); + req.setI(++count); + promise = req.send(); + if (promise.poll(waitScope)) { + // We'll see a couple of instances where completing one request frees up space to make two + // more. This is because the first few requests we made are a little bit larger than the + // rest due to being pipelined on the bootstrap. Once the bootstrap resolves, the request + // size gets smaller. + promise.wait(waitScope); + req = cap.doStreamIRequest(); + req.setI(++count); + promise = req.send(); + + // We definitely shouldn't have freed up stream space for more than two additional requests! + KJ_ASSERT(!promise.poll(waitScope)); + } + } +} + +KJ_TEST("Streaming over a chain of local and remote RPC calls") { + // This test verifies that a local RPC call that eventually resolves to a remote RPC call will + // still support streaming calls over the remote connection. + + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + // Set up a local server that will eventually delegate requests to a remote server. + auto localPaf = kj::newPromiseAndFulfiller(); + test::TestStreaming::Client promisedClient(kj::mv(localPaf.promise)); + + uint count = 0; + auto req = promisedClient.doStreamIRequest(); + req.setI(++count); + auto promise = req.send(); + + // Expect streaming request to be blocked on promised client. + KJ_EXPECT(!promise.poll(waitScope)); + + // Set up a remote server with a flow control window for streaming. + auto pipe = kj::newTwoWayPipe(); + + size_t window = 1024; + size_t clientWritten = 0; + size_t serverWritten = 0; + + pipe.ends[0] = kj::heap(kj::mv(pipe.ends[0]), window, clientWritten); + pipe.ends[1] = kj::heap(kj::mv(pipe.ends[1]), window, serverWritten); + + auto remotePaf = kj::newPromiseAndFulfiller(); + test::TestStreaming::Client serverCap(kj::mv(remotePaf.promise)); + + TwoPartyClient tpClient(*pipe.ends[0]); + TwoPartyClient tpServer(*pipe.ends[1], kj::mv(serverCap), rpc::twoparty::Side::SERVER); + + auto clientCap = tpClient.bootstrap().castAs(); + + // Expect streaming request to be unblocked by fulfilling promised client with remote server. + localPaf.fulfiller->fulfill(kj::mv(clientCap)); + KJ_EXPECT(promise.poll(waitScope)); + + // Send stream requests until we can't anymore. + while (promise.poll(waitScope)) { + promise.wait(waitScope); + + auto req = promisedClient.doStreamIRequest(); + req.setI(++count); + promise = req.send(); + KJ_ASSERT(count < 1000); + } + + // Expect several stream requests to have fit in the flow control window. + KJ_EXPECT(count > 5); + + auto finishReq = promisedClient.finishStreamRequest(); + auto finishPromise = finishReq.send(); + KJ_EXPECT(!finishPromise.poll(waitScope)); + + // Finish calls on server + auto ownServer = kj::heap(); + auto& server = *ownServer; + remotePaf.fulfiller->fulfill(kj::mv(ownServer)); + KJ_EXPECT(!promise.poll(waitScope)); + + uint countReceived = 0; + for (uint i = 0; i < count; i++) { + KJ_EXPECT(server.iSum == ++countReceived); + server.iSum = 0; + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + if (i < count - 1) { + KJ_EXPECT(!finishPromise.poll(waitScope)); + } + } + + KJ_EXPECT(finishPromise.poll(waitScope)); + finishPromise.wait(waitScope); +} + +KJ_TEST("Streaming over RPC then unwrap with CapabilitySet") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + + CapabilityServerSet capSet; + + auto ownServer = kj::heap(); + auto& server = *ownServer; + auto serverCap = capSet.add(kj::mv(ownServer)); + + auto paf = kj::newPromiseAndFulfiller(); + + TwoPartyClient tpClient(*pipe.ends[0], serverCap); + TwoPartyClient tpServer(*pipe.ends[1], kj::mv(paf.promise), rpc::twoparty::Side::SERVER); + + auto clientCap = tpClient.bootstrap().castAs(); + + // Send stream requests until we can't anymore. + kj::Promise promise = kj::READY_NOW; + uint count = 0; + while (promise.poll(waitScope)) { + promise.wait(waitScope); + + auto req = clientCap.doStreamIRequest(); + req.setI(++count); + promise = req.send(); + } + + // We should have sent... several. + KJ_EXPECT(count > 10); + + // Now try to unwrap. + auto unwrapPromise = capSet.getLocalServer(clientCap); + + // It won't work yet, obviously, because we haven't resolved the promise. + KJ_EXPECT(!unwrapPromise.poll(waitScope)); + + // So do that. + paf.fulfiller->fulfill(tpServer.bootstrap().castAs()); + clientCap.whenResolved().wait(waitScope); + + // But the unwrap still doesn't resolve because streaming requests are queued up. + KJ_EXPECT(!unwrapPromise.poll(waitScope)); + + // OK, let's resolve a streaming request. + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + + // All of our call promises have now completed from the client's perspective. + promise.wait(waitScope); + + // But we still can't unwrap, because calls are queued server-side. + KJ_EXPECT(!unwrapPromise.poll(waitScope)); + + // Let's even make one more call now. But this is actually a local call since the promise + // resolved. + { + auto req = clientCap.doStreamIRequest(); + req.setI(++count); + promise = req.send(); + } + + // Because it's a local call, it doesn't resolve early. The window is no longer in effect. + KJ_EXPECT(!promise.poll(waitScope)); + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + + // Our unwrap promise is also still not resolved. + KJ_EXPECT(!unwrapPromise.poll(waitScope)); + + // Close out stream calls until it does resolve! + while (!unwrapPromise.poll(waitScope)) { + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + } + + // Now we can unwrap! + KJ_EXPECT(&KJ_ASSERT_NONNULL(unwrapPromise.wait(waitScope)) == &server); + + // But our last stream call still isn't done. + KJ_EXPECT(!promise.poll(waitScope)); + + // Finish it. + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + promise.wait(waitScope); +} + +KJ_TEST("promise cap resolves between starting request and sending it") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + auto pipe = kj::newTwoWayPipe(); + + // Client exports TestCallOrderImpl as its bootstrap. + TwoPartyClient client(*pipe.ends[0], kj::heap(), rpc::twoparty::Side::CLIENT); + + // Server exports a promise, which will later resolve to loop back to the capability the client + // exported. + auto paf = kj::newPromiseAndFulfiller(); + TwoPartyClient server(*pipe.ends[1], kj::mv(paf.promise), rpc::twoparty::Side::SERVER); + + // Create a request but don't send it yet. + auto cap = client.bootstrap().castAs(); + auto req1 = cap.getCallSequenceRequest(); + + // Fulfill the promise now so that the server's bootstrap loops back to the client's bootstrap. + paf.fulfiller->fulfill(server.bootstrap()); + cap.whenResolved().wait(waitScope); + + // Send the request we created earlier, and also create and send a second request. + auto promise1 = req1.send(); + auto promise2 = cap.getCallSequenceRequest().send(); + + // They should arrive in order of send()s. + auto n1 = promise1.wait(waitScope).getN(); + KJ_EXPECT(n1 == 0, n1); + auto n2 = promise2.wait(waitScope).getN(); + KJ_EXPECT(n2 == 1, n2); +} + +KJ_TEST("write error propagates to read error") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + auto frontPipe = kj::newTwoWayPipe(); + auto backPipe = kj::newTwoWayPipe(); + + TwoPartyClient client(*frontPipe.ends[0]); + + int callCount; + TwoPartyClient server(*backPipe.ends[1], kj::heap(callCount), + rpc::twoparty::Side::SERVER); + + auto pumpUpTask = frontPipe.ends[1]->pumpTo(*backPipe.ends[0]); + auto pumpDownTask = backPipe.ends[0]->pumpTo(*frontPipe.ends[1]); + + auto cap = client.bootstrap().castAs(); + + // Make sure the connections work. + { + auto req = cap.fooRequest(); + req.setI(123); + req.setJ(true); + auto resp = req.send().wait(waitScope); + EXPECT_EQ("foo", resp.getX()); + } + + // Disconnect upstream task in such a way that future writes on the client will fail, but the + // server doesn't notice the disconnect and so won't react. + pumpUpTask = nullptr; + frontPipe.ends[1]->abortRead(); // causes write() on ends[0] to fail in the future + + { + auto req = cap.fooRequest(); + req.setI(123); + req.setJ(true); + auto promise = req.send().then([](auto) { + KJ_FAIL_EXPECT("expected exception"); + }, [](kj::Exception&& e) { + KJ_ASSERT(e.getDescription() == "abortRead() has been called"); + }); + + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } +} + +class TestStreamingCancellationBug final: public test::TestStreaming::Server { +public: + uint iSum = 0; + kj::Maybe>> fulfiller; + + kj::Promise doStreamI(DoStreamIContext context) override { + auto paf = kj::newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + return paf.promise.then([this,context]() mutable { + // Don't count the sum until here so we actually detect if the call is canceled. + iSum += context.getParams().getI(); + }); + } + + kj::Promise finishStream(FinishStreamContext context) override { + auto results = context.getResults(); + results.setTotalI(iSum); + return kj::READY_NOW; + } +}; + +KJ_TEST("Streaming over RPC no premature cancellation when client dropped") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + + auto ownServer = kj::heap(); + auto& server = *ownServer; + test::TestStreaming::Client serverCap = kj::mv(ownServer); + + TwoPartyClient tpClient(*pipe.ends[0]); + TwoPartyClient tpServer(*pipe.ends[1], kj::mv(serverCap), rpc::twoparty::Side::SERVER); + + auto client = tpClient.bootstrap().castAs(); + + kj::Promise promise1 = nullptr, promise2 = nullptr; + + { + auto req = client.doStreamIRequest(); + req.setI(123); + promise1 = req.send(); + } + { + auto req = client.doStreamIRequest(); + req.setI(456); + promise2 = req.send(); + } + + auto finishPromise = client.finishStreamRequest().send(); + + KJ_EXPECT(server.iSum == 0); + + // Drop the client. This shouldn't cause a problem for the already-running RPCs. + { auto drop = kj::mv(client); } + + while (!finishPromise.poll(waitScope)) { + KJ_ASSERT_NONNULL(server.fulfiller)->fulfill(); + } + + finishPromise.wait(waitScope); + KJ_EXPECT(server.iSum == 579); +} + +KJ_TEST("Dropping capability during call doesn't destroy server") { + class TestInterfaceImpl final: public test::TestInterface::Server { + // An object which increments a count in the constructor and decrements it in the destructor, + // to detect when it is destroyed. The object's foo() method also sets a fulfiller to use to + // cause the method to complete. + public: + TestInterfaceImpl(uint& count, kj::Maybe>>& fulfillerSlot) + : count(count), fulfillerSlot(fulfillerSlot) { ++count; } + ~TestInterfaceImpl() noexcept(false) { --count; } + + kj::Promise foo(FooContext context) override { + auto paf = kj::newPromiseAndFulfiller(); + fulfillerSlot = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + + private: + uint& count; + kj::Maybe>>& fulfillerSlot; + }; + + class TestBootstrapImpl final: public test::TestMoreStuff::Server { + // Bootstrap object which just vends instances of `TestInterfaceImpl`. + public: + TestBootstrapImpl(uint& count, kj::Maybe>>& fulfillerSlot) + : count(count), fulfillerSlot(fulfillerSlot) {} + + kj::Promise getHeld(GetHeldContext context) override { + context.initResults().setCap(kj::heap(count, fulfillerSlot)); + return kj::READY_NOW; + } + + private: + uint& count; + kj::Maybe>>& fulfillerSlot; + }; + + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + auto pipe = kj::newTwoWayPipe(); + + uint count = 0; + kj::Maybe>> fulfillerSlot; + test::TestMoreStuff::Client bootstrap = kj::heap(count, fulfillerSlot); + + TwoPartyClient tpClient(*pipe.ends[0]); + TwoPartyClient tpServer(*pipe.ends[1], kj::mv(bootstrap), rpc::twoparty::Side::SERVER); + + auto cap = tpClient.bootstrap().castAs().getHeldRequest().send().getCap(); + + waitScope.poll(); + auto promise = cap.fooRequest().send(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(fulfillerSlot != nullptr); + + // Dropping the capability should not destroy the server as long as the call is still + // outstanding. + {auto drop = kj::mv(cap);} + + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(count == 1); + + // Cancelling the call still should not destroy the server because the call is not marked to + // allow cancellation. So the call should keep running. + {auto drop = kj::mv(promise);} + + waitScope.poll(); + KJ_EXPECT(count == 1); + + // When the call completes, only then should the server be dropped. + KJ_ASSERT_NONNULL(fulfillerSlot)->fulfill(); + + waitScope.poll(); + KJ_EXPECT(count == 0); +} + +RemotePromise getCallSequence( + test::TestCallOrder::Client& client, uint expected) { + auto req = client.getCallSequenceRequest(); + req.setExpected(expected); + return req.send(); +} + +KJ_TEST("Two-hop embargo") { + // Copied from `TEST(Rpc, Embargo)` in `rpc-test.c++`, adapted to involve a two-hop path through + // a proxy. This tests what happens when disembargoes on multiple hops are happening in parallel. + + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + int callCount = 0, handleCount = 0; + + // Set up two two-party RPC connections in series. The middle node just proxies requests through. + auto frontPipe = kj::newTwoWayPipe(); + auto backPipe = kj::newTwoWayPipe(); + TwoPartyClient tpClient(*frontPipe.ends[0]); + TwoPartyClient proxyBack(*backPipe.ends[0]); + TwoPartyClient proxyFront(*frontPipe.ends[1], proxyBack.bootstrap(), rpc::twoparty::Side::SERVER); + TwoPartyClient tpServer(*backPipe.ends[1], kj::heap(callCount, handleCount), + rpc::twoparty::Side::SERVER); + + // Perform some logic that does a bunch of promise pipelining, including passing a capability + // from the client to the server and back to the client, and making promise-pipelined calls on + // that capability. This should exercise the promise resolution and disembargo code. + auto client = tpClient.bootstrap().castAs(); + + auto cap = test::TestCallOrder::Client(kj::heap()); + + auto earlyCall = client.getCallSequenceRequest().send(); + + auto echoRequest = client.echoRequest(); + echoRequest.setCap(cap); + auto echo = echoRequest.send(); + + auto pipeline = echo.getCap(); + + auto call0 = getCallSequence(pipeline, 0); + auto call1 = getCallSequence(pipeline, 1); + + earlyCall.wait(waitScope); + + auto call2 = getCallSequence(pipeline, 2); + + auto resolved = echo.wait(waitScope).getCap(); + + auto call3 = getCallSequence(pipeline, 3); + auto call4 = getCallSequence(pipeline, 4); + auto call5 = getCallSequence(pipeline, 5); + + EXPECT_EQ(0, call0.wait(waitScope).getN()); + EXPECT_EQ(1, call1.wait(waitScope).getN()); + EXPECT_EQ(2, call2.wait(waitScope).getN()); + EXPECT_EQ(3, call3.wait(waitScope).getN()); + EXPECT_EQ(4, call4.wait(waitScope).getN()); + EXPECT_EQ(5, call5.wait(waitScope).getN()); +} + +class TestCallOrderImplAsPromise final: public test::TestCallOrder::Server { + // This is an implementation of TestCallOrder that presents itself as a promise by implementing + // `shortenPath()`, although it never resolves to anything (`shortenPath()` never completes). + // This tests deeper code paths in promise resolution and embargo code. +public: + template + TestCallOrderImplAsPromise(Params&&... params): inner(kj::fwd(params)...) {} + + kj::Promise getCallSequence(GetCallSequenceContext context) override { + return inner.getCallSequence(context); + } + + kj::Maybe> shortenPath() override { + // Make this object appear to be a promise. + return kj::Promise(kj::NEVER_DONE); + } + +private: + TestCallOrderImpl inner; +}; + +KJ_TEST("Two-hop embargo") { + // Same as above, but the eventual resolution is itself a promise. This verifies that + // handleDisembargo() only waits for the target to resolve back to the capability that the + // disembargo should reflect to, but not beyond that. + + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + int callCount = 0, handleCount = 0; + + // Set up two two-party RPC connections in series. The middle node just proxies requests through. + auto frontPipe = kj::newTwoWayPipe(); + auto backPipe = kj::newTwoWayPipe(); + TwoPartyClient tpClient(*frontPipe.ends[0]); + TwoPartyClient proxyBack(*backPipe.ends[0]); + TwoPartyClient proxyFront(*frontPipe.ends[1], proxyBack.bootstrap(), rpc::twoparty::Side::SERVER); + TwoPartyClient tpServer(*backPipe.ends[1], kj::heap(callCount, handleCount), + rpc::twoparty::Side::SERVER); + + // Perform some logic that does a bunch of promise pipelining, including passing a capability + // from the client to the server and back to the client, and making promise-pipelined calls on + // that capability. This should exercise the promise resolution and disembargo code. + auto client = tpClient.bootstrap().castAs(); + + auto cap = test::TestCallOrder::Client(kj::heap()); + + auto earlyCall = client.getCallSequenceRequest().send(); + + auto echoRequest = client.echoRequest(); + echoRequest.setCap(cap); + auto echo = echoRequest.send(); + + auto pipeline = echo.getCap(); + + auto call0 = getCallSequence(pipeline, 0); + auto call1 = getCallSequence(pipeline, 1); + + earlyCall.wait(waitScope); + + auto call2 = getCallSequence(pipeline, 2); + + auto resolved = echo.wait(waitScope).getCap(); + + auto call3 = getCallSequence(pipeline, 3); + auto call4 = getCallSequence(pipeline, 4); + auto call5 = getCallSequence(pipeline, 5); + + EXPECT_EQ(0, call0.wait(waitScope).getN()); + EXPECT_EQ(1, call1.wait(waitScope).getN()); + EXPECT_EQ(2, call2.wait(waitScope).getN()); + EXPECT_EQ(3, call3.wait(waitScope).getN()); + EXPECT_EQ(4, call4.wait(waitScope).getN()); + EXPECT_EQ(5, call5.wait(waitScope).getN()); +} + } // namespace } // namespace _ } // namespace capnp diff --git a/c++/src/capnp/rpc-twoparty.c++ b/c++/src/capnp/rpc-twoparty.c++ index 751904df59..09c84bbb55 100644 --- a/c++/src/capnp/rpc-twoparty.c++ +++ b/c++/src/capnp/rpc-twoparty.c++ @@ -22,13 +22,25 @@ #include "rpc-twoparty.h" #include "serialize-async.h" #include +#include namespace capnp { -TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, - ReaderOptions receiveOptions) - : stream(stream), side(side), peerVatId(4), - receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) { +TwoPartyVatNetwork::TwoPartyVatNetwork( + kj::OneOf>&& stream, + uint maxFdsPerMessage, + rpc::twoparty::Side side, + ReaderOptions receiveOptions, + const kj::MonotonicClock& clock) + + : stream(kj::mv(stream)), + maxFdsPerMessage(maxFdsPerMessage), + side(side), + peerVatId(4), + receiveOptions(receiveOptions), + previousWrite(kj::READY_NOW), + clock(clock), + currentOutgoingMessageSendTime(clock.now()) { peerVatId.initRoot().setSide( side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER : rpc::twoparty::Side::CLIENT); @@ -38,6 +50,49 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty: disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller); } +TwoPartyVatNetwork::TwoPartyVatNetwork(capnp::MessageStream& stream, + rpc::twoparty::Side side, ReaderOptions receiveOptions, + const kj::MonotonicClock& clock) + : TwoPartyVatNetwork(stream, 0, side, receiveOptions, clock) {} + +TwoPartyVatNetwork::TwoPartyVatNetwork( + capnp::MessageStream& stream, + uint maxFdsPerMessage, + rpc::twoparty::Side side, + ReaderOptions receiveOptions, + const kj::MonotonicClock& clock) + : TwoPartyVatNetwork(&stream, maxFdsPerMessage, side, receiveOptions, clock) {} + +TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, + ReaderOptions receiveOptions, + const kj::MonotonicClock& clock) + : TwoPartyVatNetwork( + kj::Own(kj::heap( + stream, IncomingRpcMessage::getShortLivedCallback())), + 0, side, receiveOptions, clock) {} + +TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMessage, + rpc::twoparty::Side side, ReaderOptions receiveOptions, + const kj::MonotonicClock& clock) + : TwoPartyVatNetwork( + kj::Own(kj::heap( + stream, IncomingRpcMessage::getShortLivedCallback())), + maxFdsPerMessage, side, receiveOptions, clock) {} + +TwoPartyVatNetwork::~TwoPartyVatNetwork() noexcept(false) {}; + +MessageStream& TwoPartyVatNetwork::getStream() { + KJ_SWITCH_ONEOF(stream) { + KJ_CASE_ONEOF(s, MessageStream*) { + return *s; + } + KJ_CASE_ONEOF(s, kj::Own) { + return *s; + } + } + KJ_UNREACHABLE; +} + void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const { if (--refcount == 0) { fulfiller->fulfill(); @@ -81,24 +136,72 @@ public: return message.getRoot(); } + void setFds(kj::Array fds) override { + if (network.maxFdsPerMessage > 0) { + this->fds = kj::mv(fds); + } + } + void send() override { size_t size = 0; for (auto& segment: message.getSegmentsForOutput()) { size += segment.size(); } - KJ_REQUIRE(size < ReaderOptions().traversalLimitInWords, size, - "Trying to send Cap'n Proto message larger than the single-message size limit. The " - "other side probably won't accept it and would abort the connection, so I won't " - "send it.") { + KJ_REQUIRE(size < network.receiveOptions.traversalLimitInWords, size, + "Trying to send Cap'n Proto message larger than our single-message size limit. The " + "other side probably won't accept it (assuming its traversalLimitInWords matches " + "ours) and would abort the connection, so I won't send it.") { return; } - network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down") - .then([&]() { - // Note that if the write fails, all further writes will be skipped due to the exception. - // We never actually handle this exception because we assume the read end will fail as well - // and it's cleaner to handle the failure there. - return writeMessage(network.stream, message); + auto sendTime = network.clock.now(); + if (network.queuedMessages.size() == 0) { + // Optimistically set sendTime when there's no messages in the queue. Without this, sending + // a message after a long delay could cause getOutgoingMessageWaitTime() to return excessively + // long wait times if it is called during the time period after send() is called, + // but before the write occurs, as we increment currentQueueCount synchronously, but + // asynchronously update currentOutgoingMessageSendTime. + network.currentOutgoingMessageSendTime = sendTime; + } + + // Instead of sending each new message as soon as possible, we attempt to batch together small + // messages by delaying when we send them using evalLast. This allows us to group together + // related small messages, reducing the number of syscalls we make. + auto& previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down"); + bool alreadyPendingSend = !network.queuedMessages.empty(); + network.currentQueueSize += message.sizeInWords() * sizeof(word); + network.queuedMessages.add(kj::addRef(*this)); + if (alreadyPendingSend) { + // The first send sets up an evalLast that will clear out pendingMessages when it's sent. + // If pendingMessages is non-empty, then there must already be a callback waiting to send + // them. + return; + } + + // On the other hand, if pendingMessages was empty, then we should set up the delayed write. + network.previousWrite = previousWrite.then([this, sendTime]() { + return kj::evalLast([this, sendTime]() -> kj::Promise { + network.currentOutgoingMessageSendTime = sendTime; + // Swap out the connection's pending messages and write all of them together. + auto ownMessages = kj::mv(network.queuedMessages); + network.currentQueueSize = 0; + auto messages = + kj::heapArray(ownMessages.size()); + for (int i = 0; i < messages.size(); ++i) { + messages[i].segments = ownMessages[i]->message.getSegmentsForOutput(); + messages[i].fds = ownMessages[i]->fds; + } + return network.getStream().writeMessages(messages).attach(kj::mv(ownMessages), kj::mv(messages)); + }).catch_([this](kj::Exception&& e) { + // Since no one checks write failures, we need to propagate them into read failures, + // otherwise we might get stuck sending all messages into a black hole and wondering why + // the peer never replies. + network.readCancelReason = kj::cp(e); + if (!network.readCanceler.isEmpty()) { + network.readCanceler.cancel(kj::cp(e)); + } + kj::throwRecoverableException(kj::mv(e)); + }); }).attach(kj::addRef(*this)) // Note that it's important that the eagerlyEvaluate() come *after* the attach() because // otherwise the message (and any capabilities in it) will not be released until a new @@ -106,23 +209,91 @@ public: .eagerlyEvaluate(nullptr); } + size_t sizeInWords() override { + return message.sizeInWords(); + } + private: TwoPartyVatNetwork& network; MallocMessageBuilder message; + kj::Array fds; }; +kj::Duration TwoPartyVatNetwork::getOutgoingMessageWaitTime() { + if (queuedMessages.size() > 0) { + return clock.now() - currentOutgoingMessageSendTime; + } else { + return 0 * kj::SECONDS; + } +} + class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage { public: IncomingMessageImpl(kj::Own message): message(kj::mv(message)) {} + IncomingMessageImpl(MessageReaderAndFds init, kj::Array fdSpace) + : message(kj::mv(init.reader)), + fdSpace(kj::mv(fdSpace)), + fds(init.fds) { + KJ_DASSERT(this->fds.begin() == this->fdSpace.begin()); + } + AnyPointer::Reader getBody() override { return message->getRoot(); } + kj::ArrayPtr getAttachedFds() override { + return fds; + } + + size_t sizeInWords() override { + return message->sizeInWords(); + } + private: kj::Own message; + kj::Array fdSpace; + kj::ArrayPtr fds; }; +kj::Own TwoPartyVatNetwork::newStream() { + return RpcFlowController::newVariableWindowController(*this); +} + +size_t TwoPartyVatNetwork::getWindow() { + // The socket's send buffer size -- as returned by getsockopt(SO_SNDBUF) -- tells us how much + // data the kernel itself is willing to buffer. The kernel will increase the send buffer size if + // needed to fill the connection's congestion window. So we can cheat and use it as our stream + // window, too, to make sure we saturate said congestion window. + // + // TODO(perf): Unfortunately, this hack breaks down in the presence of proxying. What we really + // want is the window all the way to the endpoint, which could cross multiple connections. The + // first-hop window could be either too big or too small: it's too big if the first hop has + // much higher bandwidth than the full path (causing buffering at the bottleneck), and it's + // too small if the first hop has much lower latency than the full path (causing not enough + // data to be sent to saturate the connection). To handle this, we could either: + // 1. Have proxies be aware of streaming, by flagging streaming calls in the RPC protocol. The + // proxies would then handle backpressure at each hop. This seems simple to implement but + // requires base RPC protocol changes and might require thinking carefully about e-ordering + // implications. Also, it only fixes underutilization; it does not fix buffer bloat. + // 2. Do our own BBR-like computation, where the client measures the end-to-end latency and + // bandwidth based on the observed sends and returns, and then compute the window based on + // that. This seems complicated, but avoids the need for any changes to the RPC protocol. + // In theory it solves both underutilization and buffer bloat. Note that this approach would + // require the RPC system to use a clock, which feels dirty and adds non-determinism. + + if (solSndbufUnimplemented) { + return RpcFlowController::DEFAULT_WINDOW_SIZE; + } else { + KJ_IF_MAYBE(bufSize, getStream().getSendBufferSize()) { + return *bufSize; + } else { + solSndbufUnimplemented = true; + return RpcFlowController::DEFAULT_WINDOW_SIZE; + } + } +} + rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() { return peerVatId.getRoot(); } @@ -132,12 +303,27 @@ kj::Own TwoPartyVatNetwork::newOutgoingMessage(uint firstSeg } kj::Promise>> TwoPartyVatNetwork::receiveIncomingMessage() { - return kj::evalLater([&]() { - return tryReadMessage(stream, receiveOptions) - .then([&](kj::Maybe>&& message) - -> kj::Maybe> { - KJ_IF_MAYBE(m, message) { - return kj::Own(kj::heap(kj::mv(*m))); + return kj::evalLater([this]() -> kj::Promise>> { + KJ_IF_MAYBE(e, readCancelReason) { + // A previous write failed; propagate the failure to reads, too. + return kj::cp(*e); + } + + kj::Array fdSpace = nullptr; + if(maxFdsPerMessage > 0) { + fdSpace = kj::heapArray(maxFdsPerMessage); + } + auto promise = readCanceler.wrap(getStream().tryReadMessage(fdSpace, receiveOptions)); + return promise.then([fdSpace = kj::mv(fdSpace)] + (kj::Maybe&& messageAndFds) mutable + -> kj::Maybe> { + KJ_IF_MAYBE(m, messageAndFds) { + if (m->fds.size() > 0) { + return kj::Own( + kj::heap(kj::mv(*m), kj::mv(fdSpace))); + } else { + return kj::Own(kj::heap(kj::mv(m->reader))); + } } else { return nullptr; } @@ -147,7 +333,7 @@ kj::Promise>> TwoPartyVatNetwork::receiveI kj::Promise TwoPartyVatNetwork::shutdown() { kj::Promise result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() { - stream.shutdownWrite(); + return getStream().end(); }); previousWrite = nullptr; return kj::mv(result); @@ -155,29 +341,82 @@ kj::Promise TwoPartyVatNetwork::shutdown() { // ======================================================================================= -TwoPartyServer::TwoPartyServer(Capability::Client bootstrapInterface) - : bootstrapInterface(kj::mv(bootstrapInterface)), tasks(*this) {} +TwoPartyServer::TwoPartyServer(Capability::Client bootstrapInterface, + kj::Maybe> traceEncoder) + : bootstrapInterface(kj::mv(bootstrapInterface)), + traceEncoder(kj::mv(traceEncoder)), + tasks(*this) {} struct TwoPartyServer::AcceptedConnection { kj::Own connection; TwoPartyVatNetwork network; RpcSystem rpcSystem; - explicit AcceptedConnection(Capability::Client bootstrapInterface, + explicit AcceptedConnection(TwoPartyServer& parent, kj::Own&& connectionParam) : connection(kj::mv(connectionParam)), network(*connection, rpc::twoparty::Side::SERVER), - rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {} + rpcSystem(makeRpcServer(network, kj::cp(parent.bootstrapInterface))) { + init(parent); + } + + explicit AcceptedConnection(TwoPartyServer& parent, + kj::Own&& connectionParam, + uint maxFdsPerMessage) + : connection(kj::mv(connectionParam)), + network(kj::downcast(*connection), + maxFdsPerMessage, rpc::twoparty::Side::SERVER), + rpcSystem(makeRpcServer(network, kj::cp(parent.bootstrapInterface))) { + init(parent); + } + + void init(TwoPartyServer& parent) { + KJ_IF_MAYBE(t, parent.traceEncoder) { + rpcSystem.setTraceEncoder([&func = *t](const kj::Exception& e) { + return func(e); + }); + } + } }; void TwoPartyServer::accept(kj::Own&& connection) { - auto connectionState = kj::heap(bootstrapInterface, kj::mv(connection)); + auto connectionState = kj::heap(*this, kj::mv(connection)); // Run the connection until disconnect. auto promise = connectionState->network.onDisconnect(); tasks.add(promise.attach(kj::mv(connectionState))); } +void TwoPartyServer::accept( + kj::Own&& connection, uint maxFdsPerMessage) { + auto connectionState = kj::heap( + *this, kj::mv(connection), maxFdsPerMessage); + + // Run the connection until disconnect. + auto promise = connectionState->network.onDisconnect(); + tasks.add(promise.attach(kj::mv(connectionState))); +} + +kj::Promise TwoPartyServer::accept(kj::AsyncIoStream& connection) { + auto connectionState = kj::heap(*this, + kj::Own(&connection, kj::NullDisposer::instance)); + + // Run the connection until disconnect. + auto promise = connectionState->network.onDisconnect(); + return promise.attach(kj::mv(connectionState)); +} + +kj::Promise TwoPartyServer::accept( + kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) { + auto connectionState = kj::heap(*this, + kj::Own(&connection, kj::NullDisposer::instance), + maxFdsPerMessage); + + // Run the connection until disconnect. + auto promise = connectionState->network.onDisconnect(); + return promise.attach(kj::mv(connectionState)); +} + kj::Promise TwoPartyServer::listen(kj::ConnectionReceiver& listener) { return listener.accept() .then([this,&listener](kj::Own&& connection) mutable { @@ -186,6 +425,15 @@ kj::Promise TwoPartyServer::listen(kj::ConnectionReceiver& listener) { }); } +kj::Promise TwoPartyServer::listenCapStreamReceiver( + kj::ConnectionReceiver& listener, uint maxFdsPerMessage) { + return listener.accept() + .then([this,&listener,maxFdsPerMessage](kj::Own&& connection) mutable { + accept(connection.downcast(), maxFdsPerMessage); + return listenCapStreamReceiver(listener, maxFdsPerMessage); + }); +} + void TwoPartyServer::taskFailed(kj::Exception&& exception) { KJ_LOG(ERROR, exception); } @@ -195,14 +443,26 @@ TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection) rpcSystem(makeRpcClient(network)) {} +TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) + : network(connection, maxFdsPerMessage, rpc::twoparty::Side::CLIENT), + rpcSystem(makeRpcClient(network)) {} + TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface, rpc::twoparty::Side side) : network(connection, side), rpcSystem(network, bootstrapInterface) {} +TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage, + Capability::Client bootstrapInterface, + rpc::twoparty::Side side) + : network(connection, maxFdsPerMessage, side), + rpcSystem(network, bootstrapInterface) {} + Capability::Client TwoPartyClient::bootstrap() { - MallocMessageBuilder message(4); + capnp::word scratch[4]; + memset(&scratch, 0, sizeof(scratch)); + capnp::MallocMessageBuilder message(scratch); auto vatId = message.getRoot(); vatId.setSide(network.getSide() == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER @@ -210,4 +470,8 @@ Capability::Client TwoPartyClient::bootstrap() { return rpcSystem.bootstrap(vatId); } +void TwoPartyClient::setTraceEncoder(kj::Function func) { + rpcSystem.setTraceEncoder(kj::mv(func)); +} + } // namespace capnp diff --git a/c++/src/capnp/rpc-twoparty.capnp b/c++/src/capnp/rpc-twoparty.capnp index 0b670e8ac3..5f0e2150e7 100644 --- a/c++/src/capnp/rpc-twoparty.capnp +++ b/c++/src/capnp/rpc-twoparty.capnp @@ -162,8 +162,6 @@ struct JoinResult { # implements the join by waiting for all the `JoinKeyParts` and then performing its own join on # them, then going back and answering all the join requests afterwards. - cap @2 :AnyPointer; + cap @2 :Capability; # One of the JoinResults will have a non-null `cap` which is the joined capability. - # - # TODO(cleanup): Change `AnyPointer` to `Capability` when that is supported. } diff --git a/c++/src/capnp/rpc-twoparty.capnp.c++ b/c++/src/capnp/rpc-twoparty.capnp.c++ index 64ae32bf2e..6809cebcf4 100644 --- a/c++/src/capnp/rpc-twoparty.capnp.c++ +++ b/c++/src/capnp/rpc-twoparty.capnp.c++ @@ -38,7 +38,7 @@ static const ::capnp::_::AlignedData<26> b_9fd69ebc87b9719c = { static const uint16_t m_9fd69ebc87b9719c[] = {1, 0}; const ::capnp::_::RawSchema s_9fd69ebc87b9719c = { 0x9fd69ebc87b9719c, b_9fd69ebc87b9719c.words, 26, nullptr, m_9fd69ebc87b9719c, - 0, 2, nullptr, nullptr, nullptr, { &s_9fd69ebc87b9719c, nullptr, nullptr, 0, 0, nullptr } + 0, 2, nullptr, nullptr, nullptr, { &s_9fd69ebc87b9719c, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE CAPNP_DEFINE_ENUM(Side_9fd69ebc87b9719c, 9fd69ebc87b9719c); @@ -86,7 +86,7 @@ static const uint16_t m_d20b909fee733a8e[] = {0}; static const uint16_t i_d20b909fee733a8e[] = {0}; const ::capnp::_::RawSchema s_d20b909fee733a8e = { 0xd20b909fee733a8e, b_d20b909fee733a8e.words, 33, d_d20b909fee733a8e, m_d20b909fee733a8e, - 1, 1, i_d20b909fee733a8e, nullptr, nullptr, { &s_d20b909fee733a8e, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_d20b909fee733a8e, nullptr, nullptr, { &s_d20b909fee733a8e, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<34> b_b88d09a9c5f39817 = { @@ -131,7 +131,7 @@ static const uint16_t m_b88d09a9c5f39817[] = {0}; static const uint16_t i_b88d09a9c5f39817[] = {0}; const ::capnp::_::RawSchema s_b88d09a9c5f39817 = { 0xb88d09a9c5f39817, b_b88d09a9c5f39817.words, 34, nullptr, m_b88d09a9c5f39817, - 0, 1, i_b88d09a9c5f39817, nullptr, nullptr, { &s_b88d09a9c5f39817, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_b88d09a9c5f39817, nullptr, nullptr, { &s_b88d09a9c5f39817, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<18> b_89f389b6fd4082c1 = { @@ -158,7 +158,7 @@ static const ::capnp::_::AlignedData<18> b_89f389b6fd4082c1 = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_89f389b6fd4082c1 = { 0x89f389b6fd4082c1, b_89f389b6fd4082c1.words, 18, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_89f389b6fd4082c1, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_89f389b6fd4082c1, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<19> b_b47f4979672cb59d = { @@ -186,7 +186,7 @@ static const ::capnp::_::AlignedData<19> b_b47f4979672cb59d = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_b47f4979672cb59d = { 0xb47f4979672cb59d, b_b47f4979672cb59d.words, 19, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_b47f4979672cb59d, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_b47f4979672cb59d, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_95b29059097fca83 = { @@ -262,7 +262,7 @@ static const uint16_t m_95b29059097fca83[] = {0, 1, 2}; static const uint16_t i_95b29059097fca83[] = {0, 1, 2}; const ::capnp::_::RawSchema s_95b29059097fca83 = { 0x95b29059097fca83, b_95b29059097fca83.words, 65, nullptr, m_95b29059097fca83, - 0, 3, i_95b29059097fca83, nullptr, nullptr, { &s_95b29059097fca83, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_95b29059097fca83, nullptr, nullptr, { &s_95b29059097fca83, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_9d263a3630b7ebee = { @@ -325,7 +325,7 @@ static const ::capnp::_::AlignedData<65> b_9d263a3630b7ebee = { 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, @@ -338,7 +338,7 @@ static const uint16_t m_9d263a3630b7ebee[] = {2, 0, 1}; static const uint16_t i_9d263a3630b7ebee[] = {0, 1, 2}; const ::capnp::_::RawSchema s_9d263a3630b7ebee = { 0x9d263a3630b7ebee, b_9d263a3630b7ebee.words, 65, nullptr, m_9d263a3630b7ebee, - 0, 3, i_9d263a3630b7ebee, nullptr, nullptr, { &s_9d263a3630b7ebee, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_9d263a3630b7ebee, nullptr, nullptr, { &s_9d263a3630b7ebee, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE } // namespace schemas @@ -351,51 +351,75 @@ namespace rpc { namespace twoparty { // VatId +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t VatId::_capnpPrivate::dataWordSize; constexpr uint16_t VatId::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind VatId::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* VatId::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // ProvisionId +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t ProvisionId::_capnpPrivate::dataWordSize; constexpr uint16_t ProvisionId::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind ProvisionId::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* ProvisionId::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // RecipientId +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t RecipientId::_capnpPrivate::dataWordSize; constexpr uint16_t RecipientId::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind RecipientId::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* RecipientId::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // ThirdPartyCapId +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t ThirdPartyCapId::_capnpPrivate::dataWordSize; constexpr uint16_t ThirdPartyCapId::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind ThirdPartyCapId::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* ThirdPartyCapId::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // JoinKeyPart +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t JoinKeyPart::_capnpPrivate::dataWordSize; constexpr uint16_t JoinKeyPart::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind JoinKeyPart::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* JoinKeyPart::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // JoinResult +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t JoinResult::_capnpPrivate::dataWordSize; constexpr uint16_t JoinResult::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind JoinResult::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* JoinResult::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/c++/src/capnp/rpc-twoparty.capnp.h b/c++/src/capnp/rpc-twoparty.capnp.h index 9d7820646a..403cd6ed8c 100644 --- a/c++/src/capnp/rpc-twoparty.capnp.h +++ b/c++/src/capnp/rpc-twoparty.capnp.h @@ -1,16 +1,23 @@ // Generated by Cap'n Proto compiler, DO NOT EDIT // source: rpc-twoparty.capnp -#ifndef CAPNP_INCLUDED_a184c7885cdaf2a1_ -#define CAPNP_INCLUDED_a184c7885cdaf2a1_ +#pragma once #include +#include +#if !CAPNP_LITE +#include +#endif // !CAPNP_LITE -#if CAPNP_VERSION != 6001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace schemas { @@ -530,7 +537,9 @@ class JoinResult::Reader { inline bool getSucceeded() const; inline bool hasCap() const; - inline ::capnp::AnyPointer::Reader getCap() const; +#if !CAPNP_LITE + inline ::capnp::Capability::Client getCap() const; +#endif // !CAPNP_LITE private: ::capnp::_::StructReader _reader; @@ -567,8 +576,13 @@ class JoinResult::Builder { inline void setSucceeded(bool value); inline bool hasCap(); - inline ::capnp::AnyPointer::Builder getCap(); - inline ::capnp::AnyPointer::Builder initCap(); +#if !CAPNP_LITE + inline ::capnp::Capability::Client getCap(); + inline void setCap( ::capnp::Capability::Client&& value); + inline void setCap( ::capnp::Capability::Client& value); + inline void adoptCap(::capnp::Orphan< ::capnp::Capability>&& value); + inline ::capnp::Orphan< ::capnp::Capability> disownCap(); +#endif // !CAPNP_LITE private: ::capnp::_::StructBuilder _builder; @@ -588,6 +602,7 @@ class JoinResult::Pipeline { inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) : _typeless(kj::mv(typeless)) {} + inline ::capnp::Capability::Client getCap(); private: ::capnp::AnyPointer::Pipeline _typeless; friend class ::capnp::PipelineHook; @@ -704,23 +719,40 @@ inline bool JoinResult::Builder::hasCap() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::AnyPointer::Reader JoinResult::Reader::getCap() const { - return ::capnp::AnyPointer::Reader(_reader.getPointerField( +#if !CAPNP_LITE +inline ::capnp::Capability::Client JoinResult::Reader::getCap() const { + return ::capnp::_::PointerHelpers< ::capnp::Capability>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::AnyPointer::Builder JoinResult::Builder::getCap() { - return ::capnp::AnyPointer::Builder(_builder.getPointerField( +inline ::capnp::Capability::Client JoinResult::Builder::getCap() { + return ::capnp::_::PointerHelpers< ::capnp::Capability>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::AnyPointer::Builder JoinResult::Builder::initCap() { - auto result = ::capnp::AnyPointer::Builder(_builder.getPointerField( +inline ::capnp::Capability::Client JoinResult::Pipeline::getCap() { + return ::capnp::Capability::Client(_typeless.getPointerField(0).asCap()); +} +inline void JoinResult::Builder::setCap( ::capnp::Capability::Client&& cap) { + ::capnp::_::PointerHelpers< ::capnp::Capability>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(cap)); +} +inline void JoinResult::Builder::setCap( ::capnp::Capability::Client& cap) { + ::capnp::_::PointerHelpers< ::capnp::Capability>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), cap); +} +inline void JoinResult::Builder::adoptCap( + ::capnp::Orphan< ::capnp::Capability>&& value) { + ::capnp::_::PointerHelpers< ::capnp::Capability>::adopt(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Capability> JoinResult::Builder::disownCap() { + return ::capnp::_::PointerHelpers< ::capnp::Capability>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); - result.clear(); - return result; } +#endif // !CAPNP_LITE } // namespace } // namespace } // namespace -#endif // CAPNP_INCLUDED_a184c7885cdaf2a1_ +CAPNP_END_HEADER + diff --git a/c++/src/capnp/rpc-twoparty.h b/c++/src/capnp/rpc-twoparty.h index 093c1fecdf..c280e62d40 100644 --- a/c++/src/capnp/rpc-twoparty.h +++ b/c++/src/capnp/rpc-twoparty.h @@ -19,17 +19,16 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_RPC_TWOPARTY_H_ -#define CAPNP_RPC_TWOPARTY_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "rpc.h" -#include "message.h" +#include #include +#include #include +#include + +CAPNP_BEGIN_HEADER namespace capnp { @@ -44,7 +43,8 @@ typedef VatNetwork onDisconnect() { return disconnectPromise.addBranch(); } // Returns a promise that resolves when the peer disconnects. rpc::twoparty::Side getSide() { return side; } + size_t getCurrentQueueSize() { return currentQueueSize; } + // Get the number of bytes worth of outgoing messages that are currently queued in memory waiting + // to be sent on this connection. This may be useful for backpressure. + + size_t getCurrentQueueCount() { return queuedMessages.size(); } + // Get the count of outgoing messages that are currently queued in memory waiting + // to be sent on this connection. This may be useful for backpressure. + + kj::Duration getOutgoingMessageWaitTime(); + // Get how long the current outgoing message has been waiting to be sent on this connection. + // Returns 0 if the queue is empty. This may be useful for backpressure. + // implements VatNetwork ----------------------------------------------------- kj::Maybe> connect( @@ -71,12 +109,23 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, class OutgoingMessageImpl; class IncomingMessageImpl; - kj::AsyncIoStream& stream; + kj::OneOf> stream; + // The underlying stream, which we may or may not own. Get a reference to + // this with getStream, rather than reading it directly. + + uint maxFdsPerMessage; rpc::twoparty::Side side; MallocMessageBuilder peerVatId; ReaderOptions receiveOptions; bool accepted = false; + bool solSndbufUnimplemented = false; + // Whether stream.getsockopt(SO_SNDBUF) has been observed to throw UNIMPLEMENTED. + + kj::Canceler readCanceler; + kj::Maybe readCancelReason; + // Used to propagate write errors into (permanent) read errors. + kj::Maybe> previousWrite; // Resolves when the previous write completes. This effectively serves as the write queue. // Becomes null when shutdown() is called. @@ -87,6 +136,11 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, kj::ForkedPromise disconnectPromise = nullptr; + kj::Vector> queuedMessages; + size_t currentQueueSize = 0; + const kj::MonotonicClock& clock; + kj::TimePoint currentOutgoingMessageSendTime; + class FulfillerDisposer: public kj::Disposer { // Hack: TwoPartyVatNetwork is both a VatNetwork and a VatNetwork::Connection. When the RPC // system detects (or initiates) a disconnection, it drops its reference to the Connection. @@ -102,34 +156,74 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, }; FulfillerDisposer disconnectFulfiller; + + TwoPartyVatNetwork( + kj::OneOf>&& stream, + uint maxFdsPerMessage, + rpc::twoparty::Side side, + ReaderOptions receiveOptions, + const kj::MonotonicClock& clock); + + MessageStream& getStream(); + kj::Own asConnection(); // Returns a pointer to this with the disposer set to disconnectFulfiller. // implements Connection ----------------------------------------------------- + kj::Own newStream() override; rpc::twoparty::VatId::Reader getPeerVatId() override; kj::Own newOutgoingMessage(uint firstSegmentWordSize) override; kj::Promise>> receiveIncomingMessage() override; kj::Promise shutdown() override; + + // implements WindowGetter --------------------------------------------------- + + size_t getWindow() override; }; class TwoPartyServer: private kj::TaskSet::ErrorHandler { // Convenience class which implements a simple server which accepts connections on a listener - // socket and serices them as two-party connections. + // socket and services them as two-party connections. public: - explicit TwoPartyServer(Capability::Client bootstrapInterface); + explicit TwoPartyServer(Capability::Client bootstrapInterface, + kj::Maybe> traceEncoder = nullptr); + // `traceEncoder`, if provided, will be passed on to `rpcSystem.setTraceEncoder()`. void accept(kj::Own&& connection); + void accept(kj::Own&& connection, uint maxFdsPerMessage); // Accepts the connection for servicing. + kj::Promise accept(kj::AsyncIoStream& connection) KJ_WARN_UNUSED_RESULT; + kj::Promise accept(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) + KJ_WARN_UNUSED_RESULT; + // Accept connection without taking ownership. The returned promise resolves when the client + // disconnects. Dropping the promise forcefully cancels the RPC protocol. + // + // You probably can't do anything with `connection` after the RPC protocol has terminated, other + // than to close it. The main reason to use these methods rather than the ownership-taking ones + // is if your stream object becomes invalid outside some scope, so you want to make sure to + // cancel all usage of it before that by cancelling the promise. + kj::Promise listen(kj::ConnectionReceiver& listener); // Listens for connections on the given listener. The returned promise never resolves unless an // exception is thrown while trying to accept. You may discard the returned promise to cancel // listening. + kj::Promise listenCapStreamReceiver( + kj::ConnectionReceiver& listener, uint maxFdsPerMessage); + // Listen with support for FD transfers. `listener.accept()` must return instances of + // AsyncCapabilityStream, otherwise this will crash. + + kj::Promise drain() { return tasks.onEmpty(); } + // Resolves when all clients have disconnected. + // + // Only considers clients whose connections TwoPartyServer took ownership of. + private: Capability::Client bootstrapInterface; + kj::Maybe> traceEncoder; kj::TaskSet tasks; struct AcceptedConnection; @@ -142,14 +236,25 @@ class TwoPartyClient { public: explicit TwoPartyClient(kj::AsyncIoStream& connection); + explicit TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage); TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface, rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT); + TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage, + Capability::Client bootstrapInterface, + rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT); Capability::Client bootstrap(); // Get the server's bootstrap interface. inline kj::Promise onDisconnect() { return network.onDisconnect(); } + void setTraceEncoder(kj::Function func); + // Forwarded to rpcSystem.setTraceEncoder(). + + size_t getCurrentQueueSize() { return network.getCurrentQueueSize(); } + size_t getCurrentQueueCount() { return network.getCurrentQueueCount(); } + kj::Duration getOutgoingMessageWaitTime() { return network.getOutgoingMessageWaitTime(); } + private: TwoPartyVatNetwork network; RpcSystem rpcSystem; @@ -157,4 +262,4 @@ class TwoPartyClient { } // namespace capnp -#endif // CAPNP_RPC_TWOPARTY_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/rpc.c++ b/c++/src/capnp/rpc.c++ index 6d9177b9e1..6118e3166d 100644 --- a/c++/src/capnp/rpc.c++ +++ b/c++/src/capnp/rpc.c++ @@ -31,6 +31,8 @@ #include #include #include +#include +#include namespace capnp { namespace _ { // private @@ -55,7 +57,9 @@ constexpr const uint CAP_DESCRIPTOR_SIZE_HINT = sizeInWords( constexpr const uint64_t MAX_SIZE_HINT = 1 << 20; uint copySizeHint(MessageSize size) { - uint64_t sizeHint = size.wordCount + size.capCount * CAP_DESCRIPTOR_SIZE_HINT; + uint64_t sizeHint = size.wordCount + size.capCount * CAP_DESCRIPTOR_SIZE_HINT + // if capCount > 0, the cap descriptor list has a 1-word tag + + (size.capCount > 0); return kj::min(MAX_SIZE_HINT, sizeHint); } @@ -108,14 +112,24 @@ Orphan> fromPipelineOps( } kj::Exception toException(const rpc::Exception::Reader& exception) { - return kj::Exception(static_cast(exception.getType()), - "(remote)", 0, kj::str("remote exception: ", exception.getReason())); -} + auto reason = [&]() { + if (exception.getReason().startsWith("remote exception: ")) { + return kj::str(exception.getReason()); + } else { + return kj::str("remote exception: ", exception.getReason()); + } + }(); -void fromException(const kj::Exception& exception, rpc::Exception::Builder builder) { - // TODO(someday): Indicate the remote server name as part of the stack trace. Maybe even - // transmit stack traces? + kj::Exception result(static_cast(exception.getType()), + "(remote)", 0, kj::mv(reason)); + if (exception.hasTrace()) { + result.setRemoteTrace(kj::str(exception.getTrace())); + } + return result; +} +void fromException(const kj::Exception& exception, rpc::Exception::Builder builder, + kj::Maybe&> traceEncoder) { kj::StringPtr description = exception.getDescription(); // Include context, if any. @@ -137,6 +151,10 @@ void fromException(const kj::Exception& exception, rpc::Exception::Builder build builder.setReason(description); builder.setType(static_cast(exception.getType())); + KJ_IF_MAYBE(t, traceEncoder) { + builder.setTrace((*t)(exception)); + } + if (exception.getType() == kj::Exception::Type::FAILED && !exception.getDescription().startsWith("remote exception:")) { KJ_LOG(INFO, "returning failure over rpc", exception); @@ -147,15 +165,33 @@ uint exceptionSizeHint(const kj::Exception& exception) { return sizeInWords() + exception.getDescription().size() / sizeof(word) + 1; } +ClientHook::CallHints callHintsFromReader(rpc::Call::Reader reader) { + ClientHook::CallHints hints; + hints.noPromisePipelining = reader.getNoPromisePipelining(); + hints.onlyPromisePipeline = reader.getOnlyPromisePipeline(); + return hints; +} + // ======================================================================================= +template +static constexpr Id highBit() { + return 1u << (sizeof(Id) * 8 - 1); +} + template class ExportTable { // Table mapping integers to T, where the integers are chosen locally. public: + bool isHigh(Id& id) { + return (id & highBit()) != 0; + } + kj::Maybe find(Id id) { - if (id < slots.size() && slots[id] != nullptr) { + if (isHigh(id)) { + return highSlots.find(id); + } else if (id < slots.size() && slots[id] != nullptr) { return slots[id]; } else { return nullptr; @@ -168,16 +204,23 @@ public: // `entry` is a reference to the entry being released -- we require this in order to prove // that the caller has already done a find() to check that this entry exists. We can't check // ourselves because the caller may have nullified the entry in the meantime. - KJ_DREQUIRE(&entry == &slots[id]); - T toRelease = kj::mv(slots[id]); - slots[id] = T(); - freeIds.push(id); - return toRelease; + + if (isHigh(id)) { + auto& slot = KJ_REQUIRE_NONNULL(highSlots.findEntry(id)); + return highSlots.release(slot).value; + } else { + KJ_DREQUIRE(&entry == &slots[id]); + T toRelease = kj::mv(slots[id]); + slots[id] = T(); + freeIds.push(id); + return toRelease; + } } T& next(Id& id) { if (freeIds.empty()) { id = slots.size(); + KJ_ASSERT(!isHigh(id), "2^31 concurrent questions?!!?!"); return slots.add(); } else { id = freeIds.top(); @@ -186,6 +229,25 @@ public: } } + T& nextHigh(Id& id) { + // Choose an ID with the top bit set in round-robin fashion, but don't choose an ID that + // is still in use. + + KJ_ASSERT(highSlots.size() < Id(kj::maxValue) / 2); // avoid infinite loop below. + + bool created = false; + T* slot; + while (!created) { + id = highCounter++ | highBit(); + slot = &highSlots.findOrCreate(id, [&]() { + created = true; + return typename kj::HashMap::Entry { id, T() }; + }); + } + + return *slot; + } + template void forEach(Func&& func) { for (Id i = 0; i < slots.size(); i++) { @@ -193,11 +255,24 @@ public: func(i, slots[i]); } } + for (auto& slot: highSlots) { + func(slot.key, slot.value); + } + } + + void release() { + // Release memory backing the table. + { auto drop = kj::mv(slots); } + { auto drop = kj::mv(freeIds); } + { auto drop = kj::mv(highSlots); } } private: kj::Vector slots; std::priority_queue, std::greater> freeIds; + + kj::HashMap highSlots; + Id highCounter = 0; }; template @@ -265,14 +340,14 @@ public: }; RpcConnectionState(BootstrapFactoryBase& bootstrapFactory, - kj::Maybe::Client> gateway, kj::Maybe restorer, kj::Own&& connectionParam, kj::Own>&& disconnectFulfiller, - size_t flowLimit) - : bootstrapFactory(bootstrapFactory), gateway(kj::mv(gateway)), + size_t flowLimit, + kj::Maybe&> traceEncoder) + : bootstrapFactory(bootstrapFactory), restorer(restorer), disconnectFulfiller(kj::mv(disconnectFulfiller)), flowLimit(flowLimit), - tasks(*this) { + traceEncoder(traceEncoder), tasks(*this) { connection.init(kj::mv(connectionParam)); tasks.add(messageLoop()); } @@ -315,6 +390,17 @@ public: } void disconnect(kj::Exception&& exception) { + // Shut down the connection with the given error. + // + // This will cancel `tasks`, so cannot be called from inside a task in `tasks`. Instead, use + // `tasks.add(exception)` to schedule a shutdown, since any error thrown by a task will be + // passed to `disconnect()` later. + + // After disconnect(), the RpcSystem could be destroyed, making `traceEncoder` a dangling + // reference, so null it out before we return from here. We don't need it anymore once + // disconnected anyway. + KJ_DEFER(traceEncoder = nullptr); + if (!connection.is()) { // Already disconnected. return; @@ -323,39 +409,67 @@ public: kj::Exception networkException(kj::Exception::Type::DISCONNECTED, exception.getFile(), exception.getLine(), kj::heapString(exception.getDescription())); + // Don't throw away the stack trace. + if (exception.getRemoteTrace() != nullptr) { + networkException.setRemoteTrace(kj::str(exception.getRemoteTrace())); + } + for (void* addr: exception.getStackTrace()) { + networkException.addTrace(addr); + } + // If your stack trace points here, it means that the exception became the reason that the + // RPC connection was disconnected. The exception was then thrown by all in-flight calls and + // all future calls on this connection. + networkException.addTraceHere(); + + // Set our connection state to Disconnected now so that no one tries to write any messages to + // it in their destructors. + auto dyingConnection = kj::mv(connection.get()); + connection.init(kj::cp(networkException)); + KJ_IF_MAYBE(newException, kj::runCatchingExceptions([&]() { // Carefully pull all the objects out of the tables prior to releasing them because their // destructors could come back and mess with the tables. kj::Vector> pipelinesToRelease; kj::Vector> clientsToRelease; - kj::Vector>> tailCallsToRelease; + kj::Vector tasksToRelease; kj::Vector> resolveOpsToRelease; + KJ_DEFER(tasks.clear()); // All current questions complete with exceptions. questions.forEach([&](QuestionId id, Question& question) { KJ_IF_MAYBE(questionRef, question.selfRef) { // QuestionRef still present. questionRef->reject(kj::cp(networkException)); + + // We need to fully disconnect each QuestionRef otherwise it holds a reference back to + // the connection state. Meanwhile `tasks` may hold streaming calls that end up holding + // these QuestionRefs. Technically this is a cyclic reference, but as long as the cycle + // is broken on disconnect (which happens when the RpcSystem itself is destroyed), then + // we're OK. + questionRef->disconnect(); } }); + // Since we've disconnected the QuestionRefs, they won't clean up the questions table for + // us, so do that here. + questions.release(); answers.forEach([&](AnswerId id, Answer& answer) { KJ_IF_MAYBE(p, answer.pipeline) { pipelinesToRelease.add(kj::mv(*p)); } - KJ_IF_MAYBE(promise, answer.redirectedResults) { - tailCallsToRelease.add(kj::mv(*promise)); - } + tasksToRelease.add(kj::mv(answer.task)); KJ_IF_MAYBE(context, answer.callContext) { - context->requestCancel(); + context->finish(); } }); exports.forEach([&](ExportId id, Export& exp) { clientsToRelease.add(kj::mv(exp.clientHook)); - resolveOpsToRelease.add(kj::mv(exp.resolveOp)); + KJ_IF_MAYBE(op, exp.resolveOp) { + resolveOpsToRelease.add(kj::mv(*op)); + } exp = Export(); }); @@ -379,25 +493,36 @@ public: // Send an abort message, but ignore failure. kj::runCatchingExceptions([&]() { - auto message = connection.get()->newOutgoingMessage( + auto message = dyingConnection->newOutgoingMessage( messageSizeHint() + exceptionSizeHint(exception)); fromException(exception, message->getBody().getAs().initAbort()); message->send(); }); // Indicate disconnect. - auto shutdownPromise = connection.get()->shutdown() - .attach(kj::mv(connection.get())) + auto shutdownPromise = dyingConnection->shutdown() + .attach(kj::mv(dyingConnection)) .then([]() -> kj::Promise { return kj::READY_NOW; }, - [](kj::Exception&& e) -> kj::Promise { + [this, origException = kj::mv(exception)](kj::Exception&& shutdownException) -> kj::Promise { // Don't report disconnects as an error. - if (e.getType() != kj::Exception::Type::DISCONNECTED) { - return kj::mv(e); + if (shutdownException.getType() == kj::Exception::Type::DISCONNECTED) { + return kj::READY_NOW; } - return kj::READY_NOW; + // If the error is just what was passed in to disconnect(), don't report it back out + // since it shouldn't be anything the caller doesn't already know about. + if (shutdownException.getType() == origException.getType() && + shutdownException.getDescription() == origException.getDescription()) { + return kj::READY_NOW; + } + // We are shutting down after receive error, ignore shutdown exception since underlying + // transport is probably broken. + if (receiveIncomingMessageError) { + return kj::READY_NOW; + } + return kj::mv(shutdownException); }); disconnectFulfiller->fulfill(DisconnectInfo { kj::mv(shutdownPromise) }); - connection.init(kj::mv(networkException)); + canceler.cancel(networkException); } void setFlowLimit(size_t words) { @@ -450,6 +575,11 @@ private: bool skipFinish = false; // If true, don't send a Finish message. + // + // This is used in two cases: + // * The `Return` message had the `noFinishNeeded` hint. + // * Our attempt to send the `Call` threw an exception, therefore the peer never even received + // the call in the first place and would not expect a `Finish`. inline bool operator==(decltype(nullptr)) const { return !isAwaitingReturn && selfRef == nullptr; @@ -471,7 +601,17 @@ private: kj::Maybe> pipeline; // Send pipelined calls here. Becomes null as soon as a `Finish` is received. - kj::Maybe>> redirectedResults; + using Running = kj::Promise; + struct Finished {}; + using Redirected = kj::Promise>; + + kj::OneOf task; + // While the RPC is running locally, `task` is a `Promise` representing the task to execute + // the RPC. + // + // When `Finish` is received (and results are not redirected), `task` becomes `Finished`, which + // cancels it if it's still running. + // // For locally-redirected calls (Call.sendResultsTo.yourself), this is a promise for the call // result, to be picked up by a subsequent `Return`. @@ -490,7 +630,7 @@ private: kj::Own clientHook; - kj::Promise resolveOp = nullptr; + kj::Maybe> resolveOp = nullptr; // If this export is a promise (not a settled capability), the `resolveOp` represents the // ongoing operation to wait for that promise to resolve and then send a `Resolve` message. @@ -533,7 +673,6 @@ private: // OK, now we can define RpcConnectionState's member data. BootstrapFactoryBase& bootstrapFactory; - kj::Maybe::Client> gateway; kj::Maybe restorer; typedef kj::Own Connected; @@ -542,6 +681,11 @@ private: // Once the connection has failed, we drop it and replace it with an exception, which will be // thrown from all further calls. + kj::Canceler canceler; + // Will be canceled if and when `connection` is changed from `Connected` to `Disconnected`. + // TODO(cleanup): `Connected` should be a struct that contains the connection and the Canceler, + // but that's more refactoring than I want to do right now. + kj::Own> disconnectFulfiller; ExportTable exports; @@ -565,8 +709,23 @@ private: // If non-null, we're currently blocking incoming messages waiting for callWordsInFlight to drop // below flowLimit. Fulfill this to un-block. + kj::Maybe&> traceEncoder; + kj::TaskSet tasks; + bool gotReturnForHighQuestionId = false; + // Becomes true if we ever get a `Return` message for a high question ID (with top bit set), + // which we use in cases where we've hinted to the peer that we don't want a `Return`. If the + // peer sends us one anyway then it seemingly doesn't not implement our hints. We need to stop + // using the hints in this case before the high question ID space wraps around since otherwise + // we might reuse an ID that the peer thinks is still in use. + + bool sentCapabilitiesInPipelineOnlyCall = false; + // Becomes true if `sendPipelineOnly()` is ever called with parameters that include capabilities. + + bool receiveIncomingMessageError = false; + // Becomes true when receiveIncomingMessage resulted in exception. + // ===================================================================================== // ClientHook implementations @@ -575,7 +734,15 @@ private: RpcClient(RpcConnectionState& connectionState) : connectionState(kj::addRef(connectionState)) {} - virtual kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor) = 0; + ~RpcClient() noexcept(false) { + KJ_IF_MAYBE(f, this->flowController) { + // Destroying the client should not cancel outstanding streaming calls. + connectionState->tasks.add(f->get()->waitAllAcked().attach(kj::mv(*f))); + } + } + + virtual kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor, + kj::Vector& fds) = 0; // Writes a CapDescriptor referencing this client. The CapDescriptor must be sent as part of // the very next message sent on the connection, as it may become invalid if other things // happen. @@ -598,47 +765,36 @@ private: // that other client -- return a reference to the other client, transitively. Otherwise, // return a new reference to *this. - // implements ClientHook ----------------------------------------- - - Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { - if (interfaceId == typeId>() && methodId == 0) { - KJ_IF_MAYBE(g, connectionState->gateway) { - // Wait, this is a call to Persistent.save() and we need to translate it through our - // gateway. - // - // We pull a neat trick here: We actually end up returning a RequestHook for an import - // request on the gateway cap, but with the "root" of the request actually pointing - // to the "params" field of the real request. - - sizeHint = sizeHint.map([](MessageSize hint) { - ++hint.capCount; - hint.wordCount += sizeInWords::ImportParams>(); - return hint; - }); - - auto request = g->importRequest(sizeHint); - request.setCap(Persistent<>::Client(kj::refcounted(*this))); - - // Awkwardly, request.initParams() would return a SaveParams struct, but to construct - // the Request to return we need an AnyPointer::Builder, and you - // can't go backwards from a struct builder to an AnyPointer builder. So instead we - // manually get at the pointer by converting the outer request to AnyStruct and then - // pulling the pointer from the pointer section. - auto pointers = toAny(request).getPointerSection(); - KJ_ASSERT(pointers.size() >= 2); - auto paramsPtr = pointers[1]; - KJ_ASSERT(paramsPtr.isNull()); + virtual void adoptFlowController(kj::Own flowController) { + // Called when a PromiseClient resolves to another RpcClient. If streaming calls were + // outstanding on the old client, we'd like to keep using the same FlowController on the new + // client, so as to keep the flow steady. - return Request(paramsPtr, RequestHook::from(kj::mv(request))); - } + if (this->flowController == nullptr) { + // We don't have any existing flowController so we can adopt this one, yay! + this->flowController = kj::mv(flowController); + } else { + // Apparently, there is an existing flowController. This is an unusual scenario: Apparently + // we had two stream capabilities, we were streaming to both of them, and they later + // resolved to the same capability. This probably never happens because streaming use cases + // normally call for there to be only one client. But, it's certainly possible, and we need + // to handle it. We'll do the conservative thing and just make sure that all the calls + // finish. This may mean we'll over-buffer temporarily; oh well. + connectionState->tasks.add(flowController->waitAllAcked().attach(kj::mv(flowController))); } + } + + // implements ClientHook ----------------------------------------- - return newCallNoIntercept(interfaceId, methodId, sizeHint); + Request newCall( + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { + return newCallNoIntercept(interfaceId, methodId, sizeHint, hints); } Request newCallNoIntercept( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) { + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) { if (!connectionState->connection.is()) { return newBrokenRequest(kj::cp(connectionState->connection.get()), sizeHint); } @@ -650,49 +806,28 @@ private: callBuilder.setInterfaceId(interfaceId); callBuilder.setMethodId(methodId); + callBuilder.setNoPromisePipelining(hints.noPromisePipelining); + callBuilder.setOnlyPromisePipeline(hints.onlyPromisePipeline); auto root = request->getRoot(); return Request(root, kj::mv(request)); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { - if (interfaceId == typeId>() && methodId == 0) { - KJ_IF_MAYBE(g, connectionState->gateway) { - // Wait, this is a call to Persistent.save() and we need to translate it through our - // gateway. - auto params = context->getParams().getAs::SaveParams>(); - - auto requestSize = params.totalSize(); - ++requestSize.capCount; - requestSize.wordCount += sizeInWords::ImportParams>(); - - auto request = g->importRequest(requestSize); - request.setCap(Persistent<>::Client(kj::refcounted(*this))); - request.setParams(params); - - context->allowCancellation(); - context->releaseParams(); - return context->directTailCall(RequestHook::from(kj::mv(request))); - } - } - - return callNoIntercept(interfaceId, methodId, kj::mv(context)); + kj::Own&& context, CallHints hints) override { + return callNoIntercept(interfaceId, methodId, kj::mv(context), hints); } VoidPromiseAndPipeline callNoIntercept(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) { + kj::Own&& context, CallHints hints) { // Implement call() by copying params and results messages. auto params = context->getParams(); - auto request = newCallNoIntercept(interfaceId, methodId, params.targetSize()); + auto request = newCallNoIntercept(interfaceId, methodId, params.targetSize(), hints); request.set(params); context->releaseParams(); - // We can and should propagate cancellation. - context->allowCancellation(); - return context->directTailCall(RequestHook::from(kj::mv(request))); } @@ -704,14 +839,18 @@ private: } kj::Own connectionState; + + kj::Maybe> flowController; + // Becomes non-null the first time a streaming call is made on this capability. }; class ImportClient final: public RpcClient { // A ClientHook that wraps an entry in the import table. public: - ImportClient(RpcConnectionState& connectionState, ImportId importId) - : RpcClient(connectionState), importId(importId) {} + ImportClient(RpcConnectionState& connectionState, ImportId importId, + kj::Maybe fd) + : RpcClient(connectionState), importId(importId), fd(kj::mv(fd)) {} ~ImportClient() noexcept(false) { unwindDetector.catchExceptionsIfUnwinding([&]() { @@ -736,12 +875,19 @@ private: }); } + void setFdIfMissing(kj::Maybe newFd) { + if (fd == nullptr) { + fd = kj::mv(newFd); + } + } + void addRemoteRef() { // Add a new RemoteRef and return a new ref to this client representing it. ++remoteRefcount; } - kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor) override { + kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor, + kj::Vector& fds) override { descriptor.setReceiverHosted(importId); return nullptr; } @@ -766,8 +912,13 @@ private: return nullptr; } + kj::Maybe getFd() override { + return fd.map([](auto& f) { return f.get(); }); + } + private: ImportId importId; + kj::Maybe fd; uint remoteRefcount = 0; // Number of times we've received this import from the peer. @@ -784,7 +935,8 @@ private: kj::Array&& ops) : RpcClient(connectionState), questionRef(kj::mv(questionRef)), ops(kj::mv(ops)) {} - kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor) override { + kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor, + kj::Vector& fds) override { auto promisedAnswer = descriptor.initReceiverAnswer(); promisedAnswer.setQuestionId(questionRef->getId()); promisedAnswer.adoptTransform(fromPipelineOps( @@ -814,6 +966,10 @@ private: return nullptr; } + kj::Maybe getFd() override { + return nullptr; + } + private: kj::Own questionRef; kj::Array ops; @@ -825,31 +981,25 @@ private: public: PromiseClient(RpcConnectionState& connectionState, - kj::Own initial, + kj::Own initial, kj::Promise> eventual, kj::Maybe importId) : RpcClient(connectionState), - isResolved(false), cap(kj::mv(initial)), importId(importId), - fork(eventual.fork()), - resolveSelfPromise(fork.addBranch().then( + fork(eventual.then( [this](kj::Own&& resolution) { - resolve(kj::mv(resolution), false); + return resolve(kj::mv(resolution)); }, [this](kj::Exception&& exception) { - resolve(newBrokenCap(kj::mv(exception)), true); - }).eagerlyEvaluate([&](kj::Exception&& e) { + return resolve(newBrokenCap(kj::mv(exception))); + }).catch_([&](kj::Exception&& e) { // Make any exceptions thrown from resolve() go to the connection's TaskSet which // will cause the connection to be terminated. - connectionState.tasks.add(kj::mv(e)); - })) { - // Create a client that starts out forwarding all calls to `initial` but, once `eventual` - // resolves, will forward there instead. In addition, `whenMoreResolved()` will return a fork - // of `eventual`. Note that this means the application could hold on to `eventual` even after - // the `PromiseClient` is destroyed; `eventual` must therefore make sure to hold references to - // anything that needs to stay alive in order to resolve it correctly (such as making sure the - // import ID is not released). - } + connectionState.tasks.add(kj::cp(e)); + return newBrokenCap(kj::mv(e)); + }).fork()) {} + // Create a client that starts out forwarding all calls to `initial` but, once `eventual` + // resolves, will forward there instead. ~PromiseClient() noexcept(false) { KJ_IF_MAYBE(id, importId) { @@ -867,9 +1017,10 @@ private: } } - kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor) override { + kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor, + kj::Vector& fds) override { receivedCall = true; - return connectionState->writeDescriptor(*cap, descriptor); + return connectionState->writeDescriptor(*cap, descriptor, fds); } kj::Maybe> writeTarget( @@ -883,52 +1034,38 @@ private: return connectionState->getInnermostClient(*cap); } + void adoptFlowController(kj::Own flowController) override { + if (cap->getBrand() == connectionState.get()) { + // Pass the flow controller on to our inner cap. + kj::downcast(*cap).adoptFlowController(kj::mv(flowController)); + } else { + // We resolved to a capability that isn't another RPC capability. We should simply make + // sure that all the calls complete. + connectionState->tasks.add(flowController->waitAllAcked().attach(kj::mv(flowController))); + } + } + // implements ClientHook ----------------------------------------- Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { - if (!isResolved && interfaceId == typeId>() && methodId == 0 && - connectionState->gateway != nullptr) { - // This is a call to Persistent.save(), and we're not resolved yet, and the underlying - // remote capability will perform a gateway translation. This isn't right if the promise - // ultimately resolves to a local capability. Instead, we'll need to queue the call until - // the promise resolves. - return newLocalPromiseClient(fork.addBranch()) - ->newCall(interfaceId, methodId, sizeHint); - } - + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { receivedCall = true; - return cap->newCall(interfaceId, methodId, sizeHint); + + // IMPORTANT: We must call our superclass's version of newCall(), NOT cap->newCall(), because + // the Request object we create needs to check at send() time whether the promise has + // resolved and, if so, redirect to the new target. + return RpcClient::newCall(interfaceId, methodId, sizeHint, hints); } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { - if (!isResolved && interfaceId == typeId>() && methodId == 0 && - connectionState->gateway != nullptr) { - // This is a call to Persistent.save(), and we're not resolved yet, and the underlying - // remote capability will perform a gateway translation. This isn't right if the promise - // ultimately resolves to a local capability. Instead, we'll need to queue the call until - // the promise resolves. - - auto vpapPromises = fork.addBranch().then(kj::mvCapture(context, - [interfaceId,methodId](kj::Own&& context, - kj::Own resolvedCap) { - auto vpap = resolvedCap->call(interfaceId, methodId, kj::mv(context)); - return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline)); - })).split(); - - return { - kj::mv(kj::get<0>(vpapPromises)), - newLocalPromisePipeline(kj::mv(kj::get<1>(vpapPromises))), - }; - } - + kj::Own&& context, CallHints hints) override { receivedCall = true; - return cap->call(interfaceId, methodId, kj::mv(context)); + return cap->call(interfaceId, methodId, kj::mv(context), hints); } kj::Maybe getResolved() override { - if (isResolved) { + if (isResolved()) { return *cap; } else { return nullptr; @@ -939,24 +1076,133 @@ private: return fork.addBranch(); } + kj::Maybe getFd() override { + if (isResolved()) { + return cap->getFd(); + } else { + // In theory, before resolution, the ImportClient for the promise could have an FD + // attached, if the promise itself was presented with an attached FD. However, we can't + // really return that one here because it may be closed when we get the Resolve message + // later. In theory we could have the PromiseClient itself take ownership of an FD that + // arrived attached to a promise cap, but the use case for that is questionable. I'm + // keeping it simple for now. + return nullptr; + } + } + private: - bool isResolved; kj::Own cap; kj::Maybe importId; kj::ForkedPromise> fork; - // Keep this last, because the continuation uses *this, so it should be destroyed first to - // ensure the continuation is not still running. - kj::Promise resolveSelfPromise; - bool receivedCall = false; - void resolve(kj::Own replacement, bool isError) { + enum { + UNRESOLVED, + // Not resolved at all yet. + + REMOTE, + // Remote promise resolved to a remote settled capability (or null/error). + + REFLECTED, + // Remote promise resolved to one of our own exports. + + MERGED, + // Remote promise resolved to another remote promise which itself wasn't resolved yet, so we + // merged them. In this case, `cap` is guaranteed to point to another PromiseClient. + + BROKEN + // Resolved to null or error. + } resolutionType = UNRESOLVED; + + inline bool isResolved() { + return resolutionType != UNRESOLVED; + } + + kj::Promise> resolve(kj::Own replacement) { + KJ_DASSERT(!isResolved()); + const void* replacementBrand = replacement->getBrand(); - if (replacementBrand != connectionState.get() && - replacementBrand != &ClientHook::NULL_CAPABILITY_BRAND && - receivedCall && !isError && connectionState->connection.is()) { + bool isSameConnection = replacementBrand == connectionState.get(); + if (isSameConnection) { + // We resolved to some other RPC capability hosted by the same peer. + KJ_IF_MAYBE(promise, replacement->whenMoreResolved()) { + // We resolved to another remote promise. If *that* promise eventually resolves back + // to us, we'll need a disembargo. Possibilities: + // 1. The other promise hasn't resolved at all yet. In that case we can simply set its + // `receivedCall` flag and let it handle the disembargo later. + // 2. The other promise has received a Resolve message and decided to initiate a + // disembargo which it is still waiting for. In that case we will certainly also need + // a disembargo for the same reason that the other promise did. And, we can't simply + // wait for their disembargo; we need to start a new one of our own. + // 3. The other promise has resolved already (with or without a disembargo). In this + // case we should treat it as if we resolved directly to the other promise's result, + // possibly requiring a disembargo under the same conditions. + + // We know the other object is a PromiseClient because it's the only ClientHook + // type in the RPC implementation which returns non-null for `whenMoreResolved()`. + PromiseClient* other = &kj::downcast(*replacement); + while (other->resolutionType == MERGED) { + // There's no need to resolve to a thing that's just going to resolve to another thing. + replacement = other->cap->addRef(); + other = &kj::downcast(*replacement); + + // Note that replacementBrand is unchanged since we'd only merge with other + // PromiseClients on the same connection. + KJ_DASSERT(replacement->getBrand() == replacementBrand); + } + + if (other->isResolved()) { + // The other capability resolved already. If it determined that it resolved as + // reflected, then we determine the same. + resolutionType = other->resolutionType; + } else { + // The other capability hasn't resolved yet, so we can safely merge with it and do a + // single combined disembargo if needed later. + other->receivedCall = other->receivedCall || receivedCall; + resolutionType = MERGED; + } + } else { + resolutionType = REMOTE; + } + } else { + if (replacementBrand == &ClientHook::NULL_CAPABILITY_BRAND || + replacementBrand == &ClientHook::BROKEN_CAPABILITY_BRAND) { + // We don't consider null or broken capabilities as "reflected" because they may have + // been communicated to us literally as a null pointer or an exception on the wire, + // rather than as a reference to one of our exports, in which case a disembargo won't + // work. But also, call ordering is completely irrelevant with these so there's no need + // to disembargo anyway. + resolutionType = BROKEN; + } else { + resolutionType = REFLECTED; + } + } + + // Every branch above ends by setting resolutionType to something other than UNRESOLVED. + KJ_DASSERT(isResolved()); + + // If the original capability was used for streaming calls, it will have a + // `flowController` that might still be shepherding those calls. We'll need make sure that + // it doesn't get thrown away. Note that we know that *cap is an RpcClient because resolve() + // is only called once and our constructor required that the initial capability is an + // RpcClient. + KJ_IF_MAYBE(f, kj::downcast(*cap).flowController) { + if (isSameConnection) { + // The new target is on the same connection. It would make a lot of sense to keep using + // the same flow controller if possible. + kj::downcast(*replacement).adoptFlowController(kj::mv(*f)); + } else { + // The new target is something else. The best we can do is wait for the controller to + // drain. New calls will be flow-controlled in a new way without knowing about the old + // controller. + connectionState->tasks.add(f->get()->waitAllAcked().attach(kj::mv(*f))); + } + } + + if (resolutionType == REFLECTED && receivedCall && + connectionState->connection.is()) { // The new capability is hosted locally, not on the remote machine. And, we had made calls // to the promise. We need to make sure those calls echo back to us before we allow new // calls to go directly to the local capability, so we need to set a local embargo and send @@ -982,10 +1228,9 @@ private: embargo.fulfiller = kj::mv(paf.fulfiller); // Make a promise which resolves to `replacement` as soon as the `Disembargo` comes back. - auto embargoPromise = paf.promise.then( - kj::mvCapture(replacement, [](kj::Own&& replacement) { - return kj::mv(replacement); - })); + auto embargoPromise = paf.promise.then([replacement = kj::mv(replacement)]() mutable { + return kj::mv(replacement); + }); // We need to queue up calls in the meantime, so we'll resolve ourselves to a local promise // client instead. @@ -995,61 +1240,14 @@ private: message->send(); } - cap = kj::mv(replacement); - isResolved = true; - } - }; - - class NoInterceptClient final: public RpcClient { - // A wrapper around an RpcClient which bypasses special handling of "save" requests. When we - // intercept a "save" request and invoke a RealmGateway, we give it a version of the capability - // with intercepting disabled, since usually the first thing the RealmGateway will do is turn - // around and call save() again. - // - // This is admittedly sort of backwards: the interception of "save" ought to be the part - // implemented by a wrapper. However, that would require placing a wrapper around every - // RpcClient we create whereas NoInterceptClient only needs to be injected after a save() - // request occurs and is intercepted. + cap = replacement->addRef(); - public: - NoInterceptClient(RpcClient& inner) - : RpcClient(*inner.connectionState), - inner(kj::addRef(inner)) {} - - kj::Maybe writeDescriptor(rpc::CapDescriptor::Builder descriptor) override { - return inner->writeDescriptor(descriptor); + return kj::mv(replacement); } - - kj::Maybe> writeTarget(rpc::MessageTarget::Builder target) override { - return inner->writeTarget(target); - } - - kj::Own getInnermostClient() override { - return inner->getInnermostClient(); - } - - Request newCall( - uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint) override { - return inner->newCallNoIntercept(interfaceId, methodId, sizeHint); - } - VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, - kj::Own&& context) override { - return inner->callNoIntercept(interfaceId, methodId, kj::mv(context)); - } - - kj::Maybe getResolved() override { - return nullptr; - } - - kj::Maybe>> whenMoreResolved() override { - return nullptr; - } - - private: - kj::Own inner; }; - kj::Maybe writeDescriptor(ClientHook& cap, rpc::CapDescriptor::Builder descriptor) { + kj::Maybe writeDescriptor(ClientHook& cap, rpc::CapDescriptor::Builder descriptor, + kj::Vector& fds) { // Write a descriptor for the given capability. // Find the innermost wrapped capability. @@ -1062,15 +1260,24 @@ private: } } + KJ_IF_MAYBE(fd, inner->getFd()) { + descriptor.setAttachedFd(fds.size()); + fds.add(kj::mv(*fd)); + } + if (inner->getBrand() == this) { - return kj::downcast(*inner).writeDescriptor(descriptor); + return kj::downcast(*inner).writeDescriptor(descriptor, fds); } else { auto iter = exportsByCap.find(inner); if (iter != exportsByCap.end()) { // We've already seen and exported this capability before. Just up the refcount. auto& exp = KJ_ASSERT_NONNULL(exports.find(iter->second)); ++exp.refcount; - descriptor.setSenderHosted(iter->second); + if (exp.resolveOp == nullptr) { + descriptor.setSenderHosted(iter->second); + } else { + descriptor.setSenderPromise(iter->second); + } return iter->second; } else { // This is the first time we've seen this capability. @@ -1094,12 +1301,17 @@ private: } kj::Array writeDescriptors(kj::ArrayPtr>> capTable, - rpc::Payload::Builder payload) { + rpc::Payload::Builder payload, kj::Vector& fds) { + if (capTable.size() == 0) { + // Calling initCapTable(0) will still allocate a 1-word tag, which we'd like to avoid... + return nullptr; + } + auto capTableBuilder = payload.initCapTable(capTable.size()); kj::Vector exports(capTable.size()); for (uint i: kj::indices(capTable)) { KJ_IF_MAYBE(cap, capTable[i]) { - KJ_IF_MAYBE(exportId, writeDescriptor(**cap, capTableBuilder[i])) { + KJ_IF_MAYBE(exportId, writeDescriptor(**cap, capTableBuilder[i], fds)) { exports.add(*exportId); } } else { @@ -1199,7 +1411,9 @@ private: messageSizeHint() + sizeInWords() + 16); auto resolve = message->getBody().initAs().initResolve(); resolve.setPromiseId(exportId); - writeDescriptor(*exp.clientHook, resolve.initCap()); + kj::Vector fds; + writeDescriptor(*exp.clientHook, resolve.initCap(), fds); + message->setFds(fds.releaseAsArray()); message->send(); return kj::READY_NOW; @@ -1217,10 +1431,14 @@ private: }); } + void fromException(const kj::Exception& exception, rpc::Exception::Builder builder) { + _::fromException(exception, builder, traceEncoder); + } + // ===================================================================================== // Interpreting CapDescriptor - kj::Own import(ImportId importId, bool isPromise) { + kj::Own import(ImportId importId, bool isPromise, kj::Maybe fd) { // Receive a new import. auto& import = imports[importId]; @@ -1229,8 +1447,17 @@ private: // Create the ImportClient, or if one already exists, use it. KJ_IF_MAYBE(c, import.importClient) { importClient = kj::addRef(*c); + + // If the same import is introduced multiple times, and it is missing an FD the first time, + // but it has one on a later attempt, we want to attach the later one. This could happen + // because the first introduction was part of a message that had too many other FDs and went + // over the per-message limit. Perhaps the protocol design is such that this other message + // doesn't really care if the FDs are transferred or not, but the later message really does + // care; it would be bad if the previous message blocked later messages from delivering the + // FD just because it happened to reference the same capability. + importClient->setFdIfMissing(kj::mv(fd)); } else { - importClient = kj::refcounted(*this, importId); + importClient = kj::refcounted(*this, importId, kj::mv(fd)); import.importClient = *importClient; } @@ -1262,19 +1489,115 @@ private: } } - kj::Maybe> receiveCap(rpc::CapDescriptor::Reader descriptor) { + class TribbleRaceBlocker: public ClientHook, public kj::Refcounted { + // Hack to work around a problem that arises during the Tribble 4-way Race Condition as + // described in rpc.capnp in the documentation for the `Disembargo` message. + // + // Consider a remote promise that is resolved by a `Resolve` message. PromiseClient::resolve() + // is eventually called and given the `ClientHook` for the resolution. Imagine that the + // `ClientHook` it receives turns out to be an `ImportClient`. There are two ways this could + // have happened: + // + // 1. The `Resolve` message contained a `CapDescriptor` of type `senderHosted`, naming an entry + // in the sender's export table, and the `ImportClient` refers to the corresponding slot on + // the receiver's import table. In this case, no embargo is needed, because messages to the + // resolved location traverse the same path as messages to the promise would have. + // + // 2. The `Resolve` message contained a `CapDescriptor` of type `receiverHosted`, naming an + // entry in the receiver's export table. That entry just happened to contain an + // `ImportClient` referring back to the sender. This specifically happens when the entry + // in question had previously itself referred to a promise, and that promise has since + // resolved to a remote capability, at which point the export table entry was replaced by + // the appropriate `ImportClient` representing that. Presumably, the peer *did not yet know* + // about this resolution, which is why it sent a `receiverHosted` pointing to something that + // reflects back to the sender, rather than sending `senderHosted` in the first place. + // + // In this case, an embargo *is* required, because peer may still be reflecting messages + // sent to this promise back to us. In fact, the peer *must* continue reflecting messages, + // even when it eventually learns that the eventual destination is one of its own + // capabilities, due to the Tribble 4-way Race Condition rule. + // + // Since this case requires an embargo, somehow PromiseClient::resolve() must be able to + // distinguish it from the case (1). One solution would be for us to pass some extra flag + // all the way from where the `Resolve` messages is received to `PromiseClient::resolve()`. + // That solution is reasonably easy in the `Resolve` case, but gets notably more difficult + // in the case of `Return`s, which also resolve promises and are subject to all the same + // problems. In the case of a `Return`, some non-RPC-specific code is involved in the + // resolution, making it harder to pass along a flag. + // + // Instead, we use this hack: When we read an entry in the export table and discover that + // it actually contains an `ImportClient` or a `PipelineClient` reflecting back over our + // own connection, then we wrap it in a `TribbleRaceBlocker`. This wrapper prevents + // `PromiseClient` from recognizing the capability as being remote, so it instead treats it + // as local. That causes it to set up an embargo as desired. + // + // TODO(perf): This actually blocks further promise resolution in the case where the + // ImportClient or PipelineClient itself ends up being yet another promise that resolves + // back over the connection again. What we probably really need to do here is, instead of + // placing `ImportClient` or `PipelineClient` on the export table, place a special type there + // that both knows what to do with future incoming messages to that export ID, but also knows + // what to do when that export is the subject of a `Resolve`. + + public: + TribbleRaceBlocker(kj::Own inner): inner(kj::mv(inner)) {} + + Request newCall( + uint64_t interfaceId, uint16_t methodId, kj::Maybe sizeHint, + CallHints hints) override { + return inner->newCall(interfaceId, methodId, sizeHint, hints); + } + VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, + kj::Own&& context, CallHints hints) override { + return inner->call(interfaceId, methodId, kj::mv(context), hints); + } + kj::Maybe getResolved() override { + // We always wrap either PipelineClient or ImportClient, both of which return null for this + // anyway. + return nullptr; + } + kj::Maybe>> whenMoreResolved() override { + // We always wrap either PipelineClient or ImportClient, both of which return null for this + // anyway. + return nullptr; + } + kj::Own addRef() override { + return kj::addRef(*this); + } + const void* getBrand() override { + return nullptr; + } + kj::Maybe getFd() override { + return inner->getFd(); + } + + private: + kj::Own inner; + }; + + kj::Maybe> receiveCap(rpc::CapDescriptor::Reader descriptor, + kj::ArrayPtr fds) { + uint fdIndex = descriptor.getAttachedFd(); + kj::Maybe fd; + if (fdIndex < fds.size() && fds[fdIndex] != nullptr) { + fd = kj::mv(fds[fdIndex]); + } + switch (descriptor.which()) { case rpc::CapDescriptor::NONE: return nullptr; case rpc::CapDescriptor::SENDER_HOSTED: - return import(descriptor.getSenderHosted(), false); + return import(descriptor.getSenderHosted(), false, kj::mv(fd)); case rpc::CapDescriptor::SENDER_PROMISE: - return import(descriptor.getSenderPromise(), true); + return import(descriptor.getSenderPromise(), true, kj::mv(fd)); case rpc::CapDescriptor::RECEIVER_HOSTED: KJ_IF_MAYBE(exp, exports.find(descriptor.getReceiverHosted())) { - return exp->clientHook->addRef(); + auto result = exp->clientHook->addRef(); + if (result->getBrand() == this) { + result = kj::refcounted(kj::mv(result)); + } + return kj::mv(result); } else { return newBrokenCap("invalid 'receiverHosted' export ID"); } @@ -1286,7 +1609,11 @@ private: if (answer->active) { KJ_IF_MAYBE(pipeline, answer->pipeline) { KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) { - return pipeline->get()->getPipelinedCap(*ops); + auto result = pipeline->get()->getPipelinedCap(*ops); + if (result->getBrand() == this) { + result = kj::refcounted(kj::mv(result)); + } + return kj::mv(result); } else { return newBrokenCap("unrecognized pipeline ops"); } @@ -1299,7 +1626,7 @@ private: case rpc::CapDescriptor::THIRD_PARTY_HOSTED: // We don't support third-party caps, so use the vine instead. - return import(descriptor.getThirdPartyHosted().getVineId(), false); + return import(descriptor.getThirdPartyHosted().getVineId(), false, kj::mv(fd)); default: KJ_FAIL_REQUIRE("unknown CapDescriptor type") { break; } @@ -1307,10 +1634,11 @@ private: } } - kj::Array>> receiveCaps(List::Reader capTable) { + kj::Array>> receiveCaps(List::Reader capTable, + kj::ArrayPtr fds) { auto result = kj::heapArrayBuilder>>(capTable.size()); for (auto cap: capTable) { - result.add(receiveCap(cap)); + result.add(receiveCap(cap, fds)); } return result.finish(); } @@ -1325,26 +1653,40 @@ private: public: inline QuestionRef( RpcConnectionState& connectionState, QuestionId id, - kj::Own>>> fulfiller) + kj::Maybe>>>> fulfiller) : connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)) {} - ~QuestionRef() { - unwindDetector.catchExceptionsIfUnwinding([&]() { + ~QuestionRef() noexcept { + // Contrary to KJ style, we declare this destructor `noexcept` because if anything in here + // throws (without being caught) we're probably in pretty bad shape and going to be crashing + // later anyway. Better to abort now. + + KJ_IF_MAYBE(c, connectionState) { + auto& connectionState = *c; + auto& question = KJ_ASSERT_NONNULL( connectionState->questions.find(id), "Question ID no longer on table?"); // Send the "Finish" message (if the connection is not already broken). if (connectionState->connection.is() && !question.skipFinish) { - auto message = connectionState->connection.get()->newOutgoingMessage( - messageSizeHint()); - auto builder = message->getBody().getAs().initFinish(); - builder.setQuestionId(id); - // If we're still awaiting a return, then this request is being canceled, and we're going - // to ignore any capabilities in the return message, so set releaseResultCaps true. If we - // already received the return, then we've already built local proxies for the caps and - // will send Release messages when those are destroyed. - builder.setReleaseResultCaps(question.isAwaitingReturn); - message->send(); + KJ_IF_MAYBE(e, kj::runCatchingExceptions([&]() { + auto message = connectionState->connection.get()->newOutgoingMessage( + messageSizeHint()); + auto builder = message->getBody().getAs().initFinish(); + builder.setQuestionId(id); + // If we're still awaiting a return, then this request is being canceled, and we're going + // to ignore any capabilities in the return message, so set releaseResultCaps true. If we + // already received the return, then we've already built local proxies for the caps and + // will send Release messages when those are destroyed. + builder.setReleaseResultCaps(question.isAwaitingReturn); + + // Let the peer know we don't have the early cancellation bug. + builder.setRequireEarlyCancellationWorkaround(false); + + message->send(); + })) { + connectionState->tasks.add(kj::mv(*e)); + } } // Check if the question has returned and, if so, remove it from the table. @@ -1357,28 +1699,37 @@ private: // Call has already returned, so we can now remove it from the table. connectionState->questions.erase(id, question); } - }); + } } inline QuestionId getId() const { return id; } void fulfill(kj::Own&& response) { - fulfiller->fulfill(kj::mv(response)); + KJ_IF_MAYBE(f, fulfiller) { + f->get()->fulfill(kj::mv(response)); + } } void fulfill(kj::Promise>&& promise) { - fulfiller->fulfill(kj::mv(promise)); + KJ_IF_MAYBE(f, fulfiller) { + f->get()->fulfill(kj::mv(promise)); + } } void reject(kj::Exception&& exception) { - fulfiller->reject(kj::mv(exception)); + KJ_IF_MAYBE(f, fulfiller) { + f->get()->reject(kj::mv(exception)); + } + } + + void disconnect() { + connectionState = nullptr; } private: - kj::Own connectionState; + kj::Maybe> connectionState; QuestionId id; - kj::Own>>> fulfiller; - kj::UnwindDetector unwindDetector; + kj::Maybe>>>> fulfiller; }; class RpcRequest final: public RequestHook { @@ -1403,6 +1754,7 @@ private: RemotePromise send() override { if (!connectionState->connection.is()) { // Connection is broken. + // TODO(bug): Seems like we should check for redirect before this? const kj::Exception& e = connectionState->connection.get(); return RemotePromise( kj::Promise>(kj::cp(e)), @@ -1414,19 +1766,29 @@ private: // We'll have to make a new request and do a copy. Ick. auto replacement = redirect->get()->newCall( - callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize()); + callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize(), + callHintsFromReader(callBuilder)); replacement.set(paramsBuilder); return replacement.send(); } else { + bool noPromisePipelining = callBuilder.getNoPromisePipelining(); + auto sendResult = sendInternal(false); - auto forkedPromise = sendResult.promise.fork(); + kj::Own pipeline; + if (noPromisePipelining) { + pipeline = getDisabledPipeline(); + } else { + auto forkedPromise = sendResult.promise.fork(); + + // The pipeline must get notified of resolution before the app does to maintain ordering. + pipeline = kj::refcounted( + *connectionState, kj::mv(sendResult.questionRef), forkedPromise.addBranch()); - // The pipeline must get notified of resolution before the app does to maintain ordering. - auto pipeline = kj::refcounted( - *connectionState, kj::mv(sendResult.questionRef), forkedPromise.addBranch()); + sendResult.promise = forkedPromise.addBranch(); + } - auto appPromise = forkedPromise.addBranch().then( + auto appPromise = sendResult.promise.then( [=](kj::Own&& response) { auto reader = response->getResults(); return Response(reader, kj::mv(response)); @@ -1438,6 +1800,55 @@ private: } } + kj::Promise sendStreaming() override { + if (!connectionState->connection.is()) { + // Connection is broken. + // TODO(bug): Seems like we should check for redirect before this? + return kj::cp(connectionState->connection.get()); + } + + KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) { + // Whoops, this capability has been redirected while we were building the request! + // We'll have to make a new request and do a copy. Ick. + + auto replacement = redirect->get()->newCall( + callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize(), + callHintsFromReader(callBuilder)); + replacement.set(paramsBuilder); + return RequestHook::from(kj::mv(replacement))->sendStreaming(); + } else { + return sendStreamingInternal(false); + } + } + + AnyPointer::Pipeline sendForPipeline() override { + if (!connectionState->connection.is()) { + // Connection is broken. + // TODO(bug): Seems like we should check for redirect before this? + const kj::Exception& e = connectionState->connection.get(); + return AnyPointer::Pipeline(newBrokenPipeline(kj::cp(e))); + } + + KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) { + // Whoops, this capability has been redirected while we were building the request! + // We'll have to make a new request and do a copy. Ick. + + auto replacement = redirect->get()->newCall( + callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize(), + callHintsFromReader(callBuilder)); + replacement.set(paramsBuilder); + return replacement.sendForPipeline(); + } else if (connectionState->gotReturnForHighQuestionId) { + // Peer doesn't implement our hints. Fall back to a regular send(). + return send(); + } else { + auto questionRef = sendForPipelineInternal(); + kj::Own pipeline = kj::refcounted( + *connectionState, kj::mv(questionRef)); + return AnyPointer::Pipeline(kj::mv(pipeline)); + } + } + struct TailInfo { QuestionId questionId; kj::Promise promise; @@ -1472,7 +1883,13 @@ private: QuestionId questionId = sendResult.questionRef->getId(); - auto pipeline = kj::refcounted(*connectionState, kj::mv(sendResult.questionRef)); + kj::Own pipeline; + bool noPromisePipelining = callBuilder.getNoPromisePipelining(); + if (noPromisePipelining) { + pipeline = getDisabledPipeline(); + } else { + pipeline = kj::refcounted(*connectionState, kj::mv(sendResult.questionRef)); + } return TailInfo { questionId, kj::mv(promise), kj::mv(pipeline) }; } @@ -1495,10 +1912,21 @@ private: kj::Promise> promise = nullptr; }; - SendInternalResult sendInternal(bool isTailCall) { + struct SetupSendResult: public SendInternalResult { + QuestionId questionId; + Question& question; + + SetupSendResult(SendInternalResult&& super, QuestionId questionId, Question& question) + : SendInternalResult(kj::mv(super)), questionId(questionId), question(question) {} + // TODO(cleanup): This constructor is implicit in C++17. + }; + + SetupSendResult setupSend(bool isTailCall) { // Build the cap table. + kj::Vector fds; auto exports = connectionState->writeDescriptors( - capTable.getTable(), callBuilder.getParams()); + capTable.getTable(), callBuilder.getParams(), fds); + message->setFds(fds.releaseAsArray()); // Init the question table. Do this after writing descriptors to avoid interference. QuestionId questionId; @@ -1515,8 +1943,14 @@ private: question.selfRef = *result.questionRef; result.promise = paf.promise.attach(kj::addRef(*result.questionRef)); + return { kj::mv(result), questionId, question }; + } + + SendInternalResult sendInternal(bool isTailCall) { + auto result = setupSend(isTailCall); + // Finish and send. - callBuilder.setQuestionId(questionId); + callBuilder.setQuestionId(result.questionId); if (isTailCall) { callBuilder.getSendResultsTo().setYourself(); } @@ -1527,14 +1961,91 @@ private: })) { // We can't safely throw the exception from here since we've already modified the question // table state. We'll have to reject the promise instead. - question.isAwaitingReturn = false; - question.skipFinish = true; + // TODO(bug): Attempts to use the pipeline will end up sending a request referencing a + // bogus question ID. Can we rethrow after doing the appropriate cleanup, so the pipeline + // is never created? See the approach in sendForPipelineInternal() below. + result.question.isAwaitingReturn = false; + result.question.skipFinish = true; + connectionState->releaseExports(result.question.paramExports); result.questionRef->reject(kj::mv(*exception)); } // Send and return. return kj::mv(result); } + + kj::Promise sendStreamingInternal(bool isTailCall) { + auto setup = setupSend(isTailCall); + + // Finish and send. + callBuilder.setQuestionId(setup.questionId); + if (isTailCall) { + callBuilder.getSendResultsTo().setYourself(); + } + kj::Promise flowPromise = nullptr; + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + KJ_CONTEXT("sending RPC call", + callBuilder.getInterfaceId(), callBuilder.getMethodId()); + RpcFlowController* flow; + KJ_IF_MAYBE(f, target->flowController) { + flow = *f; + } else { + flow = target->flowController.emplace( + connectionState->connection.get()->newStream()); + } + flowPromise = flow->send(kj::mv(message), setup.promise.ignoreResult()); + })) { + // We can't safely throw the exception from here since we've already modified the question + // table state. We'll have to reject the promise instead. + setup.question.isAwaitingReturn = false; + setup.question.skipFinish = true; + setup.questionRef->reject(kj::cp(*exception)); + return kj::mv(*exception); + } + + return kj::mv(flowPromise); + } + + kj::Own sendForPipelineInternal() { + // Since must of setupSend() is subtly different for this case, we don't reuse it. + + // Build the cap table. + kj::Vector fds; + auto exports = connectionState->writeDescriptors( + capTable.getTable(), callBuilder.getParams(), fds); + message->setFds(fds.releaseAsArray()); + + if (exports.size() > 0) { + connectionState->sentCapabilitiesInPipelineOnlyCall = true; + } + + // Init the question table. Do this after writing descriptors to avoid interference. + QuestionId questionId; + auto& question = connectionState->questions.nextHigh(questionId); + question.isAwaitingReturn = false; // No Return needed + question.paramExports = kj::mv(exports); + question.isTailCall = false; + + // Make the QuentionRef and result promise. + auto questionRef = kj::refcounted(*connectionState, questionId, nullptr); + question.selfRef = *questionRef; + + // If sending throws, we'll need to fix up the state a little... + KJ_ON_SCOPE_FAILURE({ + question.skipFinish = true; + connectionState->releaseExports(question.paramExports); + }); + + // Finish and send. + callBuilder.setQuestionId(questionId); + callBuilder.setOnlyPromisePipeline(true); + + KJ_CONTEXT("sending RPC call", + callBuilder.getInterfaceId(), callBuilder.getMethodId()); + message->send(); + + return kj::mv(questionRef); + } }; class RpcPipeline final: public PipelineHook, public kj::Refcounted { @@ -1581,28 +2092,40 @@ private: } kj::Own getPipelinedCap(kj::Array&& ops) override { - if (state.is()) { - // Wrap a PipelineClient in a PromiseClient. - auto pipelineClient = kj::refcounted( - *connectionState, kj::addRef(*state.get()), kj::heapArray(ops.asPtr())); - - KJ_IF_MAYBE(r, redirectLater) { - auto resolutionPromise = r->addBranch().then(kj::mvCapture(ops, - [](kj::Array ops, kj::Own&& response) { - return response->getResults().getPipelinedCap(ops); - })); - - return kj::refcounted( - *connectionState, kj::mv(pipelineClient), kj::mv(resolutionPromise), nullptr); + return clientMap.findOrCreate(ops.asPtr(), [&]() { + if (state.is()) { + // Wrap a PipelineClient in a PromiseClient. + auto pipelineClient = kj::refcounted( + *connectionState, kj::addRef(*state.get()), kj::heapArray(ops.asPtr())); + + KJ_IF_MAYBE(r, redirectLater) { + auto resolutionPromise = r->addBranch().then( + [ops = kj::heapArray(ops.asPtr())](kj::Own&& response) { + return response->getResults().getPipelinedCap(kj::mv(ops)); + }); + + return kj::HashMap, kj::Own>::Entry { + kj::mv(ops), + kj::refcounted( + *connectionState, kj::mv(pipelineClient), kj::mv(resolutionPromise), nullptr) + }; + } else { + // Oh, this pipeline will never get redirected, so just return the PipelineClient. + return kj::HashMap, kj::Own>::Entry { + kj::mv(ops), kj::mv(pipelineClient) + }; + } + } else if (state.is()) { + auto pipelineClient = state.get()->getResults().getPipelinedCap(ops); + return kj::HashMap, kj::Own>::Entry { + kj::mv(ops), kj::mv(pipelineClient) + }; } else { - // Oh, this pipeline will never get redirected, so just return the PipelineClient. - return kj::mv(pipelineClient); + return kj::HashMap, kj::Own>::Entry { + kj::mv(ops), newBrokenCap(kj::cp(state.get())) + }; } - } else if (state.is()) { - return state.get()->getResults().getPipelinedCap(ops); - } else { - return newBrokenCap(kj::cp(state.get())); - } + })->addRef(); } private: @@ -1614,6 +2137,12 @@ private: typedef kj::Exception Broken; kj::OneOf state; + kj::HashMap, kj::Own> clientMap; + // See QueuedPipeline::clientMap in capability.c++ for a discussion of why we must memoize + // the results of getPipelinedCap(). RpcPipeline has a similar problem when a capability we + // return is later subject to an embargo. It's important that the embargo is correctly applied + // across all calls to the same capability. + // Keep this last, because the continuation uses *this, so it should be destroyed first to // ensure the continuation is not still running. kj::Promise resolveSelfPromise; @@ -1685,22 +2214,30 @@ private: return capTable.imbue(payload.getContent()); } + inline bool hasCapabilities() { + return capTable.getTable().size() > 0; + } + kj::Maybe> send() { // Send the response and return the export list. Returns nullptr if there were no caps. // (Could return a non-null empty array if there were caps but none of them were exports.) // Build the cap table. auto capTable = this->capTable.getTable(); - auto exports = connectionState.writeDescriptors(capTable, payload); + kj::Vector fds; + auto exports = connectionState.writeDescriptors(capTable, payload, fds); + message->setFds(fds.releaseAsArray()); - // Capabilities that we are returning are subject to embargos. See `Disembargo` in rpc.capnp. - // As explained there, in order to deal with the Tribble 4-way race condition, we need to - // make sure that if we're returning any remote promises, that we ignore any subsequent - // resolution of those promises for the purpose of pipelined requests on this answer. Luckily, - // we can modify the cap table in-place. + // Populate `resolutionsAtReturnTime`. for (auto& slot: capTable) { KJ_IF_MAYBE(cap, slot) { - slot = connectionState.getInnermostClient(**cap); + auto inner = connectionState.getInnermostClient(**cap); + if (inner.get() != cap->get()) { + resolutionsAtReturnTime.upsert(cap->get(), kj::mv(inner), + [&](kj::Own& existing, kj::Own&& replacement) { + KJ_ASSERT(existing.get() == replacement.get()); + }); + } } } @@ -1712,11 +2249,40 @@ private: } } + struct Resolution { + kj::Own returnedCap; + // The capabiilty that appeared in the response message in this slot. + + kj::Own unwrapped; + // Exactly what `getInnermostClient(returnedCap)` produced at the time that the return + // message was encoded. + }; + + Resolution getResolutionAtReturnTime(kj::ArrayPtr ops) { + auto returnedCap = getResultsBuilder().asReader().getPipelinedCap(ops); + kj::Own unwrapped; + KJ_IF_MAYBE(u, resolutionsAtReturnTime.find(returnedCap.get())) { + unwrapped = u->get()->addRef(); + } else { + unwrapped = returnedCap->addRef(); + } + return { kj::mv(returnedCap), kj::mv(unwrapped) }; + } + private: RpcConnectionState& connectionState; kj::Own message; BuilderCapabilityTable capTable; rpc::Payload::Builder payload; + + kj::HashMap> resolutionsAtReturnTime; + // For each capability in `capTable` as of the time when the call returned, this map stores + // the result of calling `getInnermostClient()` on that capability. This is needed in order + // to solve the Tribble 4-way race condition described in the documentation for `Disembargo` + // in `rpc.capnp`. `PostReturnRpcPipeline`, below, uses this. + // + // As an optimization, if the innermost client is exactly the same object then nothing is + // stored in the map. }; class LocallyRedirectedRpcResponse final @@ -1742,25 +2308,95 @@ private: MallocMessageBuilder message; }; + class PostReturnRpcPipeline final: public PipelineHook, public kj::Refcounted { + // Once an incoming call has returned, we may need to replace the `PipelineHook` with one that + // correctly handles the Tribble 4-way race condition. Namely, we must ensure that if the + // response contained any capabilities pointing back out to the network, then any further + // pipelined calls received targetting those capabilities (as well as any Disembargo messages) + // will resolve to the same network capability forever, *even if* that network capability is + // itself a promise which later resolves to somewhere else. + public: + PostReturnRpcPipeline(kj::Own inner, + RpcServerResponseImpl& response, + kj::Own context) + : inner(kj::mv(inner)), response(response), context(kj::mv(context)) {} + + kj::Own addRef() override { + return kj::addRef(*this); + } + + kj::Own getPipelinedCap(kj::ArrayPtr ops) override { + auto resolved = response.getResolutionAtReturnTime(ops); + auto original = inner->getPipelinedCap(ops); + return getResolutionAtReturnTime(kj::mv(original), kj::mv(resolved)); + } + + kj::Own getPipelinedCap(kj::Array&& ops) override { + auto resolved = response.getResolutionAtReturnTime(ops); + auto original = inner->getPipelinedCap(kj::mv(ops)); + return getResolutionAtReturnTime(kj::mv(original), kj::mv(resolved)); + } + + private: + kj::Own inner; + RpcServerResponseImpl& response; + kj::Own context; // owns `response` + + kj::Own getResolutionAtReturnTime( + kj::Own original, RpcServerResponseImpl::Resolution resolution) { + // Wait for `original` to resolve to `resolution.returnedCap`, then return + // `resolution.unwrapped`. + + ClientHook* ptr = original.get(); + for (;;) { + if (ptr == resolution.returnedCap.get()) { + return kj::mv(resolution.unwrapped); + } else KJ_IF_MAYBE(r, ptr->getResolved()) { + ptr = r; + } else { + break; + } + } + + KJ_IF_MAYBE(p, ptr->whenMoreResolved()) { + return newLocalPromiseClient(p->then( + [this, original = kj::mv(original), resolution = kj::mv(resolution)] + (kj::Own r) mutable { + return getResolutionAtReturnTime(kj::mv(r), kj::mv(resolution)); + })); + } else if (ptr->isError() || ptr->isNull()) { + // This is already a broken capability, the error probably explains what went wrong. In + // any case, message ordering is irrelevant here since all calls will throw anyway. + return ptr->addRef(); + } else { + return newBrokenCap( + "An RPC call's capnp::PipelineHook object resolved a pipelined capability to a " + "different final object than what was returned in the actual response. This could " + "be a bug in Cap'n Proto, or could be due to a use of context.setPipeline() that " + "was inconsistent with the later results."); + } + } + }; + class RpcCallContext final: public CallContextHook, public kj::Refcounted { public: RpcCallContext(RpcConnectionState& connectionState, AnswerId answerId, kj::Own&& request, kj::Array>> capTableArray, const AnyPointer::Reader& params, - bool redirectResults, kj::Own>&& cancelFulfiller, - uint64_t interfaceId, uint16_t methodId) + bool redirectResults, uint64_t interfaceId, uint16_t methodId, + ClientHook::CallHints hints) : connectionState(kj::addRef(connectionState)), answerId(answerId), + hints(hints), interfaceId(interfaceId), methodId(methodId), - requestSize(request->getBody().targetSize().wordCount), + requestSize(request->sizeInWords()), request(kj::mv(request)), paramsCapTable(kj::mv(capTableArray)), params(paramsCapTable.imbue(params)), returnMessage(nullptr), - redirectResults(redirectResults), - cancelFulfiller(kj::mv(cancelFulfiller)) { + redirectResults(redirectResults) { connectionState.callWordsInFlight += requestSize; } @@ -1768,8 +2404,10 @@ private: if (isFirstResponder()) { // We haven't sent a return yet, so we must have been canceled. Send a cancellation return. unwindDetector.catchExceptionsIfUnwinding([&]() { - // Don't send anything if the connection is broken. - if (connectionState->connection.is()) { + // Don't send anything if the connection is broken, or if the onlyPromisePipeline hint + // was used (in which case the caller doesn't care to receive a `Return`). + bool shouldFreePipeline = true; + if (connectionState->connection.is() && !hints.onlyPromisePipeline) { auto message = connectionState->connection.get()->newOutgoingMessage( messageSizeHint() + sizeInWords()); auto builder = message->getBody().initAs().initReturn(); @@ -1781,6 +2419,9 @@ private: // The reason we haven't sent a return is because the results were sent somewhere // else. builder.setResultsSentElsewhere(); + + // The pipeline could still be valid and in-use in this case. + shouldFreePipeline = false; } else { builder.setCanceled(); } @@ -1788,7 +2429,7 @@ private: message->send(); } - cleanupAnswerTable(nullptr, true); + cleanupAnswerTable(nullptr, shouldFreePipeline); }); } } @@ -1805,10 +2446,11 @@ private: void sendReturn() { KJ_ASSERT(!redirectResults); + KJ_ASSERT(!hints.onlyPromisePipeline); // Avoid sending results if canceled so that we don't have to figure out whether or not // `releaseResultCaps` was set in the already-received `Finish`. - if (!(cancellationFlags & CANCEL_REQUESTED) && isFirstResponder()) { + if (!receivedFinish && isFirstResponder()) { KJ_ASSERT(connectionState->connection.is(), "Cancellation should have been requested on disconnect.") { return; @@ -1819,17 +2461,43 @@ private: returnMessage.setAnswerId(answerId); returnMessage.setReleaseParamCaps(false); + auto& responseImpl = kj::downcast(*KJ_ASSERT_NONNULL(response)); + if (!responseImpl.hasCapabilities()) { + returnMessage.setNoFinishNeeded(true); + + // Tell ourselves that a finsih was already received, so that `cleanupAnswerTable()` + // removes the answer table entry. + receivedFinish = true; + + // HACK: The answer table's `task` is the thing which is calling `sendReturn()`. We can't + // cancel ourselves. However, we know calling `sendReturn()` is the last thing it does, + // so we can safely detach() it. + auto& answer = KJ_ASSERT_NONNULL(connectionState->answers.find(answerId)); + auto& selfPromise = KJ_ASSERT_NONNULL(answer.task.tryGet()); + selfPromise.detach([](kj::Exception&&) {}); + } + kj::Maybe> exports; KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { - // Debug info incase send() fails due to overside message. + // Debug info in case send() fails due to overside message. KJ_CONTEXT("returning from RPC call", interfaceId, methodId); - exports = kj::downcast(*KJ_ASSERT_NONNULL(response)).send(); + exports = responseImpl.send(); })) { responseSent = false; sendErrorReturn(kj::mv(*exception)); return; } + if (responseImpl.hasCapabilities()) { + auto& answer = KJ_ASSERT_NONNULL(connectionState->answers.find(answerId)); + // Swap out the `pipeline` in the answer table for one that will return capabilities + // consistent with whatever the result caps resolved to as of the time the return was sent. + answer.pipeline = answer.pipeline.map([&](kj::Own& inner) { + return kj::refcounted( + kj::mv(inner), responseImpl, kj::addRef(*this)); + }); + } + KJ_IF_MAYBE(e, exports) { // Caps were returned, so we can't free the pipeline yet. cleanupAnswerTable(kj::mv(*e), false); @@ -1841,6 +2509,7 @@ private: } void sendErrorReturn(kj::Exception&& exception) { KJ_ASSERT(!redirectResults); + KJ_ASSERT(!hints.onlyPromisePipeline); if (isFirstResponder()) { if (connectionState->connection.is()) { auto message = connectionState->connection.get()->newOutgoingMessage( @@ -1849,7 +2518,13 @@ private: builder.setAnswerId(answerId); builder.setReleaseParamCaps(false); - fromException(exception, builder.initException()); + connectionState->fromException(exception, builder.initException()); + + // Note that even though the response contains no capabilities, we don't want to set + // `noFinishNeeded` here because if any pipelined calls were made, we want them to + // fail with the correct exception. (Perhaps if the request had `noPromisePipelining`, + // then we could set `noFinishNeeded`, but optimizing the error case doesn't seem that + // important.) message->send(); } @@ -1859,24 +2534,35 @@ private: cleanupAnswerTable(nullptr, false); } } + void sendRedirectReturn() { + KJ_ASSERT(redirectResults); + KJ_ASSERT(!hints.onlyPromisePipeline); + + if (isFirstResponder()) { + auto message = connectionState->connection.get()->newOutgoingMessage( + messageSizeHint()); + auto builder = message->getBody().initAs().initReturn(); - void requestCancel() { - // Hints that the caller wishes to cancel this call. At the next time when cancellation is - // deemed safe, the RpcCallContext shall send a canceled Return -- or if it never becomes - // safe, the RpcCallContext will send a normal return when the call completes. Either way - // the RpcCallContext is now responsible for cleaning up the entry in the answer table, since - // a Finish message was already received. + builder.setAnswerId(answerId); + builder.setReleaseParamCaps(false); + builder.setResultsSentElsewhere(); - bool previouslyAllowedButNotRequested = cancellationFlags == CANCEL_ALLOWED; - cancellationFlags |= CANCEL_REQUESTED; + // TODO(perf): Could `noFinishNeeded` be used here? The `Finish` messages are pretty + // redundant after a redirect, but as this case is less common and more complicated I + // don't want to fully think through the implications right now. - if (previouslyAllowedButNotRequested) { - // We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Initiate - // the cancellation. - cancelFulfiller->fulfill(); + message->send(); + + cleanupAnswerTable(nullptr, false); } } + void finish() { + // Called when a `Finish` message is received while this object still exists. + + receivedFinish = true; + } + // implements CallContextHook ------------------------------------ AnyPointer::Reader getParams() override { @@ -1908,6 +2594,11 @@ private: return results; } } + void setPipeline(kj::Own&& pipeline) override { + KJ_IF_MAYBE(f, tailCallPipelineFulfiller) { + f->get()->fulfill(AnyPointer::Pipeline(kj::mv(pipeline))); + } + } kj::Promise tailCall(kj::Own&& request) override { auto result = directTailCall(kj::mv(request)); KJ_IF_MAYBE(f, tailCallPipelineFulfiller) { @@ -1919,9 +2610,13 @@ private: KJ_REQUIRE(response == nullptr, "Can't call tailCall() after initializing the results struct."); - if (request->getBrand() == connectionState.get() && !redirectResults) { + if (request->getBrand() == connectionState.get() && + !redirectResults && !hints.noPromisePipelining) { // The tail call is headed towards the peer that called us in the first place, so we can // optimize out the return trip. + // + // If the noPromisePipelining hint was sent, we skip this trick since the caller will + // ignore the `Return` message anyway. KJ_IF_MAYBE(tailInfo, kj::downcast(*request).tailSend()) { if (isFirstResponder()) { @@ -1946,6 +2641,14 @@ private: } // Just forwarding to another local call. + + if (hints.onlyPromisePipeline) { + return { + kj::NEVER_DONE, + PipelineHook::from(request->sendForPipeline()) + }; + } + auto promise = request->send(); // Wait for response. @@ -1963,16 +2666,6 @@ private: tailCallPipelineFulfiller = kj::mv(paf.fulfiller); return kj::mv(paf.promise); } - void allowCancellation() override { - bool previouslyRequestedButNotAllowed = cancellationFlags == CANCEL_REQUESTED; - cancellationFlags |= CANCEL_ALLOWED; - - if (previouslyRequestedButNotAllowed) { - // We just set CANCEL_ALLOWED, and CANCEL_REQUESTED was already set previously. Initiate - // the cancellation. - cancelFulfiller->fulfill(); - } - } kj::Own addRef() override { return kj::addRef(*this); } @@ -1981,6 +2674,8 @@ private: kj::Own connectionState; AnswerId answerId; + ClientHook::CallHints hints; + uint64_t interfaceId; uint16_t methodId; // For debugging. @@ -2002,18 +2697,9 @@ private: // Cancellation state ---------------------------------- - enum CancellationFlags { - CANCEL_REQUESTED = 1, - CANCEL_ALLOWED = 2 - }; - - uint8_t cancellationFlags = 0; - // When both flags are set, the cancellation process will begin. - - kj::Own> cancelFulfiller; - // Fulfilled when cancellation has been both requested and permitted. The fulfilled promise is - // exclusive-joined with the outermost promise waiting on the call return, so fulfilling it - // cancels that promise. + bool receivedFinish = false; + // True if a `Finish` message has been recevied OR we sent a `Return` with `noFinishNedeed`. + // In either case, it is our responsibility to clean up the answer table. kj::UnwindDetector unwindDetector; @@ -2033,7 +2719,7 @@ private: // answer table. Or we might even be responsible for removing the entire answer table // entry. - if (cancellationFlags & CANCEL_REQUESTED) { + if (receivedFinish) { // Already received `Finish` so it's our job to erase the table entry. We shouldn't have // sent results if canceled, so we shouldn't have an export list to deal with. KJ_ASSERT(resultExports.size() == 0); @@ -2083,21 +2769,37 @@ private: }); } - return connection.get()->receiveIncomingMessage().then( + return canceler.wrap(connection.get()->receiveIncomingMessage()).then( [this](kj::Maybe>&& message) { KJ_IF_MAYBE(m, message) { handleMessage(kj::mv(*m)); return true; } else { - disconnect(KJ_EXCEPTION(DISCONNECTED, "Peer disconnected.")); + tasks.add(KJ_EXCEPTION(DISCONNECTED, "Peer disconnected.")); return false; } + }, [this](kj::Exception&& exception) { + receiveIncomingMessageError = true; + kj::throwRecoverableException(kj::mv(exception)); + return false; }).then([this](bool keepGoing) { // No exceptions; continue loop. // // (We do this in a separate continuation to handle the case where exceptions are // disabled.) - if (keepGoing) tasks.add(messageLoop()); + // + // TODO(perf): We add an evalLater() here so that anything we needed to do in reaction to + // the previous message has a chance to complete before the next message is handled. In + // particular, without this, I observed an ordering problem: I saw a case where a `Return` + // message was followed by a `Resolve` message, but the `PromiseClient` associated with the + // `Resolve` had its `resolve()` method invoked _before_ any `PromiseClient`s associated + // with pipelined capabilities resolved by the `Return`. This could lead to an + // incorrectly-ordered interaction between `PromiseClient`s when they resolve to each + // other. This is probably really a bug in the way `Return`s are handled -- apparently, + // resolution of `PromiseClient`s based on returned capabilities does not occur in a + // depth-first way, when it should. If we could fix that then we can probably remove this + // `evalLater()`. However, the `evalLater()` is not that bad and solves the problem... + if (keepGoing) tasks.add(kj::evalLater([this]() { return messageLoop(); })); }); } @@ -2130,7 +2832,7 @@ private: break; case rpc::Message::RESOLVE: - handleResolve(reader.getResolve()); + handleResolve(kj::mv(message), reader.getResolve()); break; case rpc::Message::RELEASE: @@ -2262,8 +2964,17 @@ private: auto capTableArray = capTable.getTable(); KJ_DASSERT(capTableArray.size() == 1); - resultExports = writeDescriptors(capTableArray, payload); - capHook = KJ_ASSERT_NONNULL(capTableArray[0])->addRef(); + kj::Vector fds; + resultExports = writeDescriptors(capTableArray, payload, fds); + response->setFds(fds.releaseAsArray()); + + // If we're returning a capability that turns out to be an PromiseClient pointing back on + // this same network, it's important we remove the `PromiseClient` layer and use the inner + // capability instead. This achieves the same effect that `PostReturnRpcPipeline` does for + // regular call returns. + // + // This single line of code represents two hours of my life. + capHook = getInnermostClient(*KJ_ASSERT_NONNULL(capTableArray[0])); })) { fromException(*exception, ret.initException()); capHook = newBrokenCap(kj::mv(*exception)); @@ -2307,15 +3018,19 @@ private: } auto payload = call.getParams(); - auto capTableArray = receiveCaps(payload.getCapTable()); - auto cancelPaf = kj::newPromiseAndFulfiller(); + auto capTableArray = receiveCaps(payload.getCapTable(), message->getAttachedFds()); AnswerId answerId = call.getQuestionId(); + auto hints = callHintsFromReader(call); + + // Don't honor onlyPromisePipeline if results are redirected, because this situation isn't + // useful in practice and would be complicated to handle "correctly". + if (redirectResults) hints.onlyPromisePipeline = false; + auto context = kj::refcounted( *this, answerId, kj::mv(message), kj::mv(capTableArray), payload.getContent(), - redirectResults, kj::mv(cancelPaf.fulfiller), - call.getInterfaceId(), call.getMethodId()); + redirectResults, call.getInterfaceId(), call.getMethodId(), hints); // No more using `call` after this point, as it now belongs to the context. @@ -2331,7 +3046,7 @@ private: } auto promiseAndPipeline = startCall( - call.getInterfaceId(), call.getMethodId(), kj::mv(capability), context->addRef()); + call.getInterfaceId(), call.getMethodId(), kj::mv(capability), context->addRef(), hints); // Things may have changed -- in particular if startCall() immediately called // context->directTailCall(). @@ -2342,91 +3057,38 @@ private: answer.pipeline = kj::mv(promiseAndPipeline.pipeline); if (redirectResults) { - auto resultsPromise = promiseAndPipeline.promise.then( - kj::mvCapture(context, [](kj::Own&& context) { + answer.task = promiseAndPipeline.promise.then( + [context=kj::mv(context)]() mutable { return context->consumeRedirectedResponse(); - })); - - // If the call that later picks up `redirectedResults` decides to discard it, we need to - // make sure our call is not itself canceled unless it has called allowCancellation(). - // So we fork the promise and join one branch with the cancellation promise, in order to - // hold on to it. - auto forked = resultsPromise.fork(); - answer.redirectedResults = forked.addBranch(); - - cancelPaf.promise - .exclusiveJoin(forked.addBranch().then([](kj::Own&&){})) - .detach([](kj::Exception&&) {}); + }); + } else if (hints.onlyPromisePipeline) { + // The promise is probably fake anyway, so don't bother adding a .then(). We do, however, + // have to attach `context` to this, since we destroy `task` upon receiving a `Finish` + // message, and we want `RpcCallContext` to be destroyed no earlier than that. + answer.task = promiseAndPipeline.promise.attach(kj::mv(context)); } else { // Hack: Both the success and error continuations need to use the context. We could // refcount, but both will be destroyed at the same time anyway. - RpcCallContext* contextPtr = context; - - promiseAndPipeline.promise.then( - [contextPtr]() { - contextPtr->sendReturn(); - }, [contextPtr](kj::Exception&& exception) { - contextPtr->sendErrorReturn(kj::mv(exception)); - }).catch_([&](kj::Exception&& exception) { + RpcCallContext& contextRef = *context; + + answer.task = promiseAndPipeline.promise.then( + [context = kj::mv(context)]() mutable { + context->sendReturn(); + }, [&contextRef](kj::Exception&& exception) { + contextRef.sendErrorReturn(kj::mv(exception)); + }).eagerlyEvaluate([&](kj::Exception&& exception) { // Handle exceptions that occur in sendReturn()/sendErrorReturn(). taskFailed(kj::mv(exception)); - }).attach(kj::mv(context)) - .exclusiveJoin(kj::mv(cancelPaf.promise)) - .detach([](kj::Exception&&) {}); + }); } } } ClientHook::VoidPromiseAndPipeline startCall( uint64_t interfaceId, uint64_t methodId, - kj::Own&& capability, kj::Own&& context) { - if (interfaceId == typeId>() && methodId == 0) { - KJ_IF_MAYBE(g, gateway) { - // Wait, this is a call to Persistent.save() and we need to translate it through our - // gateway. - - KJ_IF_MAYBE(resolvedPromise, capability->whenMoreResolved()) { - // The plot thickens: We're looking at a promise capability. It could end up resolving - // to a capability outside the gateway, in which case we don't want to translate at all. - - auto promises = resolvedPromise->then(kj::mvCapture(context, - [this,interfaceId,methodId](kj::Own&& context, - kj::Own resolvedCap) { - auto vpap = startCall(interfaceId, methodId, kj::mv(resolvedCap), kj::mv(context)); - return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline)); - })).attach(addRef(*this)).split(); - - return { - kj::mv(kj::get<0>(promises)), - newLocalPromisePipeline(kj::mv(kj::get<1>(promises))), - }; - } - - if (capability->getBrand() == this) { - // This capability is one of our own, pointing back out over the network. That means - // that it would be inappropriate to apply the gateway transformation. We just want to - // reflect the call back. - return kj::downcast(*capability) - .callNoIntercept(interfaceId, methodId, kj::mv(context)); - } - - auto params = context->getParams().getAs::SaveParams>(); - - auto requestSize = params.totalSize(); - ++requestSize.capCount; - requestSize.wordCount += sizeInWords::ExportParams>(); - - auto request = g->exportRequest(requestSize); - request.setCap(Persistent<>::Client(capability->addRef())); - request.setParams(params); - - context->allowCancellation(); - context->releaseParams(); - return context->directTailCall(RequestHook::from(kj::mv(request))); - } - } - - return capability->call(interfaceId, methodId, kj::mv(context)); + kj::Own&& capability, kj::Own&& context, + ClientHook::CallHints hints) { + return capability->call(interfaceId, methodId, kj::mv(context), hints); } kj::Maybe> getMessageTarget(const rpc::MessageTarget::Reader& target) { @@ -2446,13 +3108,14 @@ private: auto promisedAnswer = target.getPromisedAnswer(); kj::Own pipeline; - auto& base = answers[promisedAnswer.getQuestionId()]; - KJ_REQUIRE(base.active, "PromisedAnswer.questionId is not a current question.") { - return nullptr; + KJ_IF_MAYBE(answer, answers.find(promisedAnswer.getQuestionId())) { + if (answer->active) { + KJ_IF_MAYBE(p, answer->pipeline) { + pipeline = p->get()->addRef(); + } + } } - KJ_IF_MAYBE(p, base.pipeline) { - pipeline = p->get()->addRef(); - } else { + if (pipeline.get() == nullptr) { pipeline = newBrokenPipeline(KJ_EXCEPTION(FAILED, "Pipeline call on a request that returned no capabilities or was already closed.")); } @@ -2479,9 +3142,41 @@ private: // pointer into it, so make sure these destructors run later. kj::Array exportsToRelease; KJ_DEFER(releaseExports(exportsToRelease)); - kj::Maybe>> promiseToRelease; + kj::Maybe promiseToRelease; + + QuestionId questionId = ret.getAnswerId(); + if (questions.isHigh(questionId)) { + // We sent hints with this question saying we didn't want a `Return` but we got one anyway. + // We cannot even look up the question on the question table because it's (remotely) possible + // that we already removed it and re-allocated the ID to something else. So, we should ignore + // the `Return`. But we might want to make note to stop using these hints, to protect against + // the (again, remote) possibility of our ID space wrapping around and leading to confusion. + if (ret.getReleaseParamCaps() && sentCapabilitiesInPipelineOnlyCall) { + // Oh no, it appears the peer wants us to release any capabilities in the params, something + // which only a level 0 peer would request (no version of the C++ RPC system has ever done + // this). And it appears we did send capabilities in at least one pipeline-only call + // previously. But we have no record of which capabilities were sent in *this* call, so + // we cannot release them. Log an error about the leak. + // + // This scenario is unlikely to happen in practice, because sendForPipeline() is not useful + // when talking to a peer that doesn't support capability-passing -- they couldn't possibly + // return a capability to pipeline on! So, I'm not going to spend time to find a solution + // for this corner case. We will log an error, though, just in case someone hits this + // somehow. + KJ_LOG(ERROR, + "sendForPipeline() was used when sending an RPC to a peer, the parameters of that " + "RPC included capabilities, but the peer seems to implement Cap'n Proto at level 0, " + "meaning it does not support capability passing (or, at least, it sent a `Return` " + "with `releaseParamCaps = true`). The capabilities that were sent may have been " + "leaked (they won't be dropped until the connection closes)."); + + sentCapabilitiesInPipelineOnlyCall = false; // don't log again + } + gotReturnForHighQuestionId = true; + return; + } - KJ_IF_MAYBE(question, questions.find(ret.getAnswerId())) { + KJ_IF_MAYBE(question, questions.find(questionId)) { KJ_REQUIRE(question->isAwaitingReturn, "Duplicate Return.") { return; } question->isAwaitingReturn = false; @@ -2491,6 +3186,10 @@ private: question->paramExports = nullptr; } + if (ret.getNoFinishNeeded()) { + question->skipFinish = true; + } + KJ_IF_MAYBE(questionRef, question->selfRef) { switch (ret.which()) { case rpc::Return::RESULTS: { @@ -2500,7 +3199,7 @@ private: } auto payload = ret.getResults(); - auto capTableArray = receiveCaps(payload.getCapTable()); + auto capTableArray = receiveCaps(payload.getCapTable(), message->getAttachedFds()); questionRef->fulfill(kj::refcounted( *this, kj::addRef(*questionRef), kj::mv(message), kj::mv(capTableArray), payload.getContent())); @@ -2532,8 +3231,15 @@ private: case rpc::Return::TAKE_FROM_OTHER_QUESTION: KJ_IF_MAYBE(answer, answers.find(ret.getTakeFromOtherQuestion())) { - KJ_IF_MAYBE(response, answer->redirectedResults) { + KJ_IF_MAYBE(response, answer->task.tryGet()) { questionRef->fulfill(kj::mv(*response)); + answer->task = Answer::Finished(); + + KJ_IF_MAYBE(context, answer->callContext) { + // Send the `Return` message for the call of which we're taking ownership, so + // that the peer knows it can now tear down the call state. + context->sendRedirectReturn(); + } } else { KJ_FAIL_REQUIRE("`Return.takeFromOtherQuestion` referenced a call that did not " "use `sendResultsTo.yourself`.") { return; } @@ -2548,10 +3254,23 @@ private: KJ_FAIL_REQUIRE("Unknown 'Return' type.") { return; } } } else { + // This is a response to a question that we canceled earlier. + if (ret.isTakeFromOtherQuestion()) { - // Be sure to release the tail call's promise. + // This turned out to be a tail call back to us! We now take ownership of the tail call. + // Since the caller canceled, we need to cancel out the tail call, if it still exists. + KJ_IF_MAYBE(answer, answers.find(ret.getTakeFromOtherQuestion())) { - promiseToRelease = kj::mv(answer->redirectedResults); + // Indeed, it does still exist. + + // Throw away the result promise. + promiseToRelease = kj::mv(answer->task); + + KJ_IF_MAYBE(context, answer->callContext) { + // Send the `Return` message for the call of which we're taking ownership, so + // that the peer knows it can now tear down the call state. + context->sendRedirectReturn(); + } } } @@ -2573,9 +3292,13 @@ private: KJ_DEFER(releaseExports(exportsToRelease)); Answer answerToRelease; kj::Maybe> pipelineToRelease; + kj::Maybe promiseToRelease; KJ_IF_MAYBE(answer, answers.find(finish.getQuestionId())) { - KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; } + if (!answer->active) { + // Treat the same as if the answer wasn't in the table; see comment below. + return; + } if (finish.getReleaseResultCaps()) { exportsToRelease = kj::mv(answer->resultExports); @@ -2585,29 +3308,68 @@ private: pipelineToRelease = kj::mv(answer->pipeline); - // If the call isn't actually done yet, cancel it. Otherwise, we can go ahead and erase the - // question from the table. KJ_IF_MAYBE(context, answer->callContext) { - context->requestCancel(); + // Destroying answer->task will probably destroy the call context, but we can't prove that + // since it's refcounted. Instead, inform the call context that it is now its job to + // clean up the answer table. Then, cancel the task. + promiseToRelease = kj::mv(answer->task); + answer->task = Answer::Finished(); + context->finish(); } else { + // The call context is already gone so we can tear down the Answer here. answerToRelease = answers.erase(finish.getQuestionId()); } } else { - KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; } + // The `Finish` message targets a qusetion ID that isn't present in our answer table. + // Probably, we send a `Return` with `noFinishNeeded = true`, but the other side didn't + // recognize this hint and sent a `Finish` anyway, or the `Finish` was already in-flight at + // the time we sent the `Return`. We can silently ignore this. + // + // It would be nice to detect invalid finishes somehow, but to do so we would have to + // remember past answer IDs somewhere even when we said `noFinishNeeded`. Assuming the other + // side respects the hint and doesn't send a `Finish`, we'd only be able to clean up these + // records when the other end reuses the question ID, which might never happen. + } + + if (finish.getRequireEarlyCancellationWorkaround()) { + // Defer actual cancellation of the call until the end of the event loop queue. + // + // This is needed for compatibility with older versions of Cap'n Proto (0.10 and prior) in + // which the default was to prohibit cancellation until it was explicitly allowed. In newer + // versions (1.0 and later) cancellation is allowed until explicitly prohibited, that is, if + // we haven't actually delivered the call yet, it can be canceled. This requires less + // bookkeeping and so improved performance. + // + // However, old clients might be inadvertently relying on the old behavior. For example, if + // someone using and old version called `.send()` on a message and then promptly dropped the + // returned Promise, the message would often be delivered. This was not intended to work, but + // did, and could be relied upon by accident. Moreover, the original implementation of + // streaming included a bug where streaming calls *always* sent an immediate Finish. + // + // By deferring cancellation until after a turn of the event loop, we provide an opportunity + // for any `Call` messages we've received to actually be delivered, so that they can opt out + // of cancellation if desired. + KJ_IF_MAYBE(task, promiseToRelease) { + KJ_IF_MAYBE(running, task->tryGet()) { + tasks.add(kj::evalLast([running = kj::mv(*running)]() { + // Just drop `running` here to cancel the call. + })); + } + } } } // --------------------------------------------------------------------------- // Level 1 - void handleResolve(const rpc::Resolve::Reader& resolve) { + void handleResolve(kj::Own&& message, const rpc::Resolve::Reader& resolve) { kj::Own replacement; kj::Maybe exception; // Extract the replacement capability. switch (resolve.which()) { case rpc::Resolve::CAP: - KJ_IF_MAYBE(cap, receiveCap(resolve.getCap())) { + KJ_IF_MAYBE(cap, receiveCap(resolve.getCap(), message->getAttachedFds())) { replacement = kj::mv(*cap); } else { KJ_FAIL_REQUIRE("'Resolve' contained 'CapDescriptor.none'.") { return; } @@ -2683,26 +3445,40 @@ private: return; } - for (;;) { - KJ_IF_MAYBE(r, target->getResolved()) { - target = r->addRef(); - } else { - break; - } - } + EmbargoId embargoId = context.getSenderLoopback(); - KJ_REQUIRE(target->getBrand() == this, - "'Disembargo' of type 'senderLoopback' sent to an object that does not point " - "back to the sender.") { - return; - } + // It's possible that `target` is a promise capability that hasn't resolved yet, in which + // case we must wait for the resolution. In particular this can happen in the case where + // we have Alice -> Bob -> Carol, Alice makes a call that proxies from Bob to Carol, and + // Carol returns a capability from this call that points all the way back though Bob to + // Alice. When this return capability passes through Bob, Bob will resolve the previous + // promise-pipeline capability to it. However, Bob has to send a Disembargo to Carol before + // completing this resolution. In the meantime, though, Bob returns the final repsonse to + // Alice. Alice then *also* sends a Disembargo to Bob. The Alice -> Bob Disembargo might + // arrive at Bob before the Bob -> Carol Disembargo has resolved, in which case the + // Disembargo is delivered to a promise capability. + auto promise = target->whenResolved() + .then([]() { + // We also need to insert an evalLast() here to make sure that any pending calls towards + // this cap have had time to find their way through the event loop. + return kj::evalLast([]() {}); + }); - EmbargoId embargoId = context.getSenderLoopback(); + tasks.add(promise.then([this, embargoId, target = kj::mv(target)]() mutable { + for (;;) { + KJ_IF_MAYBE(r, target->getResolved()) { + target = r->addRef(); + } else { + break; + } + } + + KJ_REQUIRE(target->getBrand() == this, + "'Disembargo' of type 'senderLoopback' sent to an object that does not point " + "back to the sender.") { + return; + } - // We need to insert an evalLater() here to make sure that any pending calls towards this - // cap have had time to find their way through the event loop. - tasks.add(kj::evalLater(kj::mvCapture( - target, [this,embargoId](kj::Own&& target) { if (!connection.is()) { return; } @@ -2722,8 +3498,8 @@ private: // any promise with a direct node in order to solve the Tribble 4-way race condition. // See the documentation of Disembargo in rpc.capnp for more. KJ_REQUIRE(redirect == nullptr, - "'Disembargo' of type 'senderLoopback' sent to an object that does not " - "appear to have been the subject of a previous 'Resolve' message.") { + "'Disembargo' of type 'senderLoopback' sent to an object that does not " + "appear to have been the subject of a previous 'Resolve' message.") { return; } } @@ -2731,7 +3507,7 @@ private: builder.getContext().setReceiverLoopback(embargoId); message->send(); - }))); + })); break; } @@ -2761,21 +3537,18 @@ private: class RpcSystemBase::Impl final: private BootstrapFactoryBase, private kj::TaskSet::ErrorHandler { public: - Impl(VatNetworkBase& network, kj::Maybe bootstrapInterface, - kj::Maybe::Client> gateway) + Impl(VatNetworkBase& network, kj::Maybe bootstrapInterface) : network(network), bootstrapInterface(kj::mv(bootstrapInterface)), - bootstrapFactory(*this), gateway(kj::mv(gateway)), tasks(*this) { - tasks.add(acceptLoop()); + bootstrapFactory(*this), tasks(*this) { + acceptLoopPromise = acceptLoop().eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); } - Impl(VatNetworkBase& network, BootstrapFactoryBase& bootstrapFactory, - kj::Maybe::Client> gateway) - : network(network), bootstrapFactory(bootstrapFactory), - gateway(kj::mv(gateway)), tasks(*this) { - tasks.add(acceptLoop()); + Impl(VatNetworkBase& network, BootstrapFactoryBase& bootstrapFactory) + : network(network), bootstrapFactory(bootstrapFactory), tasks(*this) { + acceptLoopPromise = acceptLoop().eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); } Impl(VatNetworkBase& network, SturdyRefRestorerBase& restorer) : network(network), bootstrapFactory(*this), restorer(restorer), tasks(*this) { - tasks.add(acceptLoop()); + acceptLoopPromise = acceptLoop().eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); } ~Impl() noexcept(false) { @@ -2784,7 +3557,7 @@ public: // disassemble it. if (!connections.empty()) { kj::Vector> deleteMe(connections.size()); - kj::Exception shutdownException = KJ_EXCEPTION(FAILED, "RpcSystem was destroyed."); + kj::Exception shutdownException = KJ_EXCEPTION(DISCONNECTED, "RpcSystem was destroyed."); for (auto& entry: connections) { entry.second->disconnect(kj::cp(shutdownException)); deleteMe.add(kj::mv(entry.second)); @@ -2803,11 +3576,16 @@ public: KJ_IF_MAYBE(connection, network.baseConnect(vatId)) { auto& state = getConnectionState(kj::mv(*connection)); return Capability::Client(state.restore(objectId)); + } else if (objectId.isNull()) { + // Turns out `vatId` refers to ourselves, so we can also pass it as the client ID for + // baseCreateFor(). + return bootstrapFactory.baseCreateFor(vatId); } else KJ_IF_MAYBE(r, restorer) { return r->baseRestore(objectId); } else { return Capability::Client(newBrokenCap( - "SturdyRef referred to a local object but there is no local SturdyRef restorer.")); + "This vat only supports a bootstrap interface, not the old Cap'n-Proto-0.4-style " + "named exports.")); } } @@ -2819,13 +3597,20 @@ public: } } + void setTraceEncoder(kj::Function func) { + traceEncoder = kj::mv(func); + } + + kj::Promise run() { return kj::mv(acceptLoopPromise); } + private: VatNetworkBase& network; kj::Maybe bootstrapInterface; BootstrapFactoryBase& bootstrapFactory; - kj::Maybe::Client> gateway; kj::Maybe restorer; size_t flowLimit = kj::maxValue; + kj::Maybe> traceEncoder; + kj::Promise acceptLoopPromise = nullptr; kj::TaskSet tasks; typedef std::unordered_map> @@ -2845,8 +3630,8 @@ private: tasks.add(kj::mv(info.shutdownPromise)); })); auto newState = kj::refcounted( - bootstrapFactory, gateway, restorer, kj::mv(connection), - kj::mv(onDisconnect.fulfiller), flowLimit); + bootstrapFactory, restorer, kj::mv(connection), + kj::mv(onDisconnect.fulfiller), flowLimit, traceEncoder); RpcConnectionState& result = *newState; connections.insert(std::make_pair(connectionPtr, kj::mv(newState))); return result; @@ -2856,16 +3641,10 @@ private: } kj::Promise acceptLoop() { - auto receive = network.baseAccept().then( + return network.baseAccept().then( [this](kj::Own&& connection) { getConnectionState(kj::mv(connection)); - }); - return receive.then([this]() { - // No exceptions; continue loop. - // - // (We do this in a separate continuation to handle the case where exceptions are - // disabled.) - tasks.add(acceptLoop()); + return acceptLoop(); }); } @@ -2888,13 +3667,11 @@ private: }; RpcSystemBase::RpcSystemBase(VatNetworkBase& network, - kj::Maybe bootstrapInterface, - kj::Maybe::Client> gateway) - : impl(kj::heap(network, kj::mv(bootstrapInterface), kj::mv(gateway))) {} + kj::Maybe bootstrapInterface) + : impl(kj::heap(network, kj::mv(bootstrapInterface))) {} RpcSystemBase::RpcSystemBase(VatNetworkBase& network, - BootstrapFactoryBase& bootstrapFactory, - kj::Maybe::Client> gateway) - : impl(kj::heap(network, bootstrapFactory, kj::mv(gateway))) {} + BootstrapFactoryBase& bootstrapFactory) + : impl(kj::heap(network, bootstrapFactory)) {} RpcSystemBase::RpcSystemBase(VatNetworkBase& network, SturdyRefRestorerBase& restorer) : impl(kj::heap(network, restorer)) {} RpcSystemBase::RpcSystemBase(RpcSystemBase&& other) noexcept = default; @@ -2913,5 +3690,171 @@ void RpcSystemBase::baseSetFlowLimit(size_t words) { return impl->setFlowLimit(words); } +void RpcSystemBase::setTraceEncoder(kj::Function func) { + impl->setTraceEncoder(kj::mv(func)); +} + +kj::Promise RpcSystemBase::run() { + return impl->run(); +} + } // namespace _ (private) + +// ======================================================================================= + +namespace { + +class WindowFlowController final: public RpcFlowController, private kj::TaskSet::ErrorHandler { +public: + WindowFlowController(RpcFlowController::WindowGetter& windowGetter) + : windowGetter(windowGetter), tasks(*this) { + state.init(); + } + + kj::Promise send(kj::Own message, kj::Promise ack) override { + auto size = message->sizeInWords() * sizeof(capnp::word); + maxMessageSize = kj::max(size, maxMessageSize); + + // We are REQUIRED to send the message NOW to maintain correct ordering. + message->send(); + + inFlight += size; + tasks.add(ack.then([this, size]() { + inFlight -= size; + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(blockedSends, Running) { + if (isReady()) { + // Release all fulfillers. + for (auto& fulfiller: blockedSends) { + fulfiller->fulfill(); + } + blockedSends.clear(); + + } + + KJ_IF_MAYBE(f, emptyFulfiller) { + if (inFlight == 0) { + f->get()->fulfill(tasks.onEmpty()); + } + } + } + KJ_CASE_ONEOF(exception, kj::Exception) { + // A previous call failed, but this one -- which was already in-flight at the time -- + // ended up succeeding. That may indicate that the server side is not properly + // handling streaming error propagation. Nothing much we can do about it here though. + } + } + })); + + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(blockedSends, Running) { + if (isReady()) { + return kj::READY_NOW; + } else { + auto paf = kj::newPromiseAndFulfiller(); + blockedSends.add(kj::mv(paf.fulfiller)); + return kj::mv(paf.promise); + } + } + KJ_CASE_ONEOF(exception, kj::Exception) { + return kj::cp(exception); + } + } + KJ_UNREACHABLE; + } + + kj::Promise waitAllAcked() override { + KJ_IF_MAYBE(q, state.tryGet()) { + if (!q->empty()) { + auto paf = kj::newPromiseAndFulfiller>(); + emptyFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + } + return tasks.onEmpty(); + } + +private: + RpcFlowController::WindowGetter& windowGetter; + size_t inFlight = 0; + size_t maxMessageSize = 0; + + typedef kj::Vector>> Running; + kj::OneOf state; + + kj::Maybe>>> emptyFulfiller; + + kj::TaskSet tasks; + + void taskFailed(kj::Exception&& exception) override { + KJ_SWITCH_ONEOF(state) { + KJ_CASE_ONEOF(blockedSends, Running) { + // Fail out all pending sends. + for (auto& fulfiller: blockedSends) { + fulfiller->reject(kj::cp(exception)); + } + // Fail out all future sends. + state = kj::mv(exception); + } + KJ_CASE_ONEOF(exception, kj::Exception) { + // ignore redundant exception + } + } + } + + bool isReady() { + // We extend the window by maxMessageSize to avoid a pathological situation when a message + // is larger than the window size. Otherwise, after sending that message, we would end up + // not sending any others until the ack was received, wasting a round trip's worth of + // bandwidth. + return inFlight <= maxMessageSize // avoid getWindow() call if unnecessary + || inFlight < windowGetter.getWindow() + maxMessageSize; + } +}; + +class FixedWindowFlowController final + : public RpcFlowController, public RpcFlowController::WindowGetter { +public: + FixedWindowFlowController(size_t windowSize): windowSize(windowSize), inner(*this) {} + + kj::Promise send(kj::Own message, kj::Promise ack) override { + return inner.send(kj::mv(message), kj::mv(ack)); + } + + kj::Promise waitAllAcked() override { + return inner.waitAllAcked(); + } + + size_t getWindow() override { return windowSize; } + +private: + size_t windowSize; + WindowFlowController inner; +}; + +} // namespace + +kj::Own RpcFlowController::newFixedWindowController(size_t windowSize) { + return kj::heap(windowSize); +} +kj::Own RpcFlowController::newVariableWindowController(WindowGetter& getter) { + return kj::heap(getter); +} + +bool IncomingRpcMessage::isShortLivedRpcMessage(AnyPointer::Reader body) { + switch (body.getAs().which()) { + case rpc::Message::CALL: + case rpc::Message::RETURN: + return false; + default: + return true; + } +} + +kj::Function IncomingRpcMessage::getShortLivedCallback() { + return [](MessageReader& reader) { + return IncomingRpcMessage::isShortLivedRpcMessage(reader.getRoot()); + }; +} + } // namespace capnp diff --git a/c++/src/capnp/rpc.capnp b/c++/src/capnp/rpc.capnp index cd808b39f7..0e718d5bf2 100644 --- a/c++/src/capnp/rpc.capnp +++ b/c++/src/capnp/rpc.capnp @@ -233,11 +233,11 @@ struct Message { abort @1 :Exception; # Sent when a connection is being aborted due to an unrecoverable error. This could be e.g. - # because the sender received an invalid or nonsensical message (`isCallersFault` is true) or - # because the sender had an internal error (`isCallersFault` is false). The sender will shut - # down the outgoing half of the connection after `abort` and will completely close the - # connection shortly thereafter (it's up to the sender how much of a time buffer they want to - # offer for the client to receive the `abort` before the connection is reset). + # because the sender received an invalid or nonsensical message or because the sender had an + # internal error. The sender will shut down the outgoing half of the connection after `abort` + # and will completely close the connection shortly thereafter (it's up to the sender how much + # of a time buffer they want to offer for the client to receive the `abort` before the + # connection is reset). # Level 0 features ----------------------------------------------- @@ -316,8 +316,8 @@ struct Bootstrap { # A Vat may export multiple bootstrap interfaces. In this case, `deprecatedObjectId` specifies # which one to return. If this pointer is null, then the default bootstrap interface is returned. # - # As of verison 0.5, use of this field is deprecated. If a service wants to export multiple - # bootstrap interfaces, it should instead define a single bootstarp interface that has methods + # As of version 0.5, use of this field is deprecated. If a service wants to export multiple + # bootstrap interfaces, it should instead define a single bootstrap interface that has methods # that return each of the other interfaces. # # **History** @@ -352,7 +352,7 @@ struct Bootstrap { # - Overloading "Restore" also had a security problem: Often, "main" or "well-known" # capabilities exported by a vat are in fact not public: they are intended to be accessed only # by clients who are capable of forming a connection to the vat. This can lead to trouble if - # the client itself has other clients and wishes to foward some `Restore` requests from those + # the client itself has other clients and wishes to forward some `Restore` requests from those # external clients -- it has to be very careful not to allow through `Restore` requests # addressing the default capability. # @@ -415,6 +415,30 @@ struct Call { # `acceptFromThirdParty`. Level 3 implementations should set this true. Otherwise, the callee # will have to proxy the return in the case of a tail call to a third-party vat. + noPromisePipelining @9 :Bool = false; + # If true, the sender promises that it won't make any promise-pipelined calls on the results of + # this call. If it breaks this promise, the receiver may throw an arbitrary error from such + # calls. + # + # The receiver may use this as an optimization, by skipping the bookkeeping needed for pipelining + # when no pipelined calls are expected. The sender typically sets this to false when the method's + # schema does not specify any return capabilities. + + onlyPromisePipeline @10 :Bool = false; + # If true, the sender only plans to use this call to make pipelined calls. The receiver need not + # send a `Return` message (but is still allowed to do so). + # + # Since the sender does not know whether a `Return` will be sent, it must release all state + # related to the call when it sends `Finish`. However, in the case that the callee does not + # recognize this hint and chooses to send a `Return`, then technically the caller is not allowed + # to reuse the question ID until it receives said `Return`. This creates a conundrum: How does + # the caller decide when it's OK to reuse the ID? To sidestep the problem, the C++ implementation + # uses high-numbered IDs (with the high-order bit set) for such calls, and cycles through the + # IDs in order. If all 2^31 IDs in this space are used without ever seeing a `Return`, then the + # implementation assumes that the other end is in fact honoring the hint, and the ID counter is + # allowed to loop around. If a `Return` is ever seen when `onlyPromisePipeline` was set, then + # the implementation stops using this hint. + params @4 :Payload; # The call parameters. `params.content` is a struct whose fields correspond to the parameters of # the method. @@ -446,23 +470,22 @@ struct Call { # in the calls so that the results need not pass back through Vat B. # # For example: - # - Alice, in Vat A, call foo() on Bob in Vat B. + # - Alice, in Vat A, calls foo() on Bob in Vat B. # - Alice makes a pipelined call bar() on the promise returned by foo(). # - Later on, Bob resolves the promise from foo() to point at Carol, who lives in Vat A (next # to Alice). # - Vat B dutifully forwards the bar() call to Carol. Let us call this forwarded call bar'(). # Notice that bar() and bar'() are travelling in opposite directions on the same network # link. - # - The `Call` for bar'() has `sendResultsTo` set to `yourself`, with the value being the - # question ID originally assigned to the bar() call. + # - The `Call` for bar'() has `sendResultsTo` set to `yourself`. + # - Vat B sends a `Return` for bar() with `takeFromOtherQuestion` set in place of the results, + # with the value set to the question ID of bar'(). Vat B does not wait for bar'() to return, + # as doing so would introduce unnecessary round trip latency. # - Vat A receives bar'() and delivers it to Carol. - # - When bar'() returns, Vat A immediately takes the results and returns them from bar(). - # - Meanwhile, Vat A sends a `Return` for bar'() to Vat B, with `resultsSentElsewhere` set in - # place of results. - # - Vat A sends a `Finish` for that call to Vat B. - # - Vat B receives the `Return` for bar'() and sends a `Return` for bar(), with - # `receivedFromYourself` set in place of the results. - # - Vat B receives the `Finish` for bar() and sends a `Finish` to bar'(). + # - When bar'() returns, Vat A sends a `Return` for bar'() to Vat B, with `resultsSentElsewhere` + # set in place of results. + # - Vat A sends a `Finish` for the bar() call to Vat B. + # - Vat B receives the `Finish` for bar() and sends a `Finish` for bar'(). thirdParty @7 :RecipientId; # **(level 3)** @@ -493,6 +516,16 @@ struct Return { # should always set this true. This defaults true because if level 0 implementations forget to # set it they'll never notice (just silently leak caps), but if level >=1 implementations forget # to set it to false they'll quickly get errors. + # + # The receiver should act as if the sender had sent a release message with count=1 for each + # CapDescriptor in the original Call message. + + noFinishNeeded @8 :Bool = false; + # If true, the sender does not need the receiver to send a `Finish` message; its answer table + # entry has already been cleaned up. This implies that the results do not contain any + # capabilities, since the `Finish` message would normally release those capabilities from + # promise pipelining responsibility. The caller may still send a `Finish` message if it wants, + # which will be silently ignored by the callee. union { results @2 :Payload; @@ -500,9 +533,9 @@ struct Return { # # For regular method calls, `results.content` points to the result struct. # - # For a `Return` in response to an `Accept`, `results` contains a single capability (rather - # than a struct), and `results.content` is just a capability pointer with index 0. A `Finish` - # is still required in this case. + # For a `Return` in response to an `Accept` or `Bootstrap`, `results` contains a single + # capability (rather than a struct), and `results.content` is just a capability pointer with + # index 0. A `Finish` is still required in this case. exception @3 :Exception; # Indicates that the call failed and explains why. @@ -514,11 +547,15 @@ struct Return { resultsSentElsewhere @5 :Void; # This is set when returning from a `Call` that had `sendResultsTo` set to something other # than `caller`. + # + # It doesn't matter too much when this is sent, as the receiver doesn't need to do anything + # with it, but the C++ implementation appears to wait for the call to finish before sending + # this. takeFromOtherQuestion @6 :QuestionId; # The sender has also sent (before this message) a `Call` with the given question ID and with # `sendResultsTo.yourself` set, and the results of that other call should be used as the - # results here. + # results here. `takeFromOtherQuestion` can only used once per question. acceptFromThirdParty @7 :ThirdPartyCapId; # **(level 3)** @@ -558,6 +595,20 @@ struct Finish { # should always set this true. This defaults true because if level 0 implementations forget to # set it they'll never notice (just silently leak caps), but if level >=1 implementations forget # set it false they'll quickly get errors. + + requireEarlyCancellationWorkaround @2 :Bool = true; + # If true, if the RPC system receives this Finish message before the original call has even been + # delivered, it should defer cancellation util after delivery. In particular, this gives the + # destination object a chance to opt out of cancellation, e.g. as controlled by the + # `allowCancellation` annotation defined in `c++.capnp`. + # + # This is a work-around. Versions 1.0 and up of Cap'n Proto always set this to false. However, + # older versions of Cap'n Proto unintentionally exhibited this errant behavior by default, and + # as a result programs built with older versions could be inadvertently relying on their peers + # to implement the behavior. The purpose of this flag is to let newer versions know when the + # peer is an older version, so that it can attempt to work around the issue. + # + # See also comments in handleFinish() in rpc.c++ for more details. } # Level 1 message types ---------------------------------------------- @@ -608,7 +659,7 @@ struct Resolve { # # The sender promises that from this point forth, until `promiseId` is released, it shall # simply forward all messages to the capability designated by `cap`. This is true even if - # `cap` itself happens to desigate another promise, and that other promise later resolves -- + # `cap` itself happens to designate another promise, and that other promise later resolves -- # messages sent to `promiseId` shall still go to that other promise, not to its resolution. # This is important in the case that the receiver of the `Resolve` ends up sending a # `Disembargo` message towards `promiseId` in order to control message ordering -- that @@ -692,7 +743,7 @@ struct Disembargo { # Extending the embargo/disembargo protocol to be able to shorted multiple hops at once seems # difficult. Instead, we make a rule that prevents this case from coming up: # - # One a promise P has been resolved to a remove object reference R, then all further messages + # One a promise P has been resolved to a remote object reference R, then all further messages # received addressed to P will be forwarded strictly to R. Even if it turns out later that R is # itself a promise, and has resolved to some other object Q, messages sent to P will still be # forwarded to R, not directly to Q (R will of course further forward the messages to Q). @@ -701,6 +752,10 @@ struct Disembargo { # is expected that people sending messages to P will shortly start sending them to R instead and # drop P. P is at end-of-life anyway, so it doesn't matter if it ignores chances to further # optimize its path. + # + # Note well: the Tribble 4-way race condition does not require each vat to be *distinct*; as long + # as each resolution crosses a network boundary the race can occur -- so this concerns even level + # 1 implementations, not just level 3 implementations. target @0 :MessageTarget; # What is to be disembargoed. @@ -781,7 +836,7 @@ struct Accept { # Message type sent to pick up a capability hosted by the receiving vat and provided by a third # party. The third party previously designated the capability using `Provide`. # - # This message is also used to pick up a redirected return -- see `Return.redirect`. + # This message is also used to pick up a redirected return -- see `Return.acceptFromThirdParty`. questionId @0 :QuestionId; # A new question ID identifying this accept message, which will eventually receive a Return @@ -849,7 +904,7 @@ struct Join { # - Dana receives the first request and sees that the JoinKeyPart is one of two. She notes that # she doesn't have the other part yet, so she records the request and responds with a # JoinResult. - # - Alice relays the JoinAswer back to Bob. + # - Alice relays the JoinAnswer back to Bob. # - Carol is also proxying a capability from Dana, and so forwards her Join request to Dana as # well. # - Dana receives Carol's request and notes that she now has both parts of a JoinKey. She @@ -940,6 +995,11 @@ struct CapDescriptor { # # Keep in mind that `ExportIds` in a `CapDescriptor` are subject to reference counting. See the # description of `ExportId`. + # + # Note that it is currently not possible to include a broken capability in the CapDescriptor + # table. Instead, create a new export (`senderPromise`) for each broken capability and then + # immediately follow the payload-bearing Call or Return message with one Resolve message for each + # broken capability, resolving it to an exception. union { none @0 :Void; @@ -951,8 +1011,8 @@ struct CapDescriptor { # Hopefully this is unusual. senderHosted @1 :ExportId; - # A capability newly exported by the sender. This is the ID of the new capability in the - # sender's export table (receiver's import table). + # The ID of a capability in the sender's export table (receiver's import table). It may be a + # newly allocated table entry, or an existing entry (increments the reference count). senderPromise @2 :ExportId; # A promise that the sender will resolve later. The sender will send exactly one Resolve @@ -977,6 +1037,63 @@ struct CapDescriptor { # Level 1 and 2 implementations that receive a `thirdPartyHosted` may simply send calls to its # `vine` instead. } + + attachedFd @6 :UInt8 = 0xff; + # If the RPC message in which this CapDescriptor was delivered also had file descriptors + # attached, and `fd` is a valid index into the list of attached file descriptors, then + # that file descriptor should be attached to this capability. If `attachedFd` is out-of-bounds + # for said list, then no FD is attached. + # + # For example, if the RPC message arrived over a Unix socket, then file descriptors may be + # attached by sending an SCM_RIGHTS ancillary message attached to the data bytes making up the + # raw message. Receivers who wish to opt into FD passing should arrange to receive SCM_RIGHTS + # whenever receiving an RPC message. Senders who wish to send FDs need not verify whether the + # receiver knows how to receive them, because the operating system will automatically discard + # ancillary messages like SCM_RIGHTS if the receiver doesn't ask to receive them, including + # automatically closing any FDs. + # + # It is up to the application protocol to define what capabilities are expected to have file + # descriptors attached, and what those FDs mean. But, for example, an application could use this + # to open a file on disk and then transmit the open file descriptor to a sandboxed process that + # does not otherwise have permission to access the filesystem directly. This is usually an + # optimization: the sending process could instead provide an RPC interface supporting all the + # operations needed (such as reading and writing a file), but by passing the file descriptor + # directly, the recipient can often perform operations much more efficiently. Application + # designers are encouraged to provide such RPC interfaces and automatically fall back to them + # when FD passing is not available, so that the application can still work when the parties are + # remote over a network. + # + # An attached FD is most often associated with a `senderHosted` descriptor. It could also make + # sense in the case of `thirdPartyHosted`: in this case, the sender is forwarding the FD that + # they received from the third party, so that the receiver can start using it without first + # interacting with the third party. This is an optional optimization -- the middleman may choose + # not to forward capabilities, in which case the receiver will need to complete the handshake + # with the third party directly before receiving the FD. If an implementation receives a second + # attached FD after having already received one previously (e.g. both in a `thirdPartyHosted` + # CapDescriptor and then later again when receiving the final capability directly from the + # third party), the implementation should discard the later FD and stick with the original. At + # present, there is no known reason why other capability types (e.g. `receiverHosted`) would want + # to carry an attached FD, but we reserve the right to define a meaning for this in the future. + # + # Each file descriptor attached to the message must be used in no more than one CapDescriptor, + # so that the receiver does not need to use dup() or refcounting to handle the possibility of + # multiple capabilities using the same descriptor. If multiple CapDescriptors do point to the + # same FD index, then the receiver can arbitrarily choose which capability ends up having the + # FD attached. + # + # To mitigate DoS attacks, RPC implementations should limit the number of FDs they are willing to + # receive in a single message to a small value. If a message happens to contain more than that, + # the list is truncated. Moreover, in some cases, FD passing needs to be blocked entirely for + # security or implementation reasons, in which case the list may be truncated to zero. Hence, + # `attachedFd` might point past the end of the list, which the implementation should treat as if + # no FD was attached at all. + # + # The type of this field was chosen to be UInt8 because Linux supports sending only a maximum + # of 253 file descriptors in an SCM_RIGHTS message anyway, and CapDescriptor had two bytes of + # padding left -- so after adding this, there is still one byte for a future feature. + # Conveniently, this also means we're able to use 0xff as the default value, which will always + # be out-of-range (of course, the implementation should explicitly enforce that 255 descriptors + # cannot be sent at once, rather than relying on Linux to do so). } struct PromisedAnswer { @@ -1041,7 +1158,7 @@ struct ThirdPartyCapDescriptor { # simply send calls to the vine. Such calls will be forwarded to the third-party by the # sender. # - # * Level 3 implementations must release the vine once they have successfully picked up the + # * Level 3 implementations must release the vine only once they have successfully picked up the # object from the third party. This ensures that the capability is not released by the sender # prematurely. # @@ -1118,7 +1235,7 @@ struct Exception { # start over. This should in turn cause the server to obtain a new copy of the capability that # it lost, thus making everything work. # - # If the client receives another `disconnencted` error in the process of rebuilding the + # If the client receives another `disconnected` error in the process of rebuilding the # capability and retrying the call, it should treat this as an `overloaded` error: the network # is currently unreliable, possibly due to load or other temporary issues. @@ -1133,6 +1250,11 @@ struct Exception { obsoleteDurability @2 :UInt16; # OBSOLETE. See `type` instead. + + trace @4 :Text; + # Stack trace text from the remote server. The format is not specified. By default, + # implementations do not provide stack traces; the application must explicitly enable them + # when desired. } # ======================================================================================== @@ -1220,7 +1342,7 @@ using SturdyRef = AnyPointer; # - How to authenticate the vat after connecting (e.g. a public key fingerprint). # - The identity of a specific object hosted by the vat. Generally, this is an opaque pointer whose # format is defined by the specific vat -- the client has no need to inspect the object ID. -# It is important that the objec ID be unguessable if the object is not public (and objects +# It is important that the object ID be unguessable if the object is not public (and objects # should almost never be public). # # The above are only suggestions. Some networks might work differently. For example, a private @@ -1234,8 +1356,8 @@ using ProvisionId = AnyPointer; # The information that must be sent in an `Accept` message to identify the object being accepted. # # In a network where each vat has a public/private key pair, this could simply be the public key -# fingerprint of the provider vat along with the question ID used in the `Provide` message sent from -# that provider. +# fingerprint of the provider vat along with a nonce matching the one in the `RecipientId` used +# in the `Provide` message sent from that provider. using RecipientId = AnyPointer; # **(level 3)** @@ -1244,8 +1366,12 @@ using RecipientId = AnyPointer; # capability. # # In a network where each vat has a public/private key pair, this could simply be the public key -# fingerprint of the recipient. (CapTP also calls for a nonce to identify the object. In our -# case, the `Provide` message's `questionId` can serve as the nonce.) +# fingerprint of the recipient along with a nonce matching the one in the `ProvisionId`. +# +# As another example, when communicating between processes on the same machine over Unix sockets, +# RecipientId could simply refer to a file descriptor attached to the message via SCM_RIGHTS. +# This file descriptor would be one end of a newly-created socketpair, with the other end having +# been sent to the capability's recipient in ThirdPartyCapId. using ThirdPartyCapId = AnyPointer; # **(level 3)** @@ -1254,8 +1380,13 @@ using ThirdPartyCapId = AnyPointer; # # In a network where each vat has a public/private key pair, this could be a combination of the # third party's public key fingerprint, hints on how to connect to the third party (e.g. an IP -# address), and the question ID used in the corresponding `Provide` message sent to that third party -# (used to identify which capability to pick up). +# address), and the nonce used in the corresponding `Provide` message's `RecipientId` as sent +# to that third party (used to identify which capability to pick up). +# +# As another example, when communicating between processes on the same machine over Unix sockets, +# ThirdPartyCapId could simply refer to a file descriptor attached to the message via SCM_RIGHTS. +# This file descriptor would be one end of a newly-created socketpair, with the other end having +# been sent to the process hosting the capability in RecipientId. using JoinKeyPart = AnyPointer; # **(level 4)** diff --git a/c++/src/capnp/rpc.capnp.c++ b/c++/src/capnp/rpc.capnp.c++ index 0135f9b624..4927995f47 100644 --- a/c++/src/capnp/rpc.capnp.c++ +++ b/c++/src/capnp/rpc.capnp.c++ @@ -259,7 +259,7 @@ static const uint16_t m_91b79f1f808db032[] = {1, 11, 8, 2, 13, 4, 12, 9, 7, 10, static const uint16_t i_91b79f1f808db032[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}; const ::capnp::_::RawSchema s_91b79f1f808db032 = { 0x91b79f1f808db032, b_91b79f1f808db032.words, 232, d_91b79f1f808db032, m_91b79f1f808db032, - 12, 14, i_91b79f1f808db032, nullptr, nullptr, { &s_91b79f1f808db032, nullptr, nullptr, 0, 0, nullptr } + 12, 14, i_91b79f1f808db032, nullptr, nullptr, { &s_91b79f1f808db032, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<51> b_e94ccf8031176ec4 = { @@ -321,10 +321,10 @@ static const uint16_t m_e94ccf8031176ec4[] = {1, 0}; static const uint16_t i_e94ccf8031176ec4[] = {0, 1}; const ::capnp::_::RawSchema s_e94ccf8031176ec4 = { 0xe94ccf8031176ec4, b_e94ccf8031176ec4.words, 51, nullptr, m_e94ccf8031176ec4, - 0, 2, i_e94ccf8031176ec4, nullptr, nullptr, { &s_e94ccf8031176ec4, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_e94ccf8031176ec4, nullptr, nullptr, { &s_e94ccf8031176ec4, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<121> b_836a53ce789d4cd4 = { +static const ::capnp::_::AlignedData<155> b_836a53ce789d4cd4 = { { 0, 0, 0, 0, 5, 0, 6, 0, 212, 76, 157, 120, 206, 83, 106, 131, 16, 0, 0, 0, 1, 0, 3, 0, @@ -334,63 +334,77 @@ static const ::capnp::_::AlignedData<121> b_836a53ce789d4cd4 = { 21, 0, 0, 0, 170, 0, 0, 0, 29, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 25, 0, 0, 0, 143, 1, 0, 0, + 25, 0, 0, 0, 255, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 114, 112, 99, 46, 99, 97, 112, 110, 112, 58, 67, 97, 108, 108, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 28, 0, 0, 0, 3, 0, 4, 0, + 36, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 181, 0, 0, 0, 90, 0, 0, 0, + 237, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 180, 0, 0, 0, 3, 0, 1, 0, - 192, 0, 0, 0, 2, 0, 1, 0, + 236, 0, 0, 0, 3, 0, 1, 0, + 248, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 189, 0, 0, 0, 58, 0, 0, 0, + 245, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 184, 0, 0, 0, 3, 0, 1, 0, - 196, 0, 0, 0, 2, 0, 1, 0, + 240, 0, 0, 0, 3, 0, 1, 0, + 252, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 193, 0, 0, 0, 98, 0, 0, 0, + 249, 0, 0, 0, 98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 192, 0, 0, 0, 3, 0, 1, 0, - 204, 0, 0, 0, 2, 0, 1, 0, + 248, 0, 0, 0, 3, 0, 1, 0, + 4, 1, 0, 0, 2, 0, 1, 0, 3, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 201, 0, 0, 0, 74, 0, 0, 0, + 1, 1, 0, 0, 74, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 200, 0, 0, 0, 3, 0, 1, 0, - 212, 0, 0, 0, 2, 0, 1, 0, - 5, 0, 0, 0, 1, 0, 0, 0, + 0, 1, 0, 0, 3, 0, 1, 0, + 12, 1, 0, 0, 2, 0, 1, 0, + 7, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 209, 0, 0, 0, 58, 0, 0, 0, + 9, 1, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 204, 0, 0, 0, 3, 0, 1, 0, - 216, 0, 0, 0, 2, 0, 1, 0, - 6, 0, 0, 0, 0, 0, 0, 0, + 4, 1, 0, 0, 3, 0, 1, 0, + 16, 1, 0, 0, 2, 0, 1, 0, + 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 153, 95, 171, 26, 246, 176, 232, 218, - 213, 0, 0, 0, 114, 0, 0, 0, + 13, 1, 0, 0, 114, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 128, 0, 0, 0, 0, 0, 1, 0, 8, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, - 193, 0, 0, 0, 194, 0, 0, 0, + 249, 0, 0, 0, 194, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 252, 0, 0, 0, 3, 0, 1, 0, + 8, 1, 0, 0, 2, 0, 1, 0, + 5, 0, 0, 0, 129, 0, 0, 0, + 0, 0, 1, 0, 9, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 5, 1, 0, 0, 162, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 8, 1, 0, 0, 3, 0, 1, 0, + 20, 1, 0, 0, 2, 0, 1, 0, + 6, 0, 0, 0, 130, 0, 0, 0, + 0, 0, 1, 0, 10, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 17, 1, 0, 0, 162, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 196, 0, 0, 0, 3, 0, 1, 0, - 208, 0, 0, 0, 2, 0, 1, 0, + 20, 1, 0, 0, 3, 0, 1, 0, + 32, 1, 0, 0, 2, 0, 1, 0, 113, 117, 101, 115, 116, 105, 111, 110, 73, 100, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, @@ -439,6 +453,26 @@ static const ::capnp::_::AlignedData<121> b_836a53ce789d4cd4 = { 97, 108, 108, 111, 119, 84, 104, 105, 114, 100, 80, 97, 114, 116, 121, 84, 97, 105, 108, 67, 97, 108, 108, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 110, 111, 80, 114, 111, 109, 105, 115, + 101, 80, 105, 112, 101, 108, 105, 110, + 105, 110, 103, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 111, 110, 108, 121, 80, 114, 111, 109, + 105, 115, 101, 80, 105, 112, 101, 108, + 105, 110, 101, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -454,11 +488,11 @@ static const ::capnp::_::RawSchema* const d_836a53ce789d4cd4[] = { &s_9a0e61223d96743b, &s_dae8b0f61aab5f99, }; -static const uint16_t m_836a53ce789d4cd4[] = {6, 2, 3, 4, 0, 5, 1}; -static const uint16_t i_836a53ce789d4cd4[] = {0, 1, 2, 3, 4, 5, 6}; +static const uint16_t m_836a53ce789d4cd4[] = {6, 2, 3, 7, 8, 4, 0, 5, 1}; +static const uint16_t i_836a53ce789d4cd4[] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; const ::capnp::_::RawSchema s_836a53ce789d4cd4 = { - 0x836a53ce789d4cd4, b_836a53ce789d4cd4.words, 121, d_836a53ce789d4cd4, m_836a53ce789d4cd4, - 3, 7, i_836a53ce789d4cd4, nullptr, nullptr, { &s_836a53ce789d4cd4, nullptr, nullptr, 0, 0, nullptr } + 0x836a53ce789d4cd4, b_836a53ce789d4cd4.words, 155, d_836a53ce789d4cd4, m_836a53ce789d4cd4, + 3, 9, i_836a53ce789d4cd4, nullptr, nullptr, { &s_836a53ce789d4cd4, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<65> b_dae8b0f61aab5f99 = { @@ -537,10 +571,10 @@ static const uint16_t m_dae8b0f61aab5f99[] = {0, 2, 1}; static const uint16_t i_dae8b0f61aab5f99[] = {0, 1, 2}; const ::capnp::_::RawSchema s_dae8b0f61aab5f99 = { 0xdae8b0f61aab5f99, b_dae8b0f61aab5f99.words, 65, d_dae8b0f61aab5f99, m_dae8b0f61aab5f99, - 1, 3, i_dae8b0f61aab5f99, nullptr, nullptr, { &s_dae8b0f61aab5f99, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_dae8b0f61aab5f99, nullptr, nullptr, { &s_dae8b0f61aab5f99, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<148> b_9e19b28d3db3573a = { +static const ::capnp::_::AlignedData<164> b_9e19b28d3db3573a = { { 0, 0, 0, 0, 5, 0, 6, 0, 58, 87, 179, 61, 141, 178, 25, 158, 16, 0, 0, 0, 1, 0, 2, 0, @@ -550,70 +584,77 @@ static const ::capnp::_::AlignedData<148> b_9e19b28d3db3573a = { 21, 0, 0, 0, 186, 0, 0, 0, 29, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 25, 0, 0, 0, 199, 1, 0, 0, + 25, 0, 0, 0, 255, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 114, 112, 99, 46, 99, 97, 112, 110, 112, 58, 82, 101, 116, 117, 114, 110, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 32, 0, 0, 0, 3, 0, 4, 0, + 36, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 209, 0, 0, 0, 74, 0, 0, 0, + 237, 0, 0, 0, 74, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 208, 0, 0, 0, 3, 0, 1, 0, - 220, 0, 0, 0, 2, 0, 1, 0, + 236, 0, 0, 0, 3, 0, 1, 0, + 248, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 32, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, - 217, 0, 0, 0, 138, 0, 0, 0, + 245, 0, 0, 0, 138, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 220, 0, 0, 0, 3, 0, 1, 0, - 232, 0, 0, 0, 2, 0, 1, 0, - 2, 0, 255, 255, 0, 0, 0, 0, + 248, 0, 0, 0, 3, 0, 1, 0, + 4, 1, 0, 0, 2, 0, 1, 0, + 3, 0, 255, 255, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 229, 0, 0, 0, 66, 0, 0, 0, + 1, 1, 0, 0, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 224, 0, 0, 0, 3, 0, 1, 0, - 236, 0, 0, 0, 2, 0, 1, 0, - 3, 0, 254, 255, 0, 0, 0, 0, + 252, 0, 0, 0, 3, 0, 1, 0, + 8, 1, 0, 0, 2, 0, 1, 0, + 4, 0, 254, 255, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 233, 0, 0, 0, 82, 0, 0, 0, + 5, 1, 0, 0, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 232, 0, 0, 0, 3, 0, 1, 0, - 244, 0, 0, 0, 2, 0, 1, 0, - 4, 0, 253, 255, 0, 0, 0, 0, + 4, 1, 0, 0, 3, 0, 1, 0, + 16, 1, 0, 0, 2, 0, 1, 0, + 5, 0, 253, 255, 0, 0, 0, 0, 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 241, 0, 0, 0, 74, 0, 0, 0, + 13, 1, 0, 0, 74, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 240, 0, 0, 0, 3, 0, 1, 0, - 252, 0, 0, 0, 2, 0, 1, 0, - 5, 0, 252, 255, 0, 0, 0, 0, + 12, 1, 0, 0, 3, 0, 1, 0, + 24, 1, 0, 0, 2, 0, 1, 0, + 6, 0, 252, 255, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 249, 0, 0, 0, 170, 0, 0, 0, + 21, 1, 0, 0, 170, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 252, 0, 0, 0, 3, 0, 1, 0, - 8, 1, 0, 0, 2, 0, 1, 0, - 6, 0, 251, 255, 2, 0, 0, 0, + 24, 1, 0, 0, 3, 0, 1, 0, + 36, 1, 0, 0, 2, 0, 1, 0, + 7, 0, 251, 255, 2, 0, 0, 0, 0, 0, 1, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 5, 1, 0, 0, 178, 0, 0, 0, + 33, 1, 0, 0, 178, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 8, 1, 0, 0, 3, 0, 1, 0, - 20, 1, 0, 0, 2, 0, 1, 0, - 7, 0, 250, 255, 0, 0, 0, 0, + 36, 1, 0, 0, 3, 0, 1, 0, + 48, 1, 0, 0, 2, 0, 1, 0, + 8, 0, 250, 255, 0, 0, 0, 0, 0, 0, 1, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 17, 1, 0, 0, 170, 0, 0, 0, + 45, 1, 0, 0, 170, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 20, 1, 0, 0, 3, 0, 1, 0, - 32, 1, 0, 0, 2, 0, 1, 0, + 48, 1, 0, 0, 3, 0, 1, 0, + 60, 1, 0, 0, 2, 0, 1, 0, + 2, 0, 0, 0, 33, 0, 0, 0, + 0, 0, 1, 0, 8, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 57, 1, 0, 0, 122, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 56, 1, 0, 0, 3, 0, 1, 0, + 68, 1, 0, 0, 2, 0, 1, 0, 97, 110, 115, 119, 101, 114, 73, 100, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, @@ -687,6 +728,15 @@ static const ::capnp::_::AlignedData<148> b_9e19b28d3db3573a = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 110, 111, 70, 105, 110, 105, 115, 104, + 78, 101, 101, 100, 101, 100, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; @@ -696,14 +746,14 @@ static const ::capnp::_::RawSchema* const d_9e19b28d3db3573a[] = { &s_9a0e61223d96743b, &s_d625b7063acf691a, }; -static const uint16_t m_9e19b28d3db3573a[] = {7, 0, 4, 3, 1, 2, 5, 6}; -static const uint16_t i_9e19b28d3db3573a[] = {2, 3, 4, 5, 6, 7, 0, 1}; +static const uint16_t m_9e19b28d3db3573a[] = {7, 0, 4, 3, 8, 1, 2, 5, 6}; +static const uint16_t i_9e19b28d3db3573a[] = {2, 3, 4, 5, 6, 7, 0, 1, 8}; const ::capnp::_::RawSchema s_9e19b28d3db3573a = { - 0x9e19b28d3db3573a, b_9e19b28d3db3573a.words, 148, d_9e19b28d3db3573a, m_9e19b28d3db3573a, - 2, 8, i_9e19b28d3db3573a, nullptr, nullptr, { &s_9e19b28d3db3573a, nullptr, nullptr, 0, 0, nullptr } + 0x9e19b28d3db3573a, b_9e19b28d3db3573a.words, 164, d_9e19b28d3db3573a, m_9e19b28d3db3573a, + 2, 9, i_9e19b28d3db3573a, nullptr, nullptr, { &s_9e19b28d3db3573a, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<50> b_d37d2eb2c2f80e63 = { +static const ::capnp::_::AlignedData<69> b_d37d2eb2c2f80e63 = { { 0, 0, 0, 0, 5, 0, 6, 0, 99, 14, 248, 194, 178, 46, 125, 211, 16, 0, 0, 0, 1, 0, 1, 0, @@ -713,28 +763,35 @@ static const ::capnp::_::AlignedData<50> b_d37d2eb2c2f80e63 = { 21, 0, 0, 0, 186, 0, 0, 0, 29, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 25, 0, 0, 0, 119, 0, 0, 0, + 25, 0, 0, 0, 175, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 114, 112, 99, 46, 99, 97, 112, 110, 112, 58, 70, 105, 110, 105, 115, 104, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 8, 0, 0, 0, 3, 0, 4, 0, + 12, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 41, 0, 0, 0, 90, 0, 0, 0, + 69, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 40, 0, 0, 0, 3, 0, 1, 0, - 52, 0, 0, 0, 2, 0, 1, 0, + 68, 0, 0, 0, 3, 0, 1, 0, + 80, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 32, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, - 49, 0, 0, 0, 146, 0, 0, 0, + 77, 0, 0, 0, 146, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 52, 0, 0, 0, 3, 0, 1, 0, - 64, 0, 0, 0, 2, 0, 1, 0, + 80, 0, 0, 0, 3, 0, 1, 0, + 92, 0, 0, 0, 2, 0, 1, 0, + 2, 0, 0, 0, 33, 0, 0, 0, + 0, 0, 1, 0, 2, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 89, 0, 0, 0, 26, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 100, 0, 0, 0, 3, 0, 1, 0, + 112, 0, 0, 0, 2, 0, 1, 0, 113, 117, 101, 115, 116, 105, 111, 110, 73, 100, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, @@ -747,6 +804,18 @@ static const ::capnp::_::AlignedData<50> b_d37d2eb2c2f80e63 = { 114, 101, 108, 101, 97, 115, 101, 82, 101, 115, 117, 108, 116, 67, 97, 112, 115, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 114, 101, 113, 117, 105, 114, 101, 69, + 97, 114, 108, 121, 67, 97, 110, 99, + 101, 108, 108, 97, 116, 105, 111, 110, + 87, 111, 114, 107, 97, 114, 111, 117, + 110, 100, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -757,11 +826,11 @@ static const ::capnp::_::AlignedData<50> b_d37d2eb2c2f80e63 = { }; ::capnp::word const* const bp_d37d2eb2c2f80e63 = b_d37d2eb2c2f80e63.words; #if !CAPNP_LITE -static const uint16_t m_d37d2eb2c2f80e63[] = {0, 1}; -static const uint16_t i_d37d2eb2c2f80e63[] = {0, 1}; +static const uint16_t m_d37d2eb2c2f80e63[] = {0, 1, 2}; +static const uint16_t i_d37d2eb2c2f80e63[] = {0, 1, 2}; const ::capnp::_::RawSchema s_d37d2eb2c2f80e63 = { - 0xd37d2eb2c2f80e63, b_d37d2eb2c2f80e63.words, 50, nullptr, m_d37d2eb2c2f80e63, - 0, 2, i_d37d2eb2c2f80e63, nullptr, nullptr, { &s_d37d2eb2c2f80e63, nullptr, nullptr, 0, 0, nullptr } + 0xd37d2eb2c2f80e63, b_d37d2eb2c2f80e63.words, 69, nullptr, m_d37d2eb2c2f80e63, + 0, 3, i_d37d2eb2c2f80e63, nullptr, nullptr, { &s_d37d2eb2c2f80e63, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<64> b_bbc29655fa89086e = { @@ -840,7 +909,7 @@ static const uint16_t m_bbc29655fa89086e[] = {1, 2, 0}; static const uint16_t i_bbc29655fa89086e[] = {1, 2, 0}; const ::capnp::_::RawSchema s_bbc29655fa89086e = { 0xbbc29655fa89086e, b_bbc29655fa89086e.words, 64, d_bbc29655fa89086e, m_bbc29655fa89086e, - 2, 3, i_bbc29655fa89086e, nullptr, nullptr, { &s_bbc29655fa89086e, nullptr, nullptr, 0, 0, nullptr } + 2, 3, i_bbc29655fa89086e, nullptr, nullptr, { &s_bbc29655fa89086e, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<48> b_ad1a6c0d7dd07497 = { @@ -899,7 +968,7 @@ static const uint16_t m_ad1a6c0d7dd07497[] = {0, 1}; static const uint16_t i_ad1a6c0d7dd07497[] = {0, 1}; const ::capnp::_::RawSchema s_ad1a6c0d7dd07497 = { 0xad1a6c0d7dd07497, b_ad1a6c0d7dd07497.words, 48, nullptr, m_ad1a6c0d7dd07497, - 0, 2, i_ad1a6c0d7dd07497, nullptr, nullptr, { &s_ad1a6c0d7dd07497, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_ad1a6c0d7dd07497, nullptr, nullptr, { &s_ad1a6c0d7dd07497, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<41> b_f964368b0fbd3711 = { @@ -955,7 +1024,7 @@ static const uint16_t m_f964368b0fbd3711[] = {1, 0}; static const uint16_t i_f964368b0fbd3711[] = {0, 1}; const ::capnp::_::RawSchema s_f964368b0fbd3711 = { 0xf964368b0fbd3711, b_f964368b0fbd3711.words, 41, d_f964368b0fbd3711, m_f964368b0fbd3711, - 2, 2, i_f964368b0fbd3711, nullptr, nullptr, { &s_f964368b0fbd3711, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_f964368b0fbd3711, nullptr, nullptr, { &s_f964368b0fbd3711, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<81> b_d562b4df655bdd4d = { @@ -1050,7 +1119,7 @@ static const uint16_t m_d562b4df655bdd4d[] = {2, 3, 1, 0}; static const uint16_t i_d562b4df655bdd4d[] = {0, 1, 2, 3}; const ::capnp::_::RawSchema s_d562b4df655bdd4d = { 0xd562b4df655bdd4d, b_d562b4df655bdd4d.words, 81, d_d562b4df655bdd4d, m_d562b4df655bdd4d, - 1, 4, i_d562b4df655bdd4d, nullptr, nullptr, { &s_d562b4df655bdd4d, nullptr, nullptr, 0, 0, nullptr } + 1, 4, i_d562b4df655bdd4d, nullptr, nullptr, { &s_d562b4df655bdd4d, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<64> b_9c6a046bfbc1ac5a = { @@ -1128,7 +1197,7 @@ static const uint16_t m_9c6a046bfbc1ac5a[] = {0, 2, 1}; static const uint16_t i_9c6a046bfbc1ac5a[] = {0, 1, 2}; const ::capnp::_::RawSchema s_9c6a046bfbc1ac5a = { 0x9c6a046bfbc1ac5a, b_9c6a046bfbc1ac5a.words, 64, d_9c6a046bfbc1ac5a, m_9c6a046bfbc1ac5a, - 1, 3, i_9c6a046bfbc1ac5a, nullptr, nullptr, { &s_9c6a046bfbc1ac5a, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_9c6a046bfbc1ac5a, nullptr, nullptr, { &s_9c6a046bfbc1ac5a, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<64> b_d4c9b56290554016 = { @@ -1203,7 +1272,7 @@ static const uint16_t m_d4c9b56290554016[] = {2, 1, 0}; static const uint16_t i_d4c9b56290554016[] = {0, 1, 2}; const ::capnp::_::RawSchema s_d4c9b56290554016 = { 0xd4c9b56290554016, b_d4c9b56290554016.words, 64, nullptr, m_d4c9b56290554016, - 0, 3, i_d4c9b56290554016, nullptr, nullptr, { &s_d4c9b56290554016, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_d4c9b56290554016, nullptr, nullptr, { &s_d4c9b56290554016, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<63> b_fbe1980490e001af = { @@ -1280,7 +1349,7 @@ static const uint16_t m_fbe1980490e001af[] = {2, 0, 1}; static const uint16_t i_fbe1980490e001af[] = {0, 1, 2}; const ::capnp::_::RawSchema s_fbe1980490e001af = { 0xfbe1980490e001af, b_fbe1980490e001af.words, 63, d_fbe1980490e001af, m_fbe1980490e001af, - 1, 3, i_fbe1980490e001af, nullptr, nullptr, { &s_fbe1980490e001af, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_fbe1980490e001af, nullptr, nullptr, { &s_fbe1980490e001af, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<50> b_95bc14545813fbc1 = { @@ -1344,7 +1413,7 @@ static const uint16_t m_95bc14545813fbc1[] = {0, 1}; static const uint16_t i_95bc14545813fbc1[] = {0, 1}; const ::capnp::_::RawSchema s_95bc14545813fbc1 = { 0x95bc14545813fbc1, b_95bc14545813fbc1.words, 50, d_95bc14545813fbc1, m_95bc14545813fbc1, - 1, 2, i_95bc14545813fbc1, nullptr, nullptr, { &s_95bc14545813fbc1, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_95bc14545813fbc1, nullptr, nullptr, { &s_95bc14545813fbc1, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<52> b_9a0e61223d96743b = { @@ -1410,10 +1479,10 @@ static const uint16_t m_9a0e61223d96743b[] = {1, 0}; static const uint16_t i_9a0e61223d96743b[] = {0, 1}; const ::capnp::_::RawSchema s_9a0e61223d96743b = { 0x9a0e61223d96743b, b_9a0e61223d96743b.words, 52, d_9a0e61223d96743b, m_9a0e61223d96743b, - 1, 2, i_9a0e61223d96743b, nullptr, nullptr, { &s_9a0e61223d96743b, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_9a0e61223d96743b, nullptr, nullptr, { &s_9a0e61223d96743b, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = { +static const ::capnp::_::AlignedData<130> b_8523ddc40b86b8b0 = { { 0, 0, 0, 0, 5, 0, 6, 0, 176, 184, 134, 11, 196, 221, 35, 133, 16, 0, 0, 0, 1, 0, 1, 0, @@ -1423,7 +1492,7 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = { 21, 0, 0, 0, 242, 0, 0, 0, 33, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 29, 0, 0, 0, 87, 1, 0, 0, + 29, 0, 0, 0, 143, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 114, 112, @@ -1431,49 +1500,56 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = { 67, 97, 112, 68, 101, 115, 99, 114, 105, 112, 116, 111, 114, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 24, 0, 0, 0, 3, 0, 4, 0, + 28, 0, 0, 0, 3, 0, 4, 0, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 153, 0, 0, 0, 42, 0, 0, 0, + 181, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 148, 0, 0, 0, 3, 0, 1, 0, - 160, 0, 0, 0, 2, 0, 1, 0, + 176, 0, 0, 0, 3, 0, 1, 0, + 188, 0, 0, 0, 2, 0, 1, 0, 1, 0, 254, 255, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 157, 0, 0, 0, 106, 0, 0, 0, + 185, 0, 0, 0, 106, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 156, 0, 0, 0, 3, 0, 1, 0, - 168, 0, 0, 0, 2, 0, 1, 0, + 184, 0, 0, 0, 3, 0, 1, 0, + 196, 0, 0, 0, 2, 0, 1, 0, 2, 0, 253, 255, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 165, 0, 0, 0, 114, 0, 0, 0, + 193, 0, 0, 0, 114, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 164, 0, 0, 0, 3, 0, 1, 0, - 176, 0, 0, 0, 2, 0, 1, 0, + 192, 0, 0, 0, 3, 0, 1, 0, + 204, 0, 0, 0, 2, 0, 1, 0, 3, 0, 252, 255, 1, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 173, 0, 0, 0, 122, 0, 0, 0, + 201, 0, 0, 0, 122, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 172, 0, 0, 0, 3, 0, 1, 0, - 184, 0, 0, 0, 2, 0, 1, 0, + 200, 0, 0, 0, 3, 0, 1, 0, + 212, 0, 0, 0, 2, 0, 1, 0, 4, 0, 251, 255, 0, 0, 0, 0, 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 181, 0, 0, 0, 122, 0, 0, 0, + 209, 0, 0, 0, 122, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 180, 0, 0, 0, 3, 0, 1, 0, - 192, 0, 0, 0, 2, 0, 1, 0, + 208, 0, 0, 0, 3, 0, 1, 0, + 220, 0, 0, 0, 2, 0, 1, 0, 5, 0, 250, 255, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 189, 0, 0, 0, 138, 0, 0, 0, + 217, 0, 0, 0, 138, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 192, 0, 0, 0, 3, 0, 1, 0, - 204, 0, 0, 0, 2, 0, 1, 0, + 220, 0, 0, 0, 3, 0, 1, 0, + 232, 0, 0, 0, 2, 0, 1, 0, + 6, 0, 0, 0, 2, 0, 0, 0, + 0, 0, 1, 0, 6, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 229, 0, 0, 0, 90, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 228, 0, 0, 0, 3, 0, 1, 0, + 240, 0, 0, 0, 2, 0, 1, 0, 110, 111, 110, 101, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -1526,6 +1602,15 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 97, 116, 116, 97, 99, 104, 101, 100, + 70, 100, 0, 0, 0, 0, 0, 0, + 6, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 6, 0, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; @@ -1535,11 +1620,11 @@ static const ::capnp::_::RawSchema* const d_8523ddc40b86b8b0[] = { &s_d37007fde1f0027d, &s_d800b1d6cd6f1ca0, }; -static const uint16_t m_8523ddc40b86b8b0[] = {0, 4, 3, 1, 2, 5}; -static const uint16_t i_8523ddc40b86b8b0[] = {0, 1, 2, 3, 4, 5}; +static const uint16_t m_8523ddc40b86b8b0[] = {6, 0, 4, 3, 1, 2, 5}; +static const uint16_t i_8523ddc40b86b8b0[] = {0, 1, 2, 3, 4, 5, 6}; const ::capnp::_::RawSchema s_8523ddc40b86b8b0 = { - 0x8523ddc40b86b8b0, b_8523ddc40b86b8b0.words, 114, d_8523ddc40b86b8b0, m_8523ddc40b86b8b0, - 2, 6, i_8523ddc40b86b8b0, nullptr, nullptr, { &s_8523ddc40b86b8b0, nullptr, nullptr, 0, 0, nullptr } + 0x8523ddc40b86b8b0, b_8523ddc40b86b8b0.words, 130, d_8523ddc40b86b8b0, m_8523ddc40b86b8b0, + 2, 7, i_8523ddc40b86b8b0, nullptr, nullptr, { &s_8523ddc40b86b8b0, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<57> b_d800b1d6cd6f1ca0 = { @@ -1610,7 +1695,7 @@ static const uint16_t m_d800b1d6cd6f1ca0[] = {0, 1}; static const uint16_t i_d800b1d6cd6f1ca0[] = {0, 1}; const ::capnp::_::RawSchema s_d800b1d6cd6f1ca0 = { 0xd800b1d6cd6f1ca0, b_d800b1d6cd6f1ca0.words, 57, d_d800b1d6cd6f1ca0, m_d800b1d6cd6f1ca0, - 1, 2, i_d800b1d6cd6f1ca0, nullptr, nullptr, { &s_d800b1d6cd6f1ca0, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_d800b1d6cd6f1ca0, nullptr, nullptr, { &s_d800b1d6cd6f1ca0, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<50> b_f316944415569081 = { @@ -1671,7 +1756,7 @@ static const uint16_t m_f316944415569081[] = {1, 0}; static const uint16_t i_f316944415569081[] = {0, 1}; const ::capnp::_::RawSchema s_f316944415569081 = { 0xf316944415569081, b_f316944415569081.words, 50, nullptr, m_f316944415569081, - 0, 2, i_f316944415569081, nullptr, nullptr, { &s_f316944415569081, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_f316944415569081, nullptr, nullptr, { &s_f316944415569081, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_d37007fde1f0027d = { @@ -1731,20 +1816,20 @@ static const uint16_t m_d37007fde1f0027d[] = {0, 1}; static const uint16_t i_d37007fde1f0027d[] = {0, 1}; const ::capnp::_::RawSchema s_d37007fde1f0027d = { 0xd37007fde1f0027d, b_d37007fde1f0027d.words, 49, nullptr, m_d37007fde1f0027d, - 0, 2, i_d37007fde1f0027d, nullptr, nullptr, { &s_d37007fde1f0027d, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_d37007fde1f0027d, nullptr, nullptr, { &s_d37007fde1f0027d, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<85> b_d625b7063acf691a = { +static const ::capnp::_::AlignedData<100> b_d625b7063acf691a = { { 0, 0, 0, 0, 5, 0, 6, 0, 26, 105, 207, 58, 6, 183, 37, 214, 16, 0, 0, 0, 1, 0, 1, 0, 80, 162, 82, 37, 27, 152, 18, 179, - 1, 0, 7, 0, 0, 0, 0, 0, + 2, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 210, 0, 0, 0, 33, 0, 0, 0, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 41, 0, 0, 0, 231, 0, 0, 0, + 41, 0, 0, 0, 31, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 114, 112, @@ -1755,35 +1840,42 @@ static const ::capnp::_::AlignedData<85> b_d625b7063acf691a = { 88, 189, 76, 63, 226, 150, 140, 178, 1, 0, 0, 0, 42, 0, 0, 0, 84, 121, 112, 101, 0, 0, 0, 0, - 16, 0, 0, 0, 3, 0, 4, 0, + 20, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 97, 0, 0, 0, 58, 0, 0, 0, + 125, 0, 0, 0, 58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 92, 0, 0, 0, 3, 0, 1, 0, - 104, 0, 0, 0, 2, 0, 1, 0, + 120, 0, 0, 0, 3, 0, 1, 0, + 132, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 101, 0, 0, 0, 186, 0, 0, 0, + 129, 0, 0, 0, 186, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 104, 0, 0, 0, 3, 0, 1, 0, - 116, 0, 0, 0, 2, 0, 1, 0, + 132, 0, 0, 0, 3, 0, 1, 0, + 144, 0, 0, 0, 2, 0, 1, 0, 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 113, 0, 0, 0, 154, 0, 0, 0, + 141, 0, 0, 0, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 116, 0, 0, 0, 3, 0, 1, 0, - 128, 0, 0, 0, 2, 0, 1, 0, + 144, 0, 0, 0, 3, 0, 1, 0, + 156, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 125, 0, 0, 0, 42, 0, 0, 0, + 153, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 120, 0, 0, 0, 3, 0, 1, 0, - 132, 0, 0, 0, 2, 0, 1, 0, + 148, 0, 0, 0, 3, 0, 1, 0, + 160, 0, 0, 0, 2, 0, 1, 0, + 4, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 1, 0, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 157, 0, 0, 0, 50, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 152, 0, 0, 0, 3, 0, 1, 0, + 164, 0, 0, 0, 2, 0, 1, 0, 114, 101, 97, 115, 111, 110, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -1818,6 +1910,14 @@ static const ::capnp::_::AlignedData<85> b_d625b7063acf691a = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 116, 114, 97, 99, 101, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; @@ -1826,11 +1926,11 @@ static const ::capnp::_::AlignedData<85> b_d625b7063acf691a = { static const ::capnp::_::RawSchema* const d_d625b7063acf691a[] = { &s_b28c96e23f4cbd58, }; -static const uint16_t m_d625b7063acf691a[] = {2, 1, 0, 3}; -static const uint16_t i_d625b7063acf691a[] = {0, 1, 2, 3}; +static const uint16_t m_d625b7063acf691a[] = {2, 1, 0, 4, 3}; +static const uint16_t i_d625b7063acf691a[] = {0, 1, 2, 3, 4}; const ::capnp::_::RawSchema s_d625b7063acf691a = { - 0xd625b7063acf691a, b_d625b7063acf691a.words, 85, d_d625b7063acf691a, m_d625b7063acf691a, - 1, 4, i_d625b7063acf691a, nullptr, nullptr, { &s_d625b7063acf691a, nullptr, nullptr, 0, 0, nullptr } + 0xd625b7063acf691a, b_d625b7063acf691a.words, 100, d_d625b7063acf691a, m_d625b7063acf691a, + 1, 5, i_d625b7063acf691a, nullptr, nullptr, { &s_d625b7063acf691a, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<37> b_b28c96e23f4cbd58 = { @@ -1877,7 +1977,7 @@ static const ::capnp::_::AlignedData<37> b_b28c96e23f4cbd58 = { static const uint16_t m_b28c96e23f4cbd58[] = {2, 0, 1, 3}; const ::capnp::_::RawSchema s_b28c96e23f4cbd58 = { 0xb28c96e23f4cbd58, b_b28c96e23f4cbd58.words, 37, nullptr, m_b28c96e23f4cbd58, - 0, 4, nullptr, nullptr, nullptr, { &s_b28c96e23f4cbd58, nullptr, nullptr, 0, 0, nullptr } + 0, 4, nullptr, nullptr, nullptr, { &s_b28c96e23f4cbd58, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE CAPNP_DEFINE_ENUM(Type_b28c96e23f4cbd58, b28c96e23f4cbd58); @@ -1890,163 +1990,243 @@ namespace capnp { namespace rpc { // Message +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Message::_capnpPrivate::dataWordSize; constexpr uint16_t Message::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Message::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Message::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Bootstrap +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Bootstrap::_capnpPrivate::dataWordSize; constexpr uint16_t Bootstrap::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Bootstrap::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Bootstrap::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Call +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Call::_capnpPrivate::dataWordSize; constexpr uint16_t Call::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Call::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Call::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Call::SendResultsTo +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Call::SendResultsTo::_capnpPrivate::dataWordSize; constexpr uint16_t Call::SendResultsTo::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Call::SendResultsTo::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Call::SendResultsTo::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Return +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Return::_capnpPrivate::dataWordSize; constexpr uint16_t Return::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Return::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Return::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Finish +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Finish::_capnpPrivate::dataWordSize; constexpr uint16_t Finish::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Finish::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Finish::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Resolve +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Resolve::_capnpPrivate::dataWordSize; constexpr uint16_t Resolve::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Resolve::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Resolve::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Release +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Release::_capnpPrivate::dataWordSize; constexpr uint16_t Release::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Release::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Release::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Disembargo +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Disembargo::_capnpPrivate::dataWordSize; constexpr uint16_t Disembargo::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Disembargo::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Disembargo::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Disembargo::Context +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Disembargo::Context::_capnpPrivate::dataWordSize; constexpr uint16_t Disembargo::Context::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Disembargo::Context::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Disembargo::Context::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Provide +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Provide::_capnpPrivate::dataWordSize; constexpr uint16_t Provide::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Provide::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Provide::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Accept +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Accept::_capnpPrivate::dataWordSize; constexpr uint16_t Accept::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Accept::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Accept::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Join +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Join::_capnpPrivate::dataWordSize; constexpr uint16_t Join::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Join::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Join::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // MessageTarget +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t MessageTarget::_capnpPrivate::dataWordSize; constexpr uint16_t MessageTarget::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind MessageTarget::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* MessageTarget::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Payload +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Payload::_capnpPrivate::dataWordSize; constexpr uint16_t Payload::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Payload::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Payload::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CapDescriptor +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CapDescriptor::_capnpPrivate::dataWordSize; constexpr uint16_t CapDescriptor::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CapDescriptor::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CapDescriptor::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // PromisedAnswer +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t PromisedAnswer::_capnpPrivate::dataWordSize; constexpr uint16_t PromisedAnswer::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind PromisedAnswer::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* PromisedAnswer::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // PromisedAnswer::Op +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t PromisedAnswer::Op::_capnpPrivate::dataWordSize; constexpr uint16_t PromisedAnswer::Op::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind PromisedAnswer::Op::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* PromisedAnswer::Op::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // ThirdPartyCapDescriptor +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t ThirdPartyCapDescriptor::_capnpPrivate::dataWordSize; constexpr uint16_t ThirdPartyCapDescriptor::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind ThirdPartyCapDescriptor::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* ThirdPartyCapDescriptor::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Exception +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Exception::_capnpPrivate::dataWordSize; constexpr uint16_t Exception::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Exception::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Exception::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/c++/src/capnp/rpc.capnp.h b/c++/src/capnp/rpc.capnp.h index 0a440397fc..05ed0124ae 100644 --- a/c++/src/capnp/rpc.capnp.h +++ b/c++/src/capnp/rpc.capnp.h @@ -1,16 +1,20 @@ // Generated by Cap'n Proto compiler, DO NOT EDIT // source: rpc.capnp -#ifndef CAPNP_INCLUDED_b312981b2552a250_ -#define CAPNP_INCLUDED_b312981b2552a250_ +#pragma once #include +#include -#if CAPNP_VERSION != 6001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace schemas { @@ -402,7 +406,7 @@ struct Exception { struct _capnpPrivate { - CAPNP_DECLARE_STRUCT_HEADER(d625b7063acf691a, 1, 1) + CAPNP_DECLARE_STRUCT_HEADER(d625b7063acf691a, 1, 2) #if !CAPNP_LITE static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } #endif // !CAPNP_LITE @@ -762,6 +766,10 @@ class Call::Reader { inline bool getAllowThirdPartyTailCall() const; + inline bool getNoPromisePipelining() const; + + inline bool getOnlyPromisePipeline() const; + private: ::capnp::_::StructReader _reader; template @@ -819,6 +827,12 @@ class Call::Builder { inline bool getAllowThirdPartyTailCall(); inline void setAllowThirdPartyTailCall(bool value); + inline bool getNoPromisePipelining(); + inline void setNoPromisePipelining(bool value); + + inline bool getOnlyPromisePipeline(); + inline void setOnlyPromisePipeline(bool value); + private: ::capnp::_::StructBuilder _builder; template @@ -987,6 +1001,8 @@ class Return::Reader { inline bool hasAcceptFromThirdParty() const; inline ::capnp::AnyPointer::Reader getAcceptFromThirdParty() const; + inline bool getNoFinishNeeded() const; + private: ::capnp::_::StructReader _reader; template @@ -1055,6 +1071,9 @@ class Return::Builder { inline ::capnp::AnyPointer::Builder getAcceptFromThirdParty(); inline ::capnp::AnyPointer::Builder initAcceptFromThirdParty(); + inline bool getNoFinishNeeded(); + inline void setNoFinishNeeded(bool value); + private: ::capnp::_::StructBuilder _builder; template @@ -1102,6 +1121,8 @@ class Finish::Reader { inline bool getReleaseResultCaps() const; + inline bool getRequireEarlyCancellationWorkaround() const; + private: ::capnp::_::StructReader _reader; template @@ -1136,6 +1157,9 @@ class Finish::Builder { inline bool getReleaseResultCaps(); inline void setReleaseResultCaps(bool value); + inline bool getRequireEarlyCancellationWorkaround(); + inline void setRequireEarlyCancellationWorkaround(bool value); + private: ::capnp::_::StructBuilder _builder; template @@ -1923,7 +1947,7 @@ class Payload::Reader { inline ::capnp::AnyPointer::Reader getContent() const; inline bool hasCapTable() const; - inline ::capnp::List< ::capnp::rpc::CapDescriptor>::Reader getCapTable() const; + inline ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>::Reader getCapTable() const; private: ::capnp::_::StructReader _reader; @@ -1958,11 +1982,11 @@ class Payload::Builder { inline ::capnp::AnyPointer::Builder initContent(); inline bool hasCapTable(); - inline ::capnp::List< ::capnp::rpc::CapDescriptor>::Builder getCapTable(); - inline void setCapTable( ::capnp::List< ::capnp::rpc::CapDescriptor>::Reader value); - inline ::capnp::List< ::capnp::rpc::CapDescriptor>::Builder initCapTable(unsigned int size); - inline void adoptCapTable(::capnp::Orphan< ::capnp::List< ::capnp::rpc::CapDescriptor>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::rpc::CapDescriptor>> disownCapTable(); + inline ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>::Builder getCapTable(); + inline void setCapTable( ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>::Builder initCapTable(unsigned int size); + inline void adoptCapTable(::capnp::Orphan< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>> disownCapTable(); private: ::capnp::_::StructBuilder _builder; @@ -2028,6 +2052,8 @@ class CapDescriptor::Reader { inline bool hasThirdPartyHosted() const; inline ::capnp::rpc::ThirdPartyCapDescriptor::Reader getThirdPartyHosted() const; + inline ::uint8_t getAttachedFd() const; + private: ::capnp::_::StructReader _reader; template @@ -2089,6 +2115,9 @@ class CapDescriptor::Builder { inline void adoptThirdPartyHosted(::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor>&& value); inline ::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor> disownThirdPartyHosted(); + inline ::uint8_t getAttachedFd(); + inline void setAttachedFd( ::uint8_t value); + private: ::capnp::_::StructBuilder _builder; template @@ -2135,7 +2164,7 @@ class PromisedAnswer::Reader { inline ::uint32_t getQuestionId() const; inline bool hasTransform() const; - inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>::Reader getTransform() const; + inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>::Reader getTransform() const; private: ::capnp::_::StructReader _reader; @@ -2169,11 +2198,11 @@ class PromisedAnswer::Builder { inline void setQuestionId( ::uint32_t value); inline bool hasTransform(); - inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>::Builder getTransform(); - inline void setTransform( ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>::Reader value); - inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>::Builder initTransform(unsigned int size); - inline void adoptTransform(::capnp::Orphan< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>> disownTransform(); + inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>::Builder getTransform(); + inline void setTransform( ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>::Builder initTransform(unsigned int size); + inline void adoptTransform(::capnp::Orphan< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>> disownTransform(); private: ::capnp::_::StructBuilder _builder; @@ -2397,6 +2426,9 @@ class Exception::Reader { inline ::capnp::rpc::Exception::Type getType() const; + inline bool hasTrace() const; + inline ::capnp::Text::Reader getTrace() const; + private: ::capnp::_::StructReader _reader; template @@ -2441,6 +2473,13 @@ class Exception::Builder { inline ::capnp::rpc::Exception::Type getType(); inline void setType( ::capnp::rpc::Exception::Type value); + inline bool hasTrace(); + inline ::capnp::Text::Builder getTrace(); + inline void setTrace( ::capnp::Text::Reader value); + inline ::capnp::Text::Builder initTrace(unsigned int size); + inline void adoptTrace(::capnp::Orphan< ::capnp::Text>&& value); + inline ::capnp::Orphan< ::capnp::Text> disownTrace(); + private: ::capnp::_::StructBuilder _builder; template @@ -3387,6 +3426,34 @@ inline void Call::Builder::setAllowThirdPartyTailCall(bool value) { ::capnp::bounded<128>() * ::capnp::ELEMENTS, value); } +inline bool Call::Reader::getNoPromisePipelining() const { + return _reader.getDataField( + ::capnp::bounded<129>() * ::capnp::ELEMENTS); +} + +inline bool Call::Builder::getNoPromisePipelining() { + return _builder.getDataField( + ::capnp::bounded<129>() * ::capnp::ELEMENTS); +} +inline void Call::Builder::setNoPromisePipelining(bool value) { + _builder.setDataField( + ::capnp::bounded<129>() * ::capnp::ELEMENTS, value); +} + +inline bool Call::Reader::getOnlyPromisePipeline() const { + return _reader.getDataField( + ::capnp::bounded<130>() * ::capnp::ELEMENTS); +} + +inline bool Call::Builder::getOnlyPromisePipeline() { + return _builder.getDataField( + ::capnp::bounded<130>() * ::capnp::ELEMENTS); +} +inline void Call::Builder::setOnlyPromisePipeline(bool value) { + _builder.setDataField( + ::capnp::bounded<130>() * ::capnp::ELEMENTS, value); +} + inline ::capnp::rpc::Call::SendResultsTo::Which Call::SendResultsTo::Reader::which() const { return _reader.getDataField( ::capnp::bounded<3>() * ::capnp::ELEMENTS); @@ -3745,6 +3812,20 @@ inline ::capnp::AnyPointer::Builder Return::Builder::initAcceptFromThirdParty() return result; } +inline bool Return::Reader::getNoFinishNeeded() const { + return _reader.getDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS); +} + +inline bool Return::Builder::getNoFinishNeeded() { + return _builder.getDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS); +} +inline void Return::Builder::setNoFinishNeeded(bool value) { + _builder.setDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS, value); +} + inline ::uint32_t Finish::Reader::getQuestionId() const { return _reader.getDataField< ::uint32_t>( ::capnp::bounded<0>() * ::capnp::ELEMENTS); @@ -3773,6 +3854,20 @@ inline void Finish::Builder::setReleaseResultCaps(bool value) { ::capnp::bounded<32>() * ::capnp::ELEMENTS, value, true); } +inline bool Finish::Reader::getRequireEarlyCancellationWorkaround() const { + return _reader.getDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS, true); +} + +inline bool Finish::Builder::getRequireEarlyCancellationWorkaround() { + return _builder.getDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS, true); +} +inline void Finish::Builder::setRequireEarlyCancellationWorkaround(bool value) { + _builder.setDataField( + ::capnp::bounded<33>() * ::capnp::ELEMENTS, value, true); +} + inline ::capnp::rpc::Resolve::Which Resolve::Reader::which() const { return _reader.getDataField( ::capnp::bounded<2>() * ::capnp::ELEMENTS); @@ -4423,29 +4518,29 @@ inline bool Payload::Builder::hasCapTable() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::rpc::CapDescriptor>::Reader Payload::Reader::getCapTable() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>::Reader Payload::Reader::getCapTable() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::rpc::CapDescriptor>::Builder Payload::Builder::getCapTable() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>::Builder Payload::Builder::getCapTable() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void Payload::Builder::setCapTable( ::capnp::List< ::capnp::rpc::CapDescriptor>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor>>::set(_builder.getPointerField( +inline void Payload::Builder::setCapTable( ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::rpc::CapDescriptor>::Builder Payload::Builder::initCapTable(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>::Builder Payload::Builder::initCapTable(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void Payload::Builder::adoptCapTable( - ::capnp::Orphan< ::capnp::List< ::capnp::rpc::CapDescriptor>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::rpc::CapDescriptor>> Payload::Builder::disownCapTable() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>> Payload::Builder::disownCapTable() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::CapDescriptor, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -4670,6 +4765,20 @@ inline ::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor> CapDescriptor::Bu ::capnp::bounded<0>() * ::capnp::POINTERS)); } +inline ::uint8_t CapDescriptor::Reader::getAttachedFd() const { + return _reader.getDataField< ::uint8_t>( + ::capnp::bounded<2>() * ::capnp::ELEMENTS, 255u); +} + +inline ::uint8_t CapDescriptor::Builder::getAttachedFd() { + return _builder.getDataField< ::uint8_t>( + ::capnp::bounded<2>() * ::capnp::ELEMENTS, 255u); +} +inline void CapDescriptor::Builder::setAttachedFd( ::uint8_t value) { + _builder.setDataField< ::uint8_t>( + ::capnp::bounded<2>() * ::capnp::ELEMENTS, value, 255u); +} + inline ::uint32_t PromisedAnswer::Reader::getQuestionId() const { return _reader.getDataField< ::uint32_t>( ::capnp::bounded<0>() * ::capnp::ELEMENTS); @@ -4692,29 +4801,29 @@ inline bool PromisedAnswer::Builder::hasTransform() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>::Reader PromisedAnswer::Reader::getTransform() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>::Reader PromisedAnswer::Reader::getTransform() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>::Builder PromisedAnswer::Builder::getTransform() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>::Builder PromisedAnswer::Builder::getTransform() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void PromisedAnswer::Builder::setTransform( ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>>::set(_builder.getPointerField( +inline void PromisedAnswer::Builder::setTransform( ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>::Builder PromisedAnswer::Builder::initTransform(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>::Builder PromisedAnswer::Builder::initTransform(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void PromisedAnswer::Builder::adoptTransform( - ::capnp::Orphan< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>> PromisedAnswer::Builder::disownTransform() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>> PromisedAnswer::Builder::disownTransform() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::rpc::PromisedAnswer::Op, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -4892,7 +5001,42 @@ inline void Exception::Builder::setType( ::capnp::rpc::Exception::Type value) { ::capnp::bounded<2>() * ::capnp::ELEMENTS, value); } +inline bool Exception::Reader::hasTrace() const { + return !_reader.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); +} +inline bool Exception::Builder::hasTrace() { + return !_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::Text::Reader Exception::Reader::getTrace() const { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} +inline ::capnp::Text::Builder Exception::Builder::getTrace() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} +inline void Exception::Builder::setTrace( ::capnp::Text::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), value); +} +inline ::capnp::Text::Builder Exception::Builder::initTrace(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), size); +} +inline void Exception::Builder::adoptTrace( + ::capnp::Orphan< ::capnp::Text>&& value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Text> Exception::Builder::disownTrace() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} + } // namespace } // namespace -#endif // CAPNP_INCLUDED_b312981b2552a250_ +CAPNP_END_HEADER + diff --git a/c++/src/capnp/rpc.h b/c++/src/capnp/rpc.h index d84ed982e7..c4df2f04e2 100644 --- a/c++/src/capnp/rpc.h +++ b/c++/src/capnp/rpc.h @@ -19,16 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_RPC_H_ -#define CAPNP_RPC_H_ +#pragma once -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif - -#include "capability.h" +#include #include "rpc-prelude.h" +CAPNP_BEGIN_HEADER + +namespace kj { class AutoCloseFd; } + namespace capnp { template class SturdyRefRestorer; +class MessageReader; + template class BootstrapFactory: public _::BootstrapFactoryBase { // Interface that constructs per-client bootstrap interfaces. Use this if you want each client @@ -76,15 +77,13 @@ class RpcSystem: public _::RpcSystemBase { typename ThirdPartyCapId, typename JoinResult> RpcSystem( VatNetwork& network, - kj::Maybe bootstrapInterface, - kj::Maybe::Client> gateway = nullptr); + kj::Maybe bootstrapInterface); template RpcSystem( VatNetwork& network, - BootstrapFactory& bootstrapFactory, - kj::Maybe::Client> gateway = nullptr); + BootstrapFactory& bootstrapFactory); template func); + // + // (Inherited from _::RpcSystemBase) + // + // Set a function to call to encode exception stack traces for transmission to remote parties. + // By default, traces are not transmitted at all. If a callback is provided, then the returned + // string will be sent with the exception. If the remote end is KJ/C++ based, then this trace + // text ends up being accessible as kj::Exception::getRemoteTrace(). + // + // Stack traces can sometimes contain sensitive information, so you should think carefully about + // what information you are willing to reveal to the remote party. + + kj::Promise run() { return RpcSystemBase::run(); } + // Listens for incoming RPC connections and handles them. Never returns normally, but could throw + // an exception if the system becomes unable to accept new connections (e.g. because the + // underlying listen socket becomes broken somehow). + // + // For historical reasons, the RpcSystem will actually run itself even if you do not call this. + // However, if an exception is thrown, the RpcSystem will log the exception to the console and + // then cease accepting new connections. In this case, your server may be in a broken state, but + // without restarting. All servers should therefore call run() and handle failures in some way. }; template makeRpcServer( // See also ez-rpc.h, which has simpler instructions for the common case of a two-party // client-server RPC connection. -template , - typename ExternalRef = _::ExternalRefFromRealmGatewayClient> -RpcSystem makeRpcServer( - VatNetwork& network, - Capability::Client bootstrapInterface, RealmGatewayClient gateway); -// Make an RPC server for a VatNetwork that resides in a different realm from the application. -// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format -// and the network's ("external") format. - template RpcSystem makeRpcServer( @@ -180,23 +190,12 @@ RpcSystem makeRpcServer( // Make an RPC server that can serve different bootstrap interfaces to different clients via a // BootstrapInterface. -template , - typename ExternalRef = _::ExternalRefFromRealmGatewayClient> -RpcSystem makeRpcServer( - VatNetwork& network, - BootstrapFactory& bootstrapFactory, RealmGatewayClient gateway); -// Make an RPC server that can serve different bootstrap interfaces to different clients via a -// BootstrapInterface and communicates with a different realm than the application is in via a -// RealmGateway. - template RpcSystem makeRpcServer( VatNetwork& network, SturdyRefRestorer& restorer) - KJ_DEPRECATED("Please transition to using a bootstrap interface instead."); + CAPNP_DEPRECATED("Please transition to using a bootstrap interface instead."); // ** DEPRECATED ** // // Create an RPC server which exports multiple main interfaces by object ID. The `restorer` object @@ -226,17 +225,6 @@ RpcSystem makeRpcClient( // See also ez-rpc.h, which has simpler instructions for the common case of a two-party // client-server RPC connection. -template , - typename ExternalRef = _::ExternalRefFromRealmGatewayClient> -RpcSystem makeRpcClient( - VatNetwork& network, - RealmGatewayClient gateway); -// Make an RPC client for a VatNetwork that resides in a different realm from the application. -// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format -// and the network's ("external") format. - template class SturdyRefRestorer: public _::SturdyRefRestorerBase { // ** DEPRECATED ** @@ -254,9 +242,8 @@ class SturdyRefRestorer: public _::SturdyRefRestorerBase { // string names. public: - virtual Capability::Client restore(typename SturdyRefObjectId::Reader ref) - KJ_DEPRECATED( - "Please transition to using bootstrap interfaces instead of SturdyRefRestorer.") = 0; + virtual Capability::Client restore(typename SturdyRefObjectId::Reader ref) CAPNP_DEPRECATED( + "Please transition to using bootstrap interfaces instead of SturdyRefRestorer.") = 0; // Restore the given object, returning a capability representing it. private: @@ -274,9 +261,18 @@ class OutgoingRpcMessage { // Get the message body, which the caller may fill in any way it wants. (The standard RPC // implementation initializes it as a Message as defined in rpc.capnp.) + virtual void setFds(kj::Array fds) {} + // Set the list of file descriptors to send along with this message, if FD passing is supported. + // An implementation may ignore this. + virtual void send() = 0; // Send the message, or at least put it in a queue to be sent later. Note that the builder // returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed. + + virtual size_t sizeInWords() = 0; + // Get the total size of the message, for flow control purposes. Although the caller could + // also call getBody().targetSize(), doing that would walk the message tree, whereas typical + // implementations can compute the size more cheaply by summing segment sizes. }; class IncomingRpcMessage { @@ -286,6 +282,81 @@ class IncomingRpcMessage { virtual AnyPointer::Reader getBody() = 0; // Get the message body, to be interpreted by the caller. (The standard RPC implementation // interprets it as a Message as defined in rpc.capnp.) + + virtual kj::ArrayPtr getAttachedFds() { return nullptr; } + // If the transport supports attached file descriptors and some were attached to this message, + // returns them. Otherwise returns an empty array. It is intended that the caller will move the + // FDs out of this table when they are consumed, possibly leaving behind a null slot. Callers + // should be careful to check if an FD was already consumed by comparing the slot with `nullptr`. + // (We don't use Maybe here because moving from a Maybe doesn't make it null, so it would only + // add confusion. Moving from an AutoCloseFd does in fact make it null.) + + virtual size_t sizeInWords() = 0; + // Get the total size of the message, for flow control purposes. Although the caller could + // also call getBody().targetSize(), doing that would walk the message tree, whereas typical + // implementations can compute the size more cheaply by summing segment sizes. + + static bool isShortLivedRpcMessage(AnyPointer::Reader body); + // Helper function which computes whether the standard RpcSystem implementation would consider + // the given message body to be short-lived, meaning it will be dropped before the next message + // is read. This is useful to implement BufferedMessageStream::IsShortLivedCallback. + + static kj::Function getShortLivedCallback(); + // Returns a function that wraps isShortLivedRpcMessage(). The returned function type matches + // `BufferedMessageStream::IsShortLivedCallback` (defined in serialize-async.h), but we don't + // include that header here. +}; + +class RpcFlowController { + // Tracks a particular RPC stream in order to implement a flow control algorithm. + +public: + virtual kj::Promise send(kj::Own message, kj::Promise ack) = 0; + // Like calling message->send(), but the promise resolves when it's a good time to send the + // next message. + // + // `ack` is a promise that resolves when the message has been acknowledged from the other side. + // In practice, `message` is typically a `Call` message and `ack` is a `Return`. Note that this + // means `ack` counts not only time to transmit the message but also time for the remote + // application to process the message. The flow controller is expected to apply backpressure if + // the remote application responds slowly. If `ack` rejects, then all outstanding and future + // sends will propagate the exception. + // + // Note that messages sent with this method must still be delivered in the same order as if they + // had been sent with `message->send()`; they cannot be delayed until later. This is important + // because the message may introduce state changes in the RPC system that later messages rely on, + // such as introducing a new Question ID that a later message may reference. Thus, the controller + // can only create backpressure by having the returned promise resolve slowly. + // + // Dropping the returned promise does not cancel the send. Once send() is called, there's no way + // to stop it. + + virtual kj::Promise waitAllAcked() = 0; + // Wait for all `ack`s previously passed to send() to finish. It is an error to call send() again + // after this. + + // --------------------------------------------------------------------------- + // Common implementations. + + static kj::Own newFixedWindowController(size_t windowSize); + // Constructs a flow controller that implements a strict fixed window of the given size. In other + // words, the controller will throttle the stream when the total bytes in-flight exceeds the + // window. + + class WindowGetter { + public: + virtual size_t getWindow() = 0; + }; + + static kj::Own newVariableWindowController(WindowGetter& getter); + // Like newFixedWindowController(), but the window size is allowed to vary over time. Useful if + // you have a technique for estimating one good window size for the connection as a whole but not + // for individual streams. Keep in mind, though, that in situations where the other end of the + // connection is merely proxying capabilities from a variety of final destinations across a + // variety of networks, no single window will be appropriate for all streams. + + static constexpr size_t DEFAULT_WINDOW_SIZE = 65536; + // The window size used by the default implementation of Connection::newStream(). }; template newStream() override + { return RpcFlowController::newFixedWindowController(65536); } + // Construct a flow controller for a new stream on this connection. The controller can be + // passed into OutgoingRpcMessage::sendStreaming(). + // + // The default implementation returns a dummy stream controller that just applies a fixed + // window of 64k to everything. This always works but may constrain throughput on networks + // where the bandwidth-delay product is high, while conversely providing too much buffer when + // the bandwidth-delay product is low. + // + // WARNING: The RPC system may keep the `RpcFlowController` object alive past the lifetime of + // the `Connection` itself. However, it will not call `send()` any more after the + // `Connection` is destroyed. + // + // TODO(perf): We should introduce a flow controller implementation that uses a clock to + // measure RTT and bandwidth and dynamically update the window size, like BBR. + // Level 0 features ---------------------------------------------- virtual typename VatId::Reader getPeerVatId() = 0; @@ -345,10 +433,17 @@ class VatNetwork: public _::VatNetworkBase { // If `firstSegmentWordSize` is non-zero, it should be treated as a hint suggesting how large // to make the first segment. This is entirely a hint and the connection may adjust it up or // down. If it is zero, the connection should choose the size itself. + // + // WARNING: The RPC system may keep the `OutgoingRpcMessage` object alive past the lifetime of + // the `Connection` itself. However, it will not call `send()` any more after the + // `Connection` is destroyed. virtual kj::Promise>> receiveIncomingMessage() override = 0; // Wait for a message to be received and return it. If the read stream cleanly terminates, // return null. If any other problem occurs, throw an exception. + // + // WARNING: The RPC system may keep the `IncomingRpcMessage` object alive past the lifetime of + // the `Connection` itself. virtual kj::Promise shutdown() override KJ_WARN_UNUSED_RESULT = 0; // Waits until all outgoing messages have been sent, then shuts down the outgoing stream. The @@ -435,18 +530,16 @@ template RpcSystem::RpcSystem( VatNetwork& network, - kj::Maybe bootstrap, - kj::Maybe::Client> gateway) - : _::RpcSystemBase(network, kj::mv(bootstrap), kj::mv(gateway)) {} + kj::Maybe bootstrap) + : _::RpcSystemBase(network, kj::mv(bootstrap)) {} template template RpcSystem::RpcSystem( VatNetwork& network, - BootstrapFactory& bootstrapFactory, - kj::Maybe::Client> gateway) - : _::RpcSystemBase(network, bootstrapFactory, kj::mv(gateway)) {} + BootstrapFactory& bootstrapFactory) + : _::RpcSystemBase(network, bootstrapFactory) {} template template makeRpcServer( return RpcSystem(network, kj::mv(bootstrapInterface)); } -template -RpcSystem makeRpcServer( - VatNetwork& network, - Capability::Client bootstrapInterface, RealmGatewayClient gateway) { - return RpcSystem(network, kj::mv(bootstrapInterface), - gateway.template castAs>()); -} - template RpcSystem makeRpcServer( @@ -499,15 +582,6 @@ RpcSystem makeRpcServer( return RpcSystem(network, bootstrapFactory); } -template -RpcSystem makeRpcServer( - VatNetwork& network, - BootstrapFactory& bootstrapFactory, RealmGatewayClient gateway) { - return RpcSystem(network, bootstrapFactory, gateway.template castAs>()); -} - template RpcSystem makeRpcServer( @@ -523,15 +597,6 @@ RpcSystem makeRpcClient( return RpcSystem(network, nullptr); } -template -RpcSystem makeRpcClient( - VatNetwork& network, - RealmGatewayClient gateway) { - return RpcSystem(network, nullptr, gateway.template castAs>()); -} - } // namespace capnp -#endif // CAPNP_RPC_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/schema-lite.h b/c++/src/capnp/schema-lite.h index 58a8c14c05..0d7b915679 100644 --- a/c++/src/capnp/schema-lite.h +++ b/c++/src/capnp/schema-lite.h @@ -19,16 +19,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_SCHEMA_LITE_H_ -#define CAPNP_SCHEMA_LITE_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include "message.h" +CAPNP_BEGIN_HEADER + namespace capnp { template @@ -45,4 +42,4 @@ inline schema::Node::Reader schemaProto() { } // namespace capnp -#endif // CAPNP_SCHEMA_LITE_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/schema-loader-test.c++ b/c++/src/capnp/schema-loader-test.c++ index 7b8691b01c..c2b2651bf9 100644 --- a/c++/src/capnp/schema-loader-test.c++ +++ b/c++/src/capnp/schema-loader-test.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#define CAPNP_TESTING_CAPNP 1 + #include "schema-loader.h" #include #include "test-util.h" @@ -387,6 +389,47 @@ TEST(SchemaLoader, Generics) { } } +TEST(SchemaLoader, LoadStreaming) { + SchemaLoader loader; + + InterfaceSchema schema = + loader.load(Schema::from().getProto()).asInterface(); + + auto results = schema.getMethodByName("doStreamI").getResultType(); + KJ_EXPECT(results.isStreamResult()); + KJ_EXPECT(results.getShortDisplayName() == "StreamResult", results.getShortDisplayName()); +} + +KJ_TEST("SchemaLoader placeholders are assumed to have caps") { + // Load TestCycle*NoCaps, but don't load its dependency TestAllTypes, so the loader has to assume + // there may be caps. + { + SchemaLoader loader; + Schema schemaA = loader.load(Schema::from().getProto()); + Schema schemaB = loader.load(Schema::from().getProto()); + loader.computeOptimizationHints(); + + KJ_EXPECT(schemaA.asStruct().mayContainCapabilities()); + KJ_EXPECT(schemaB.asStruct().mayContainCapabilities()); + } + + // Try again, but actually load TestAllTypes. Now we recognize there's no caps. + { + SchemaLoader loader; + Schema schemaA = loader.load(Schema::from().getProto()); + Schema schemaB = loader.load(Schema::from().getProto()); + loader.load(Schema::from().getProto()); + loader.computeOptimizationHints(); + + KJ_EXPECT(!schemaA.asStruct().mayContainCapabilities()); + KJ_EXPECT(!schemaB.asStruct().mayContainCapabilities()); + } + + // NOTE: computeOptimizationHints() is also tested in `schema-test.c++` where we test that + // various compiled types have the correct hints, which relies on the code generator having + // computed the hints. +} + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/c++/src/capnp/schema-loader.c++ b/c++/src/capnp/schema-loader.c++ index 1576433b6f..7c056c5041 100644 --- a/c++/src/capnp/schema-loader.c++ +++ b/c++/src/capnp/schema-loader.c++ @@ -21,9 +21,6 @@ #define CAPNP_PRIVATE #include "schema-loader.h" -#include -#include -#include #include "message.h" #include "arena.h" #include @@ -31,8 +28,10 @@ #include #include #include +#include +#include -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) #include #endif @@ -40,27 +39,6 @@ namespace capnp { namespace { -struct ByteArrayHash { - size_t operator()(kj::ArrayPtr bytes) const { - // FNV hash. Probably sucks, but the code is simple. - // - // TODO(perf): Add CityHash or something to KJ and use it here. - - uint64_t hash = 0xcbf29ce484222325ull; - for (byte b: bytes) { - hash = hash * 0x100000001b3ull; - hash ^= b; - } - return hash; - } -}; - -struct ByteArrayEq { - bool operator()(kj::ArrayPtr a, kj::ArrayPtr b) const { - return a.size() == b.size() && memcmp(a.begin(), b.begin(), a.size()) == 0; - } -}; - struct SchemaBindingsPair { const _::RawSchema* schema; const _::RawBrandedSchema::Scope* scopeBindings; @@ -68,12 +46,8 @@ struct SchemaBindingsPair { inline bool operator==(const SchemaBindingsPair& other) const { return schema == other.schema && scopeBindings == other.scopeBindings; } -}; - -struct SchemaBindingsPairHash { - size_t operator()(SchemaBindingsPair pair) const { - return 31 * reinterpret_cast(pair.schema) + - reinterpret_cast(pair.scopeBindings); + inline uint hashCode() const { + return kj::hashCode(schema, scopeBindings); } }; @@ -140,6 +114,8 @@ public: kj::Array getAllLoaded() const; + void computeOptimizationHints(); + void requireStructSize(uint64_t id, uint dataWordCount, uint pointerCount); // Require any struct nodes loaded with this ID -- in the past and in the future -- to have at // least the given sizes. Struct nodes that don't comply will simply be rewritten to comply. @@ -150,19 +126,19 @@ public: kj::Arena arena; private: - std::unordered_set, ByteArrayHash, ByteArrayEq> dedupTable; + kj::HashSet> dedupTable; // Records raw segments of memory in the arena against which we my want to de-dupe later // additions. Specifically, RawBrandedSchema binding tables are de-duped. - std::unordered_map schemas; - std::unordered_map brands; - std::unordered_map unboundBrands; + kj::HashMap schemas; + kj::HashMap brands; + kj::HashMap unboundBrands; struct RequiredSize { uint16_t dataWordCount; uint16_t pointerCount; }; - std::unordered_map structSizeRequirements; + kj::HashMap structSizeRequirements; InitializerImpl initializer; BrandedInitializerImpl brandedInitializer; @@ -285,7 +261,7 @@ public: loader.arena.allocateArray(*count); uint pos = 0; for (auto& dep: dependencies) { - result[pos++] = dep.second; + result[pos++] = dep.value; } KJ_DASSERT(pos == *count); return result.begin(); @@ -296,7 +272,7 @@ public: kj::ArrayPtr result = loader.arena.allocateArray(*count); uint pos = 0; for (auto& member: members) { - result[pos++] = member.second; + result[pos++] = member.value; } KJ_DASSERT(pos == *count); return result.begin(); @@ -310,10 +286,14 @@ private: SchemaLoader::Impl& loader; Text::Reader nodeName; bool isValid; - std::map dependencies; + + // Maps type IDs -> compiled schemas for each dependency. + // Order is important because makeDependencyArray() compiles a sorted array. + kj::TreeMap dependencies; // Maps name -> index for each member. - std::map members; + // Order is important because makeMemberInfoArray() compiles a sorted array. + kj::TreeMap members; kj::ArrayPtr membersByDiscriminant; @@ -323,8 +303,9 @@ private: KJ_FAIL_REQUIRE(__VA_ARGS__) { isValid = false; return; } void validateMemberName(kj::StringPtr name, uint index) { - bool isNewName = members.insert(std::make_pair(name, index)).second; - VALIDATE_SCHEMA(isNewName, "duplicate name", name); + members.upsert(name, index, [&](auto&, auto&&) { + FAIL_VALIDATE_SCHEMA("duplicate name", name); + }); } void validate(const schema::Node::Struct::Reader& structNode, uint64_t scopeId) { @@ -625,12 +606,13 @@ private: VALIDATE_SCHEMA(node.which() == expectedKind, "expected a different kind of node for this ID", id, (uint)expectedKind, (uint)node.which(), node.getDisplayName()); - dependencies.insert(std::make_pair(id, existing)); + dependencies.upsert(id, existing, [](auto&,auto&&) { /* ignore dupe */ }); return; } - dependencies.insert(std::make_pair(id, loader.loadEmpty( - id, kj::str("(unknown type used by ", nodeName , ")"), expectedKind, true))); + dependencies.upsert(id, loader.loadEmpty( + id, kj::str("(unknown type used by ", nodeName , ")"), expectedKind, true), + [](auto&,auto&&) { /* ignore dupe */ }); } #undef VALIDATE_SCHEMA @@ -1263,144 +1245,144 @@ _::RawSchema* SchemaLoader::Impl::load(const schema::Node::Reader& reader, bool } // Check if we already have a schema for this ID. - _::RawSchema*& slot = schemas[validatedReader.getId()]; + _::RawSchema* schema; bool shouldReplace; bool shouldClearInitializer; - if (slot == nullptr) { - // Nope, allocate a new RawSchema. - slot = &arena.allocate<_::RawSchema>(); - memset(&slot->defaultBrand, 0, sizeof(slot->defaultBrand)); - slot->id = validatedReader.getId(); - slot->canCastTo = nullptr; - slot->defaultBrand.generic = slot; - slot->lazyInitializer = isPlaceholder ? &initializer : nullptr; - slot->defaultBrand.lazyInitializer = isPlaceholder ? &brandedInitializer : nullptr; - shouldReplace = true; - shouldClearInitializer = false; - } else { + KJ_IF_MAYBE(match, schemas.find(validatedReader.getId())) { // Yes, check if it is compatible and figure out which schema is newer. - // If the existing slot is a placeholder, but we're upgrading it to a non-placeholder, we + schema = *match; + + // If the existing schema is a placeholder, but we're upgrading it to a non-placeholder, we // need to clear the initializer later. - shouldClearInitializer = slot->lazyInitializer != nullptr && !isPlaceholder; + shouldClearInitializer = schema->lazyInitializer != nullptr && !isPlaceholder; - auto existing = readMessageUnchecked(slot->encodedNode); + auto existing = readMessageUnchecked(schema->encodedNode); CompatibilityChecker checker(*this); // Prefer to replace the existing schema if the existing schema is a placeholder. Otherwise, // prefer to keep the existing schema. shouldReplace = checker.shouldReplace( - existing, validatedReader, slot->lazyInitializer != nullptr); + existing, validatedReader, schema->lazyInitializer != nullptr); + } else { + // Nope, allocate a new RawSchema. + schema = &arena.allocate<_::RawSchema>(); + memset(&schema->defaultBrand, 0, sizeof(schema->defaultBrand)); + schema->id = validatedReader.getId(); + schema->canCastTo = nullptr; + schema->defaultBrand.generic = schema; + schema->lazyInitializer = isPlaceholder ? &initializer : nullptr; + schema->defaultBrand.lazyInitializer = isPlaceholder ? &brandedInitializer : nullptr; + shouldReplace = true; + shouldClearInitializer = false; + schemas.insert(validatedReader.getId(), schema); } if (shouldReplace) { // Initialize the RawSchema. - slot->encodedNode = validated.begin(); - slot->encodedSize = validated.size(); - slot->dependencies = validator.makeDependencyArray(&slot->dependencyCount); - slot->membersByName = validator.makeMemberInfoArray(&slot->memberCount); - slot->membersByDiscriminant = validator.makeMembersByDiscriminantArray(); + schema->encodedNode = validated.begin(); + schema->encodedSize = validated.size(); + schema->dependencies = validator.makeDependencyArray(&schema->dependencyCount); + schema->membersByName = validator.makeMemberInfoArray(&schema->memberCount); + schema->membersByDiscriminant = validator.makeMembersByDiscriminantArray(); // Even though this schema isn't itself branded, it may have dependencies that are. So, we // need to set up the "dependencies" map under defaultBrand. - auto deps = makeBrandedDependencies(slot, kj::ArrayPtr()); - slot->defaultBrand.dependencies = deps.begin(); - slot->defaultBrand.dependencyCount = deps.size(); + auto deps = makeBrandedDependencies(schema, kj::ArrayPtr()); + schema->defaultBrand.dependencies = deps.begin(); + schema->defaultBrand.dependencyCount = deps.size(); } if (shouldClearInitializer) { // If this schema is not newly-allocated, it may already be in the wild, specifically in the // dependency list of other schemas. Once the initializer is null, it is live, so we must do // a release-store here. -#if __GNUC__ - __atomic_store_n(&slot->lazyInitializer, nullptr, __ATOMIC_RELEASE); - __atomic_store_n(&slot->defaultBrand.lazyInitializer, nullptr, __ATOMIC_RELEASE); +#if __GNUC__ || defined(__clang__) + __atomic_store_n(&schema->lazyInitializer, nullptr, __ATOMIC_RELEASE); + __atomic_store_n(&schema->defaultBrand.lazyInitializer, nullptr, __ATOMIC_RELEASE); #elif _MSC_VER std::atomic_thread_fence(std::memory_order_release); - *static_cast<_::RawSchema::Initializer const* volatile*>(&slot->lazyInitializer) = nullptr; + *static_cast<_::RawSchema::Initializer const* volatile*>(&schema->lazyInitializer) = nullptr; *static_cast<_::RawBrandedSchema::Initializer const* volatile*>( - &slot->defaultBrand.lazyInitializer) = nullptr; + &schema->defaultBrand.lazyInitializer) = nullptr; #else #error "Platform not supported" #endif } - return slot; + return schema; } _::RawSchema* SchemaLoader::Impl::loadNative(const _::RawSchema* nativeSchema) { - _::RawSchema*& slot = schemas[nativeSchema->id]; + _::RawSchema* schema; bool shouldReplace; bool shouldClearInitializer; - if (slot == nullptr) { - slot = &arena.allocate<_::RawSchema>(); - memset(&slot->defaultBrand, 0, sizeof(slot->defaultBrand)); - slot->defaultBrand.generic = slot; - slot->lazyInitializer = nullptr; - slot->defaultBrand.lazyInitializer = nullptr; + KJ_IF_MAYBE(match, schemas.find(nativeSchema->id)) { + schema = *match; + if (schema->canCastTo != nullptr) { + // Already loaded natively, or we're currently in the process of loading natively and there + // was a dependency cycle. + KJ_REQUIRE(schema->canCastTo == nativeSchema, + "two different compiled-in type have the same type ID", + nativeSchema->id, + readMessageUnchecked(nativeSchema->encodedNode).getDisplayName(), + readMessageUnchecked(schema->canCastTo->encodedNode).getDisplayName()); + return schema; + } else { + auto existing = readMessageUnchecked(schema->encodedNode); + auto native = readMessageUnchecked(nativeSchema->encodedNode); + CompatibilityChecker checker(*this); + shouldReplace = checker.shouldReplace(existing, native, true); + shouldClearInitializer = schema->lazyInitializer != nullptr; + } + } else { + schema = &arena.allocate<_::RawSchema>(); + memset(&schema->defaultBrand, 0, sizeof(schema->defaultBrand)); + schema->defaultBrand.generic = schema; + schema->lazyInitializer = nullptr; + schema->defaultBrand.lazyInitializer = nullptr; shouldReplace = true; shouldClearInitializer = false; // already cleared above - } else if (slot->canCastTo != nullptr) { - // Already loaded natively, or we're currently in the process of loading natively and there - // was a dependency cycle. - KJ_REQUIRE(slot->canCastTo == nativeSchema, - "two different compiled-in type have the same type ID", - nativeSchema->id, - readMessageUnchecked(nativeSchema->encodedNode).getDisplayName(), - readMessageUnchecked(slot->canCastTo->encodedNode).getDisplayName()); - return slot; - } else { - auto existing = readMessageUnchecked(slot->encodedNode); - auto native = readMessageUnchecked(nativeSchema->encodedNode); - CompatibilityChecker checker(*this); - shouldReplace = checker.shouldReplace(existing, native, true); - shouldClearInitializer = slot->lazyInitializer != nullptr; + schemas.insert(nativeSchema->id, schema); } - // Since we recurse below, the slot in the hash map could move around. Copy out the pointer - // for subsequent use. - // TODO(cleanup): Above comment is actually not true of unordered_map. Leaving here to explain - // code pattern below. - _::RawSchema* result = slot; - if (shouldReplace) { // Set the schema to a copy of the native schema, but make sure not to null out lazyInitializer // yet. _::RawSchema temp = *nativeSchema; - temp.lazyInitializer = result->lazyInitializer; - *result = temp; + temp.lazyInitializer = schema->lazyInitializer; + *schema = temp; - result->defaultBrand.generic = result; + schema->defaultBrand.generic = schema; // Indicate that casting is safe. Note that it's important to set this before recursively // loading dependencies, so that cycles don't cause infinite loops! - result->canCastTo = nativeSchema; + schema->canCastTo = nativeSchema; // We need to set the dependency list to point at other loader-owned RawSchemas. kj::ArrayPtr dependencies = - arena.allocateArray(result->dependencyCount); + arena.allocateArray(schema->dependencyCount); for (uint i = 0; i < nativeSchema->dependencyCount; i++) { dependencies[i] = loadNative(nativeSchema->dependencies[i]); } - result->dependencies = dependencies.begin(); + schema->dependencies = dependencies.begin(); // Also need to re-do the branded dependencies. - auto deps = makeBrandedDependencies(slot, kj::ArrayPtr()); - slot->defaultBrand.dependencies = deps.begin(); - slot->defaultBrand.dependencyCount = deps.size(); + auto deps = makeBrandedDependencies(schema, kj::ArrayPtr()); + schema->defaultBrand.dependencies = deps.begin(); + schema->defaultBrand.dependencyCount = deps.size(); // If there is a struct size requirement, we need to make sure that it is satisfied. - auto reqIter = structSizeRequirements.find(nativeSchema->id); - if (reqIter != structSizeRequirements.end()) { - applyStructSizeRequirement(result, reqIter->second.dataWordCount, - reqIter->second.pointerCount); + KJ_IF_MAYBE(sizeReq, structSizeRequirements.find(nativeSchema->id)) { + applyStructSizeRequirement(schema, sizeReq->dataWordCount, + sizeReq->pointerCount); } } else { // The existing schema is newer. // Indicate that casting is safe. Note that it's important to set this before recursively // loading dependencies, so that cycles don't cause infinite loops! - result->canCastTo = nativeSchema; + schema->canCastTo = nativeSchema; // Make sure the dependencies are loaded and compatible. for (uint i = 0; i < nativeSchema->dependencyCount; i++) { @@ -1412,20 +1394,20 @@ _::RawSchema* SchemaLoader::Impl::loadNative(const _::RawSchema* nativeSchema) { // If this schema is not newly-allocated, it may already be in the wild, specifically in the // dependency list of other schemas. Once the initializer is null, it is live, so we must do // a release-store here. -#if __GNUC__ - __atomic_store_n(&result->lazyInitializer, nullptr, __ATOMIC_RELEASE); - __atomic_store_n(&result->defaultBrand.lazyInitializer, nullptr, __ATOMIC_RELEASE); +#if __GNUC__ || defined(__clang__) + __atomic_store_n(&schema->lazyInitializer, nullptr, __ATOMIC_RELEASE); + __atomic_store_n(&schema->defaultBrand.lazyInitializer, nullptr, __ATOMIC_RELEASE); #elif _MSC_VER std::atomic_thread_fence(std::memory_order_release); - *static_cast<_::RawSchema::Initializer const* volatile*>(&result->lazyInitializer) = nullptr; + *static_cast<_::RawSchema::Initializer const* volatile*>(&schema->lazyInitializer) = nullptr; *static_cast<_::RawBrandedSchema::Initializer const* volatile*>( - &result->defaultBrand.lazyInitializer) = nullptr; + &schema->defaultBrand.lazyInitializer) = nullptr; #else #error "Platform not supported" #endif } - return result; + return schema; } _::RawSchema* SchemaLoader::Impl::loadEmpty( @@ -1528,31 +1510,25 @@ const _::RawBrandedSchema* SchemaLoader::Impl::makeBranded( const _::RawBrandedSchema* SchemaLoader::Impl::makeBranded( const _::RawSchema* schema, kj::ArrayPtr bindings) { - // Note that even if `bindings` is empty, we never want to return defaultBrand here because - // defaultBrand has special status. Normally, the lack of bindings means all parameters are - // "unspecified", which means their bindings are unknown and should be treated as AnyPointer. - // But defaultBrand represents a special case where all parameters are still parameters -- they - // haven't been bound in the first place. defaultBrand is used to represent the unbranded generic - // type, while a no-binding brand is equivalent to binding all parameters to AnyPointer. - if (bindings.size() == 0) { + // `defaultBrand` is the version where all type parameters are bound to `AnyPointer`. return &schema->defaultBrand; } - auto& slot = brands[SchemaBindingsPair { schema, bindings.begin() }]; - - if (slot == nullptr) { + SchemaBindingsPair key { schema, bindings.begin() }; + KJ_IF_MAYBE(existing, brands.find(key)) { + return *existing; + } else { auto& brand = arena.allocate<_::RawBrandedSchema>(); memset(&brand, 0, sizeof(brand)); - slot = &brand; + brands.insert(key, &brand); brand.generic = schema; brand.scopes = bindings.begin(); brand.scopeCount = bindings.size(); brand.lazyInitializer = &brandedInitializer; + return &brand; } - - return slot; } kj::ArrayPtr @@ -1749,9 +1725,17 @@ void SchemaLoader::Impl::makeDep(_::RawBrandedSchema::Binding& result, uint64_t typeId, schema::Type::Which whichType, schema::Node::Which expectedKind, schema::Brand::Reader brand, kj::StringPtr scopeName, kj::Maybe> brandBindings) { - const _::RawSchema* schema = loadEmpty(typeId, - kj::str("(unknown type; seen as dependency of ", scopeName, ")"), - expectedKind, true); + const _::RawSchema* schema; + if (typeId == capnp::typeId()) { + // StreamResult is a very special type that is used to mark when a method is declared as + // streaming ("foo @0 () -> stream;"). We like to auto-load it if we see it as someone's + // dependency. + schema = loadNative(&_::rawSchema()); + } else { + schema = loadEmpty(typeId, + kj::str("(unknown type; seen as dependency of ", scopeName, ")"), + expectedKind, true); + } result.which = static_cast(whichType); result.schema = makeBranded(schema, brand, brandBindings); } @@ -1783,16 +1767,15 @@ kj::ArrayPtr SchemaLoader::Impl::copyDeduped(kj::ArrayPtr valu auto bytes = values.asBytes(); - auto iter = dedupTable.find(bytes); - if (iter != dedupTable.end()) { - return kj::arrayPtr(reinterpret_cast(iter->begin()), values.size()); + KJ_IF_MAYBE(dupe, dedupTable.find(bytes)) { + return kj::arrayPtr(reinterpret_cast(dupe->begin()), values.size()); } // Need to make a new copy. auto copy = arena.allocateArray(values.size()); memcpy(copy.begin(), values.begin(), values.size() * sizeof(T)); - KJ_ASSERT(dedupTable.insert(copy.asBytes()).second); + dedupTable.insert(copy.asBytes()); return copy; } @@ -1803,11 +1786,10 @@ kj::ArrayPtr SchemaLoader::Impl::copyDeduped(kj::ArrayPtr values) { } SchemaLoader::Impl::TryGetResult SchemaLoader::Impl::tryGet(uint64_t typeId) const { - auto iter = schemas.find(typeId); - if (iter == schemas.end()) { - return {nullptr, initializer.getCallback()}; + KJ_IF_MAYBE(schema, schemas.find(typeId)) { + return {*schema, initializer.getCallback()}; } else { - return {iter->second, initializer.getCallback()}; + return {nullptr, initializer.getCallback()}; } } @@ -1817,43 +1799,214 @@ const _::RawBrandedSchema* SchemaLoader::Impl::getUnbound(const _::RawSchema* sc return &schema->defaultBrand; } - auto& slot = unboundBrands[schema]; - if (slot == nullptr) { - slot = &arena.allocate<_::RawBrandedSchema>(); + KJ_IF_MAYBE(existing, unboundBrands.find(schema)) { + return *existing; + } else { + auto slot = &arena.allocate<_::RawBrandedSchema>(); memset(slot, 0, sizeof(*slot)); slot->generic = schema; auto deps = makeBrandedDependencies(schema, nullptr); slot->dependencies = deps.begin(); slot->dependencyCount = deps.size(); + unboundBrands.insert(schema, slot); + return slot; } - - return slot; } kj::Array SchemaLoader::Impl::getAllLoaded() const { size_t count = 0; for (auto& schema: schemas) { - if (schema.second->lazyInitializer == nullptr) ++count; + if (schema.value->lazyInitializer == nullptr) ++count; } kj::Array result = kj::heapArray(count); size_t i = 0; for (auto& schema: schemas) { - if (schema.second->lazyInitializer == nullptr) { - result[i++] = Schema(&schema.second->defaultBrand); + if (schema.value->lazyInitializer == nullptr) { + result[i++] = Schema(&schema.value->defaultBrand); } } return result; } +void SchemaLoader::Impl::computeOptimizationHints() { + kj::HashMap<_::RawSchema*, kj::Vector<_::RawSchema*>> undecided; + // This map contains schemas for which we haven't yet decided if they might have capabilities. + // They at least do not directly contain capabilities, but they can't be fully decided until + // the dependents are decided. + // + // Each entry maps to a list of other schemas whose decisions depend on this schema. When a + // schema in the map is discovered to contain capabilities, then all these dependents must also + // be presumed to contain capabilities. + + // First pass: Decide on the easy cases and populate the `undecided` map with hard cases. + for (auto& entry: schemas) { + _::RawSchema* schema = entry.value; + + // Default to assuming everything could contain caps. + schema->mayContainCapabilities = true; + + if (schema->lazyInitializer != nullptr) { + // Not initialized yet, so we have to be conservative and assume there could be capabilities. + continue; + } + + auto node = readMessageUnchecked(schema->encodedNode); + + if (!node.isStruct()) { + // Non-structs are irrelevant. + continue; + } + + auto structSchema = node.getStruct(); + + bool foundAnyCaps = false; + bool foundAnyStructs = false; + for (auto field: structSchema.getFields()) { + switch (field.which()) { + case schema::Field::GROUP: + foundAnyStructs = true; + break; + case schema::Field::SLOT: { + auto type = field.getSlot().getType(); + while (type.isList()) { + type = type.getList().getElementType(); + } + + switch (type.which()) { + case schema::Type::VOID: + case schema::Type::BOOL: + case schema::Type::INT8: + case schema::Type::INT16: + case schema::Type::INT32: + case schema::Type::INT64: + case schema::Type::UINT8: + case schema::Type::UINT16: + case schema::Type::UINT32: + case schema::Type::UINT64: + case schema::Type::FLOAT32: + case schema::Type::FLOAT64: + case schema::Type::TEXT: + case schema::Type::DATA: + case schema::Type::ENUM: + // Not a capability. + break; + + case schema::Type::STRUCT: + foundAnyStructs = true; + break; + + case schema::Type::ANY_POINTER: // could be a capability, or transitively contain one + case schema::Type::INTERFACE: // definitely a capability + foundAnyCaps = true; + break; + + case schema::Type::LIST: + KJ_UNREACHABLE; // handled above + } + break; + } + } + + if (foundAnyCaps) break; // no point continuing + } + + if (foundAnyCaps) { + // Definitely has capabilities, don't add to `undecided`. + } else if (!foundAnyStructs) { + // Definitely does NOT have capabilities. Go ahead and set the hint and don't add to + // `undecided`. + schema->mayContainCapabilities = false; + } else { + // Don't know yet. Mark as no-capabilities for now, but place in `undecided` set to review + // later. + schema->mayContainCapabilities = false; + undecided.insert(schema, {}); + } + } + + // Second pass: For all undecided schemas, check dependencies and register as dependents where + // needed. + kj::Vector<_::RawSchema*> decisions; // Schemas that have become decided. + for (auto& entry: undecided) { + auto schema = entry.key; + + auto node = readMessageUnchecked(schema->encodedNode).getStruct(); + + for (auto field: node.getFields()) { + kj::Maybe depId; + + switch (field.which()) { + case schema::Field::GROUP: + depId = field.getGroup().getTypeId(); + break; + case schema::Field::SLOT: { + auto type = field.getSlot().getType(); + while (type.isList()) { + type = type.getList().getElementType(); + } + if (type.isStruct()) { + depId = type.getStruct().getTypeId(); + } + break; + } + } + + KJ_IF_MAYBE(d, depId) { + _::RawSchema* dep = KJ_ASSERT_NONNULL(schemas.find(*d)); + + if (dep->mayContainCapabilities) { + // Oops, this dependency is already known to have capabilities. So that means the current + // schema also has capabilities, transitively. Mark it as such. + schema->mayContainCapabilities = true; + + // Schedule this schema for removal later. + decisions.add(schema); + + // Might as well end the loop early. + break; + } else KJ_IF_MAYBE(undecidedEntry, undecided.find(dep)) { + // This dependency is in the undecided set. Register interest in it. + undecidedEntry->add(schema); + } else { + // This dependency is decided, and the decision is that it has no capabilities. So it + // has no impact on the dependent. + } + } + } + } + + // Third pass: For each decision we made, remove it and propagate to its dependents. + while (!decisions.empty()) { + _::RawSchema* decision = decisions.back(); + decisions.removeLast(); + + auto& entry = KJ_ASSERT_NONNULL(undecided.findEntry(decision)); + for (auto& dependent: entry.value) { + if (!dependent->mayContainCapabilities) { + // The dependent was not previously decided. But, we now know it has a dependency which has + // capabilities, therefore we can decide the dependent. + dependent->mayContainCapabilities = true; + decisions.add(dependent); + } + } + undecided.erase(entry); + } + + // Everything that is left in `undecided` must only be waiting on other undecided schemas. We + // can therefore decide that none of them have any capabilities. We marked them as such + // earlier so now we're all done. +} + void SchemaLoader::Impl::requireStructSize(uint64_t id, uint dataWordCount, uint pointerCount) { - auto& slot = structSizeRequirements[id]; - slot.dataWordCount = kj::max(slot.dataWordCount, dataWordCount); - slot.pointerCount = kj::max(slot.pointerCount, pointerCount); + structSizeRequirements.upsert(id, { uint16_t(dataWordCount), uint16_t(pointerCount) }, + [&](RequiredSize& existingValue, RequiredSize&& newValue) { + existingValue.dataWordCount = kj::max(existingValue.dataWordCount, newValue.dataWordCount); + existingValue.pointerCount = kj::max(existingValue.pointerCount, newValue.pointerCount); + }); - auto iter = schemas.find(id); - if (iter != schemas.end()) { - applyStructSizeRequirement(iter->second, dataWordCount, pointerCount); + KJ_IF_MAYBE(schema, schemas.find(id)) { + applyStructSizeRequirement(*schema, dataWordCount, pointerCount); } } @@ -1868,14 +2021,12 @@ kj::ArrayPtr SchemaLoader::Impl::makeUncheckedNode(schema::Node::Reader no kj::ArrayPtr SchemaLoader::Impl::makeUncheckedNodeEnforcingSizeRequirements( schema::Node::Reader node) { if (node.isStruct()) { - auto iter = structSizeRequirements.find(node.getId()); - if (iter != structSizeRequirements.end()) { - auto requirement = iter->second; + KJ_IF_MAYBE(requirement, structSizeRequirements.find(node.getId())) { auto structNode = node.getStruct(); - if (structNode.getDataWordCount() < requirement.dataWordCount || - structNode.getPointerCount() < requirement.pointerCount) { - return rewriteStructNodeWithSizes(node, requirement.dataWordCount, - requirement.pointerCount); + if (structNode.getDataWordCount() < requirement->dataWordCount || + structNode.getPointerCount() < requirement->pointerCount) { + return rewriteStructNodeWithSizes(node, requirement->dataWordCount, + requirement->pointerCount); } } } @@ -1932,7 +2083,7 @@ void SchemaLoader::InitializerImpl::init(const _::RawSchema* schema) const { "A schema not belonging to this loader used its initializer."); // Disable the initializer. -#if __GNUC__ +#if __GNUC__ || defined(__clang__) __atomic_store_n(&mutableSchema->lazyInitializer, nullptr, __ATOMIC_RELEASE); __atomic_store_n(&mutableSchema->defaultBrand.lazyInitializer, nullptr, __ATOMIC_RELEASE); #elif _MSC_VER @@ -1958,10 +2109,8 @@ void SchemaLoader::BrandedInitializerImpl::init(const _::RawBrandedSchema* schem } // Get the mutable version. - auto iter = lock->get()->brands.find(SchemaBindingsPair { schema->generic, schema->scopes }); - KJ_ASSERT(iter != lock->get()->brands.end()); - - _::RawBrandedSchema* mutableSchema = iter->second; + _::RawBrandedSchema* mutableSchema = KJ_ASSERT_NONNULL( + lock->get()->brands.find(SchemaBindingsPair { schema->generic, schema->scopes })); KJ_ASSERT(mutableSchema == schema); // Construct its dependency map. @@ -1971,7 +2120,7 @@ void SchemaLoader::BrandedInitializerImpl::init(const _::RawBrandedSchema* schem mutableSchema->dependencyCount = deps.size(); // It's initialized now, so disable the initializer. -#if __GNUC__ +#if __GNUC__ || defined(__clang__) __atomic_store_n(&mutableSchema->lazyInitializer, nullptr, __ATOMIC_RELEASE); #elif _MSC_VER std::atomic_thread_fence(std::memory_order_release); @@ -2011,7 +2160,10 @@ kj::Maybe SchemaLoader::tryGet( if (getResult.schema != nullptr && getResult.schema->lazyInitializer == nullptr) { if (brand.getScopes().size() > 0) { auto brandedSchema = impl.lockExclusive()->get()->makeBranded( - getResult.schema, brand, kj::arrayPtr(scope.raw->scopes, scope.raw->scopeCount)); + getResult.schema, brand, + scope.raw->isUnbound() + ? kj::Maybe>(nullptr) + : kj::arrayPtr(scope.raw->scopes, scope.raw->scopeCount)); brandedSchema->ensureInitialized(); return Schema(brandedSchema); } else { @@ -2105,6 +2257,10 @@ kj::Array SchemaLoader::getAllLoaded() const { return impl.lockShared()->get()->getAllLoaded(); } +void SchemaLoader::computeOptimizationHints() { + impl.lockExclusive()->get()->computeOptimizationHints(); +} + void SchemaLoader::loadNative(const _::RawSchema* nativeSchema) { impl.lockExclusive()->get()->loadNative(nativeSchema); } diff --git a/c++/src/capnp/schema-loader.h b/c++/src/capnp/schema-loader.h index 0e34cba77f..5db8364c42 100644 --- a/c++/src/capnp/schema-loader.h +++ b/c++/src/capnp/schema-loader.h @@ -19,17 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_SCHEMA_LOADER_H_ -#define CAPNP_SCHEMA_LOADER_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "schema.h" #include #include +CAPNP_BEGIN_HEADER + namespace capnp { class SchemaLoader { @@ -67,7 +64,7 @@ class SchemaLoader { // that isn't already loaded. ~SchemaLoader() noexcept(false); - KJ_DISALLOW_COPY(SchemaLoader); + KJ_DISALLOW_COPY_AND_MOVE(SchemaLoader); Schema get(uint64_t id, schema::Brand::Reader brand = schema::Brand::Reader(), Schema scope = Schema()) const; @@ -152,6 +149,19 @@ class SchemaLoader { // loadCompiledTypeAndDependencies() in order to get a flat list of all of T's transitive // dependencies. + void computeOptimizationHints(); + // Call after all interesting schemas have been loaded to compute optimization hints. In + // particular, this initializes `hasNoCapabilities` for every struct type. Before this is called, + // that value is initialized to false for all types (which ensures correct behavior but does not + // allow the optimization). + // + // If any loaded struct types contain fields of types for which no schema has been loaded, they + // will be presumed to possibly contain capabilities. `LazyLoadCallback` will NOT be invoked to + // load any types that haven't been loaded yet. + // + // TODO(someday): Perhaps we could dynamically initialize the hints on-demand, but it would be + // much more work to implement. + private: class Validator; class CompatibilityChecker; @@ -170,4 +180,4 @@ inline void SchemaLoader::loadCompiledTypeAndDependencies() { } // namespace capnp -#endif // CAPNP_SCHEMA_LOADER_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/schema-parser-test.c++ b/c++/src/capnp/schema-parser-test.c++ index f0442becc0..5bc618859c 100644 --- a/c++/src/capnp/schema-parser-test.c++ +++ b/c++/src/capnp/schema-parser-test.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#define CAPNP_TESTING_CAPNP 1 + #include "schema-parser.h" #include #include "test-util.h" @@ -28,26 +30,28 @@ namespace capnp { namespace { -class FakeFileReader final: public SchemaFile::FileReader { +#if _WIN32 +#define ABS(x) "C:\\" x +#else +#define ABS(x) "/" x +#endif + +class FakeFileReader final: public kj::Filesystem { public: void add(kj::StringPtr name, kj::StringPtr content) { - files[name] = content; + root->openFile(cwd.evalNative(name), kj::WriteMode::CREATE | kj::WriteMode::CREATE_PARENT) + ->writeAll(content); } - bool exists(kj::StringPtr path) const override { - return files.count(path) > 0; - } - - kj::Array read(kj::StringPtr path) const override { - auto iter = files.find(path); - KJ_ASSERT(iter != files.end(), "FakeFileReader has no such file.", path); - auto result = kj::heapArray(iter->second.size()); - memcpy(result.begin(), iter->second.begin(), iter->second.size()); - return kj::mv(result); - } + const kj::Directory& getRoot() const override { return *root; } + const kj::Directory& getCurrent() const override { return *current; } + kj::PathPtr getCurrentPath() const override { return cwd; } private: - std::map files; + kj::Own root = kj::newInMemoryDirectory(kj::nullClock()); + kj::Path cwd = kj::Path({}).evalNative(ABS("path/to/current/dir")); + kj::Own current = root->openSubdir(cwd, + kj::WriteMode::CREATE | kj::WriteMode::CREATE_PARENT); }; static uint64_t getFieldTypeFileId(StructSchema::Field field) { @@ -57,8 +61,9 @@ static uint64_t getFieldTypeFileId(StructSchema::Field field) { } TEST(SchemaParser, Basic) { - SchemaParser parser; FakeFileReader reader; + SchemaParser parser; + parser.setDiskFilesystem(reader); reader.add("src/foo/bar.capnp", "@0x8123456789abcdef;\n" @@ -74,22 +79,22 @@ TEST(SchemaParser, Basic) { reader.add("src/qux/corge.capnp", "@0x83456789abcdef12;\n" "struct Corge {}\n"); - reader.add("/usr/include/grault.capnp", + reader.add(ABS("usr/include/grault.capnp"), "@0x8456789abcdef123;\n" "struct Grault {}\n"); - reader.add("/opt/include/grault.capnp", + reader.add(ABS("opt/include/grault.capnp"), "@0x8000000000000001;\n" "struct WrongGrault {}\n"); - reader.add("/usr/local/include/garply.capnp", + reader.add(ABS("usr/local/include/garply.capnp"), "@0x856789abcdef1234;\n" "struct Garply {}\n"); kj::StringPtr importPath[] = { - "/usr/include", "/usr/local/include", "/opt/include" + ABS("usr/include"), ABS("usr/local/include"), ABS("opt/include") }; - ParsedSchema barSchema = parser.parseFile(SchemaFile::newDiskFile( - "foo2/bar2.capnp", "src/foo/bar.capnp", importPath, reader)); + ParsedSchema barSchema = parser.parseDiskFile( + "foo2/bar2.capnp", "src/foo/bar.capnp", importPath); auto barProto = barSchema.getProto(); EXPECT_EQ(0x8123456789abcdefull, barProto.getId()); @@ -107,25 +112,33 @@ TEST(SchemaParser, Basic) { EXPECT_EQ("garply", barFields[3].getProto().getName()); EXPECT_EQ(0x856789abcdef1234ull, getFieldTypeFileId(barFields[3])); - auto bazSchema = parser.parseFile(SchemaFile::newDiskFile( + auto barStructs = barSchema.getAllNested(); + ASSERT_EQ(1, barStructs.size()); + EXPECT_EQ("Bar", barStructs[0].getUnqualifiedName()); + barFields = barStructs[0].asStruct().getFields(); + ASSERT_EQ(4u, barFields.size()); + EXPECT_EQ("baz", barFields[0].getProto().getName()); + EXPECT_EQ(0x823456789abcdef1ull, getFieldTypeFileId(barFields[0])); + + auto bazSchema = parser.parseDiskFile( "not/used/because/already/loaded", - "src/foo/baz.capnp", importPath, reader)); + "src/foo/baz.capnp", importPath); EXPECT_EQ(0x823456789abcdef1ull, bazSchema.getProto().getId()); EXPECT_EQ("foo2/baz.capnp", bazSchema.getProto().getDisplayName()); auto bazStruct = bazSchema.getNested("Baz").asStruct(); EXPECT_EQ(bazStruct, barStruct.getDependency(bazStruct.getProto().getId())); - auto corgeSchema = parser.parseFile(SchemaFile::newDiskFile( + auto corgeSchema = parser.parseDiskFile( "not/used/because/already/loaded", - "src/qux/corge.capnp", importPath, reader)); + "src/qux/corge.capnp", importPath); EXPECT_EQ(0x83456789abcdef12ull, corgeSchema.getProto().getId()); EXPECT_EQ("qux/corge.capnp", corgeSchema.getProto().getDisplayName()); auto corgeStruct = corgeSchema.getNested("Corge").asStruct(); EXPECT_EQ(corgeStruct, barStruct.getDependency(corgeStruct.getProto().getId())); - auto graultSchema = parser.parseFile(SchemaFile::newDiskFile( + auto graultSchema = parser.parseDiskFile( "not/used/because/already/loaded", - "/usr/include/grault.capnp", importPath, reader)); + ABS("usr/include/grault.capnp"), importPath); EXPECT_EQ(0x8456789abcdef123ull, graultSchema.getProto().getId()); EXPECT_EQ("grault.capnp", graultSchema.getProto().getDisplayName()); auto graultStruct = graultSchema.getNested("Grault").asStruct(); @@ -133,9 +146,9 @@ TEST(SchemaParser, Basic) { // Try importing the other grault.capnp directly. It'll get the display name we specify since // it wasn't imported before. - auto wrongGraultSchema = parser.parseFile(SchemaFile::newDiskFile( + auto wrongGraultSchema = parser.parseDiskFile( "weird/display/name.capnp", - "/opt/include/grault.capnp", importPath, reader)); + ABS("opt/include/grault.capnp"), importPath); EXPECT_EQ(0x8000000000000001ull, wrongGraultSchema.getProto().getId()); EXPECT_EQ("weird/display/name.capnp", wrongGraultSchema.getProto().getDisplayName()); } @@ -145,8 +158,9 @@ TEST(SchemaParser, Constants) { // constants are not actually accessible from the generated code API, so the only way to ever // get a ConstSchema is by parsing it. - SchemaParser parser; FakeFileReader reader; + SchemaParser parser; + parser.setDiskFilesystem(reader); reader.add("const.capnp", "@0x8123456789abcdef;\n" @@ -162,8 +176,8 @@ TEST(SchemaParser, Constants) { " value @0 :T;\n" "}\n"); - ParsedSchema fileSchema = parser.parseFile(SchemaFile::newDiskFile( - "const.capnp", "const.capnp", nullptr, reader)); + ParsedSchema fileSchema = parser.parseDiskFile( + "const.capnp", "const.capnp", nullptr); EXPECT_EQ(1234, fileSchema.getNested("uint32Const").asConst().as()); @@ -181,5 +195,104 @@ TEST(SchemaParser, Constants) { EXPECT_EQ("text", genericConst.get("value").as()); } +void expectSourceInfo(schema::Node::SourceInfo::Reader sourceInfo, + uint64_t expectedId, kj::StringPtr expectedComment, + std::initializer_list expectedMembers) { + KJ_EXPECT(sourceInfo.getId() == expectedId, sourceInfo, expectedId); + KJ_EXPECT(sourceInfo.getDocComment() == expectedComment, sourceInfo, expectedComment); + + auto members = sourceInfo.getMembers(); + KJ_ASSERT(members.size() == expectedMembers.size()); + for (auto i: kj::indices(expectedMembers)) { + KJ_EXPECT(members[i].getDocComment() == expectedMembers.begin()[i], + members[i], expectedMembers.begin()[i]); + } +} + +TEST(SchemaParser, SourceInfo) { + FakeFileReader reader; + SchemaParser parser; + parser.setDiskFilesystem(reader); + + reader.add("foo.capnp", + "@0x84a2c6051e1061ed;\n" + "# file doc comment\n" + "\n" + "struct Foo @0xc6527d0a670dc4c3 {\n" + " # struct doc comment\n" + " # second line\n" + "\n" + " bar @0 :UInt32;\n" + " # field doc comment\n" + " baz :group {\n" + " # group doc comment\n" + " qux @1 :Text;\n" + " # group field doc comment\n" + " }\n" + "}\n" + "\n" + "enum Corge @0xae08878f1a016f14 {\n" + " # enum doc comment\n" + " grault @0;\n" + " # enumerant doc comment\n" + " garply @1;\n" + "}\n" + "\n" + "interface Waldo @0xc0f1b0aff62b761e {\n" + " # interface doc comment\n" + " fred @0 (plugh :Int32) -> (xyzzy :Text);\n" + " # method doc comment\n" + "}\n" + "\n" + "struct Thud @0xcca9972702b730b4 {}\n" + "# post-comment\n"); + + ParsedSchema file = parser.parseDiskFile( + "foo.capnp", "foo.capnp", nullptr); + ParsedSchema foo = file.getNested("Foo"); + + expectSourceInfo(file.getSourceInfo(), 0x84a2c6051e1061edull, "file doc comment\n", {}); + + expectSourceInfo(foo.getSourceInfo(), 0xc6527d0a670dc4c3ull, "struct doc comment\nsecond line\n", + { "field doc comment\n", "group doc comment\n" }); + + auto group = foo.asStruct().getFieldByName("baz").getType().asStruct(); + expectSourceInfo(KJ_ASSERT_NONNULL(parser.getSourceInfo(group)), + group.getProto().getId(), "group doc comment\n", { "group field doc comment\n" }); + + ParsedSchema corge = file.getNested("Corge"); + expectSourceInfo(corge.getSourceInfo(), 0xae08878f1a016f14, "enum doc comment\n", + { "enumerant doc comment\n", "" }); + + ParsedSchema waldo = file.getNested("Waldo"); + expectSourceInfo(waldo.getSourceInfo(), 0xc0f1b0aff62b761e, "interface doc comment\n", + { "method doc comment\n" }); + + ParsedSchema thud = file.getNested("Thud"); + expectSourceInfo(thud.getSourceInfo(), 0xcca9972702b730b4, "post-comment\n", {}); +} + +TEST(SchemaParser, SetFileIdsRequired) { + FakeFileReader reader; + reader.add("no-file-id.capnp", + "const foo :Int32 = 123;\n"); + + { + SchemaParser parser; + parser.setDiskFilesystem(reader); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("File does not declare an ID.", + parser.parseDiskFile("no-file-id.capnp", "no-file-id.capnp", nullptr)); + } + { + SchemaParser parser; + parser.setDiskFilesystem(reader); + parser.setFileIdsRequired(false); + + auto fileSchema = parser.parseDiskFile("no-file-id.capnp", "no-file-id.capnp", nullptr); + KJ_EXPECT(fileSchema.getNested("foo").asConst().as() == 123); + } +} + } // namespace } // namespace capnp diff --git a/c++/src/capnp/schema-parser.c++ b/c++/src/capnp/schema-parser.c++ index 3f20494da5..b5aeec12da 100644 --- a/c++/src/capnp/schema-parser.c++ +++ b/c++/src/capnp/schema-parser.c++ @@ -31,17 +31,7 @@ #include #include #include -#include -#include -#include -#include -#include - -#if _WIN32 -#include -#else -#include -#endif +#include namespace capnp { @@ -98,7 +88,7 @@ public: compiler::lex(content, statements, *this); auto parsed = orphanage.newOrphan(); - compiler::parseFile(statements.getStatements(), parsed.get(), *this); + compiler::parseFile(statements.getStatements(), parsed.get(), *this, parser.fileIdsRequired); return parsed; } @@ -158,7 +148,7 @@ private: namespace { struct SchemaFileHash { - inline bool operator()(const SchemaFile* f) const { + inline size_t operator()(const SchemaFile* f) const { return f->hashCode(); } }; @@ -171,31 +161,143 @@ struct SchemaFileEq { } // namespace +struct SchemaParser::DiskFileCompat { + // Stuff we only create if parseDiskFile() is ever called, in order to translate that call into + // KJ filesystem API calls. + + kj::Own ownFs; + kj::Filesystem& fs; + + struct ImportDir { + kj::String pathStr; + kj::Path path; + kj::Own dir; + }; + std::map cachedImportDirs; + + std::map, kj::Array> + cachedImportPaths; + + DiskFileCompat(): ownFs(kj::newDiskFilesystem()), fs(*ownFs) {} + DiskFileCompat(kj::Filesystem& fs): fs(fs) {} +}; + struct SchemaParser::Impl { typedef std::unordered_map< const SchemaFile*, kj::Own, SchemaFileHash, SchemaFileEq> FileMap; kj::MutexGuarded fileMap; compiler::Compiler compiler; + + kj::MutexGuarded> compat; }; SchemaParser::SchemaParser(): impl(kj::heap()) {} SchemaParser::~SchemaParser() noexcept(false) {} +ParsedSchema SchemaParser::parseFromDirectory( + const kj::ReadableDirectory& baseDir, kj::Path path, + kj::ArrayPtr importPath) const { + return parseFile(SchemaFile::newFromDirectory(baseDir, kj::mv(path), importPath)); +} + ParsedSchema SchemaParser::parseDiskFile( kj::StringPtr displayName, kj::StringPtr diskPath, kj::ArrayPtr importPath) const { - return parseFile(SchemaFile::newDiskFile(displayName, diskPath, importPath)); + auto lock = impl->compat.lockExclusive(); + DiskFileCompat* compat; + KJ_IF_MAYBE(c, *lock) { + compat = c; + } else { + compat = &lock->emplace(); + } + + auto& root = compat->fs.getRoot(); + auto cwd = compat->fs.getCurrentPath(); + + const kj::ReadableDirectory* baseDir = &root; + kj::Path path = cwd.evalNative(diskPath); + + kj::ArrayPtr translatedImportPath = nullptr; + + if (importPath.size() > 0) { + auto importPathKey = std::make_pair(importPath.begin(), importPath.size()); + auto& slot = compat->cachedImportPaths[importPathKey]; + + if (slot == nullptr) { + slot = KJ_MAP(path, importPath) -> const kj::ReadableDirectory* { + auto iter = compat->cachedImportDirs.find(path); + if (iter != compat->cachedImportDirs.end()) { + return iter->second.dir; + } + + auto parsed = cwd.evalNative(path); + kj::Own dir; + KJ_IF_MAYBE(d, root.tryOpenSubdir(parsed)) { + dir = kj::mv(*d); + } else { + // Ignore paths that don't exist. + dir = kj::newInMemoryDirectory(kj::nullClock()); + } + + const kj::ReadableDirectory* result = dir; + + kj::StringPtr pathRef = path; + KJ_ASSERT(compat->cachedImportDirs.insert(std::make_pair(pathRef, + DiskFileCompat::ImportDir { kj::str(path), kj::mv(parsed), kj::mv(dir) })).second); + + return result; + }; + } + + translatedImportPath = slot; + + // Check if `path` appears to be inside any of the import path directories. If so, adjust + // to be relative to that directory rather than absolute. + kj::Maybe matchedImportDir; + size_t bestMatchLength = 0; + for (auto importDir: importPath) { + auto iter = compat->cachedImportDirs.find(importDir); + KJ_ASSERT(iter != compat->cachedImportDirs.end()); + + if (path.startsWith(iter->second.path)) { + // Looks like we're trying to load a file from inside this import path. Treat the import + // path as the base directory. + if (iter->second.path.size() > bestMatchLength) { + bestMatchLength = iter->second.path.size(); + matchedImportDir = iter->second; + } + } + } + + KJ_IF_MAYBE(match, matchedImportDir) { + baseDir = match->dir; + path = path.slice(match->path.size(), path.size()).clone(); + } + } + + return parseFile(SchemaFile::newFromDirectory( + *baseDir, kj::mv(path), translatedImportPath, kj::str(displayName))); +} + +void SchemaParser::setDiskFilesystem(kj::Filesystem& fs) { + auto lock = impl->compat.lockExclusive(); + KJ_REQUIRE(*lock == nullptr, "already called parseDiskFile() or setDiskFilesystem()"); + lock->emplace(fs); } ParsedSchema SchemaParser::parseFile(kj::Own&& file) const { KJ_DEFER(impl->compiler.clearWorkspace()); - uint64_t id = impl->compiler.add(getModuleImpl(kj::mv(file))); + uint64_t id = impl->compiler.add(getModuleImpl(kj::mv(file))).getId(); impl->compiler.eagerlyCompile(id, compiler::Compiler::NODE | compiler::Compiler::CHILDREN | compiler::Compiler::DEPENDENCIES | compiler::Compiler::DEPENDENCY_DEPENDENCIES); return ParsedSchema(impl->compiler.getLoader().get(id), *this); } +kj::Maybe SchemaParser::getSourceInfo(Schema schema) const { + return impl->compiler.getSourceInfo(schema.getProto().getId()); +} + SchemaParser::ModuleImpl& SchemaParser::getModuleImpl(kj::Own&& file) const { auto lock = impl->fileMap.lockExclusive(); @@ -211,7 +313,14 @@ SchemaLoader& SchemaParser::getLoader() { return impl->compiler.getLoader(); } +const SchemaLoader& SchemaParser::getLoader() const { + return impl->compiler.getLoader(); +} + kj::Maybe ParsedSchema::findNested(kj::StringPtr name) const { + // TODO(someday): lookup() doesn't handle generics correctly. Use the ModuleScope/CompiledType + // interface instead. We can also add an applybrand() method to ParsedSchema using those + // interfaces, which would allow us to expose generics more explicitly to e.g. Python. return parser->impl->compiler.lookup(getProto().getId(), name).map( [this](uint64_t childId) { return ParsedSchema(parser->impl->compiler.getLoader().get(childId), *parser); @@ -226,226 +335,74 @@ ParsedSchema ParsedSchema::getNested(kj::StringPtr nestedName) const { } } -// ======================================================================================= - -namespace { - -class MmapDisposer: public kj::ArrayDisposer { -protected: - void disposeImpl(void* firstElement, size_t elementSize, size_t elementCount, - size_t capacity, void (*destroyElement)(void*)) const { -#if _WIN32 - KJ_ASSERT(UnmapViewOfFile(firstElement)); -#else - munmap(firstElement, elementSize * elementCount); -#endif - } -}; - -KJ_CONSTEXPR(static const) MmapDisposer mmapDisposer = MmapDisposer(); - -static char* canonicalizePath(char* path) { - // Taken from some old C code of mine. - - // Preconditions: - // - path has already been determined to be relative, perhaps because the pointer actually points - // into the middle of some larger path string, in which case it must point to the character - // immediately after a '/'. - - // Invariants: - // - src points to the beginning of a path component. - // - dst points to the location where the path component should end up, if it is not special. - // - src == path or src[-1] == '/'. - // - dst == path or dst[-1] == '/'. - - char* src = path; - char* dst = path; - char* locked = dst; // dst cannot backtrack past this - char* partEnd; - bool hasMore; - - for (;;) { - while (*src == '/') { - // Skip duplicate slash. - ++src; - } - - partEnd = strchr(src, '/'); - hasMore = partEnd != NULL; - if (hasMore) { - *partEnd = '\0'; - } else { - partEnd = src + strlen(src); - } - - if (strcmp(src, ".") == 0) { - // Skip it. - } else if (strcmp(src, "..") == 0) { - if (dst > locked) { - // Backtrack over last path component. - --dst; - while (dst > locked && dst[-1] != '/') --dst; - } else { - locked += 3; - goto copy; - } - } else { - // Copy if needed. - copy: - if (dst < src) { - memmove(dst, src, partEnd - src); - dst += partEnd - src; - } else { - dst = partEnd; - } - *dst++ = '/'; - } - - if (hasMore) { - src = partEnd + 1; - } else { - // Oops, we have to remove the trailing '/'. - if (dst == path) { - // Oops, there is no trailing '/'. We have to return ".". - strcpy(path, "."); - return path + 1; - } else { - // Remove the trailing '/'. Note that this means that opening the file will work even - // if it is not a directory, where normally it should fail on non-directories when a - // trailing '/' is present. If this is a problem, we need to add some sort of special - // handling for this case where we stat() it separately to check if it is a directory, - // because Ekam findInput will not accept a trailing '/'. - --dst; - *dst = '\0'; - return dst; - } - } - } -} - -kj::String canonicalizePath(kj::StringPtr path) { - KJ_STACK_ARRAY(char, result, path.size() + 1, 128, 512); - strcpy(result.begin(), path.begin()); - - char* start = path.startsWith("/") ? result.begin() + 1 : result.begin(); - char* end = canonicalizePath(start); - return kj::heapString(result.slice(0, end - result.begin())); -} - -kj::String relativePath(kj::StringPtr base, kj::StringPtr add) { - if (add.size() > 0 && add[0] == '/') { - return kj::heapString(add); - } - - const char* pos = base.end(); - while (pos > base.begin() && pos[-1] != '/') { - --pos; - } - - return kj::str(base.slice(0, pos - base.begin()), add); +ParsedSchema::ParsedSchemaList ParsedSchema::getAllNested() const { + return ParsedSchemaList(*this, getProto().getNestedNodes()); } -kj::String joinPath(kj::StringPtr base, kj::StringPtr add) { - KJ_REQUIRE(!add.startsWith("/")); - - return kj::str(base, '/', add); +schema::Node::SourceInfo::Reader ParsedSchema::getSourceInfo() const { + return KJ_ASSERT_NONNULL(parser->getSourceInfo(*this)); } -} // namespace - -const SchemaFile::DiskFileReader SchemaFile::DiskFileReader::instance = - SchemaFile::DiskFileReader(); - -bool SchemaFile::DiskFileReader::exists(kj::StringPtr path) const { - return access(path.cStr(), F_OK) == 0; -} - -kj::Array SchemaFile::DiskFileReader::read(kj::StringPtr path) const { - int fd; - // We already established that the file exists, so this should not fail. - KJ_SYSCALL(fd = open(path.cStr(), O_RDONLY), path); - kj::AutoCloseFd closer(fd); - - struct stat stats; - KJ_SYSCALL(fstat(fd, &stats)); - - if (S_ISREG(stats.st_mode)) { - if (stats.st_size == 0) { - // mmap()ing zero bytes will fail. - return nullptr; - } - - // Regular file. Just mmap() it. -#if _WIN32 - HANDLE handle = reinterpret_cast(_get_osfhandle(fd)); - KJ_ASSERT(handle != INVALID_HANDLE_VALUE); - HANDLE mappingHandle = CreateFileMapping( - handle, NULL, PAGE_READONLY, 0, stats.st_size, NULL); - KJ_ASSERT(mappingHandle != INVALID_HANDLE_VALUE); - KJ_DEFER(KJ_ASSERT(CloseHandle(mappingHandle))); - const void* mapping = MapViewOfFile(mappingHandle, FILE_MAP_READ, 0, 0, stats.st_size); -#else // _WIN32 - const void* mapping = mmap(NULL, stats.st_size, PROT_READ, MAP_SHARED, fd, 0); - if (mapping == MAP_FAILED) { - KJ_FAIL_SYSCALL("mmap", errno, path); - } -#endif // !_WIN32 - - return kj::Array( - reinterpret_cast(mapping), stats.st_size, mmapDisposer); - } else { - // This could be a stream of some sort, like a pipe. Fall back to read(). - // TODO(cleanup): This does a lot of copies. Not sure I care. - kj::Vector data(8192); - - char buffer[4096]; - for (;;) { - kj::miniposix::ssize_t n; - KJ_SYSCALL(n = ::read(fd, buffer, sizeof(buffer))); - if (n == 0) break; - data.addAll(buffer, buffer + n); - } +// ------------------------------------------------------------------- - return data.releaseAsArray(); - } +ParsedSchema ParsedSchema::ParsedSchemaList::operator[](uint index) const { + return ParsedSchema( + parent.parser->impl->compiler.getLoader().get(list[index].getId()), + *parent.parser); } // ------------------------------------------------------------------- class SchemaFile::DiskSchemaFile final: public SchemaFile { public: - DiskSchemaFile(const FileReader& fileReader, kj::String displayName, - kj::String diskPath, kj::ArrayPtr importPath) - : fileReader(fileReader), - displayName(kj::mv(displayName)), - diskPath(kj::mv(diskPath)), - importPath(importPath) {} + DiskSchemaFile(const kj::ReadableDirectory& baseDir, kj::Path pathParam, + kj::ArrayPtr importPath, + kj::Own file, + kj::Maybe displayNameOverride) + : baseDir(baseDir), path(kj::mv(pathParam)), importPath(importPath), file(kj::mv(file)) { + KJ_IF_MAYBE(dn, displayNameOverride) { + displayName = kj::mv(*dn); + displayNameOverridden = true; + } else { + displayName = path.toString(); + displayNameOverridden = false; + } + } kj::StringPtr getDisplayName() const override { return displayName; } kj::Array readContent() const override { - return fileReader.read(diskPath); + return file->mmap(0, file->stat().size).releaseAsChars(); } - kj::Maybe> import(kj::StringPtr path) const override { - if (path.startsWith("/")) { + kj::Maybe> import(kj::StringPtr target) const override { + if (target.startsWith("/")) { + auto parsed = kj::Path::parse(target.slice(1)); for (auto candidate: importPath) { - kj::String newDiskPath = canonicalizePath(joinPath(candidate, path.slice(1))); - if (fileReader.exists(newDiskPath)) { + KJ_IF_MAYBE(newFile, candidate->tryOpenFile(parsed)) { return kj::implicitCast>(kj::heap( - fileReader, canonicalizePath(path.slice(1)), - kj::mv(newDiskPath), importPath)); + *candidate, kj::mv(parsed), importPath, kj::mv(*newFile), nullptr)); } } return nullptr; } else { - kj::String newDiskPath = canonicalizePath(relativePath(diskPath, path)); - if (fileReader.exists(newDiskPath)) { + auto parsed = path.parent().eval(target); + + kj::Maybe displayNameOverride; + if (displayNameOverridden) { + // Try to create a consistent display name override for the imported file. This is for + // backwards-compatibility only -- display names are only overridden when using the + // deprecated parseDiskFile() interface. + kj::runCatchingExceptions([&]() { + displayNameOverride = kj::Path::parse(displayName).parent().eval(target).toString(); + }); + } + + KJ_IF_MAYBE(newFile, baseDir.tryOpenFile(parsed)) { return kj::implicitCast>(kj::heap( - fileReader, canonicalizePath(relativePath(displayName, path)), - kj::mv(newDiskPath), importPath)); + baseDir, kj::mv(parsed), importPath, kj::mv(*newFile), kj::mv(displayNameOverride))); } else { return nullptr; } @@ -453,40 +410,46 @@ public: } bool operator==(const SchemaFile& other) const override { - return diskPath == kj::downcast(other).diskPath; + auto& other2 = kj::downcast(other); + return &baseDir == &other2.baseDir && path == other2.path; } bool operator!=(const SchemaFile& other) const override { - return diskPath != kj::downcast(other).diskPath; + return !operator==(other); } size_t hashCode() const override { // djb hash with xor // TODO(someday): Add hashing library to KJ. - size_t result = 5381; - for (char c: diskPath) { - result = (result * 33) ^ c; + size_t result = reinterpret_cast(&baseDir); + for (auto& part: path) { + for (char c: part) { + result = (result * 33) ^ c; + } + result = (result * 33) ^ '/'; } return result; } void reportError(SourcePos start, SourcePos end, kj::StringPtr message) const override { kj::getExceptionCallback().onRecoverableException(kj::Exception( - kj::Exception::Type::FAILED, kj::heapString(diskPath), start.line, + kj::Exception::Type::FAILED, path.toString(), start.line, kj::heapString(message))); } private: - const FileReader& fileReader; + const kj::ReadableDirectory& baseDir; + kj::Path path; + kj::ArrayPtr importPath; + kj::Own file; kj::String displayName; - kj::String diskPath; - kj::ArrayPtr importPath; + bool displayNameOverridden; }; -kj::Own SchemaFile::newDiskFile( - kj::StringPtr displayName, kj::StringPtr diskPath, - kj::ArrayPtr importPath, - const FileReader& fileReader) { - return kj::heap(fileReader, canonicalizePath(displayName), - canonicalizePath(diskPath), importPath); +kj::Own SchemaFile::newFromDirectory( + const kj::ReadableDirectory& baseDir, kj::Path path, + kj::ArrayPtr importPath, + kj::Maybe displayNameOverride) { + return kj::heap(baseDir, kj::mv(path), importPath, baseDir.openFile(path), + kj::mv(displayNameOverride)); } } // namespace capnp diff --git a/c++/src/capnp/schema-parser.h b/c++/src/capnp/schema-parser.h index 3322bbfbfb..6c48763771 100644 --- a/c++/src/capnp/schema-parser.h +++ b/c++/src/capnp/schema-parser.h @@ -19,15 +19,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_SCHEMA_PARSER_H_ -#define CAPNP_SCHEMA_PARSER_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "schema-loader.h" #include +#include + +CAPNP_BEGIN_HEADER namespace capnp { @@ -43,31 +41,86 @@ class SchemaParser { SchemaParser(); ~SchemaParser() noexcept(false); - ParsedSchema parseDiskFile(kj::StringPtr displayName, kj::StringPtr diskPath, - kj::ArrayPtr importPath) const; - // Parse a file located on disk. Throws an exception if the file dosen't exist. - // - // Parameters: - // * `displayName`: The name that will appear in the file's schema node. (If the file has - // already been parsed, this will be ignored and the display name from the first time it was - // parsed will be kept.) - // * `diskPath`: The path to the file on disk. - // * `importPath`: Directories to search when resolving absolute imports within this file - // (imports that start with a `/`). Must remain valid until the SchemaParser is destroyed. - // (If the file has already been parsed, this will be ignored and the import path from the - // first time it was parsed will be kept.) + ParsedSchema parseFromDirectory( + const kj::ReadableDirectory& baseDir, kj::Path path, + kj::ArrayPtr importPath) const; + // Parse a file from the KJ filesystem API. Throws an exception if the file doesn't exist. + // + // `baseDir` and `path` are used together to resolve relative imports. `path` is the source + // file's path within `baseDir`. Relative imports will be interpreted relative to `path` and + // will be opened using `baseDir`. Note that the KJ filesystem API prohibits "breaking out" of + // a directory using "..", so relative imports will be restricted to children of `baseDir`. + // + // `importPath` is used for absolute imports (imports that start with a '/'). Each directory in + // the array will be searched in order until a file is found. + // + // All `ReadableDirectory` objects must remain valid until the `SchemaParser` is destroyed. Also, + // the `importPath` array must remain valid. `path` will be copied; it need not remain valid. // // This method is a shortcut, equivalent to: - // parser.parseFile(SchemaFile::newDiskFile(displayName, diskPath, importPath))`; + // parser.parseFile(SchemaFile::newDiskFile(baseDir, path, importPath))`; // // This method throws an exception if any errors are encountered in the file or in anything the // file depends on. Note that merely importing another file does not count as a dependency on // anything in the imported file -- only the imported types which are actually used are // "dependencies". + // + // Hint: Use kj::newDiskFilesystem() to initialize the KJ filesystem API. Usually you should do + // this at a high level in your program, e.g. the main() function, and then pass down the + // appropriate File/Directory objects to the components that need them. Example: + // + // auto fs = kj::newDiskFilesystem(); + // SchemaParser parser; + // auto schema = parser.parseFromDirectory(fs->getCurrent(), + // kj::Path::parse("foo/bar.capnp"), nullptr); + // + // Hint: To use in-memory data rather than real disk, you can use kj::newInMemoryDirectory(), + // write the files you want, then pass it to SchemaParser. Example: + // + // auto dir = kj::newInMemoryDirectory(kj::nullClock()); + // auto path = kj::Path::parse("foo/bar.capnp"); + // dir->openFile(path, kj::WriteMode::CREATE | kj::WriteMode::CREATE_PARENT) + // ->writeAll("struct Foo {}"); + // auto schema = parser.parseFromDirectory(*dir, path, nullptr); + // + // Hint: You can create an in-memory directory but then populate it with real files from disk, + // in order to control what is visible while also avoiding reading files yourself or making + // extra copies. Example: + // + // auto fs = kj::newDiskFilesystem(); + // auto dir = kj::newInMemoryDirectory(kj::nullClock()); + // auto fakePath = kj::Path::parse("foo/bar.capnp"); + // auto realPath = kj::Path::parse("path/to/some/file.capnp"); + // dir->transfer(fakePath, kj::WriteMode::CREATE | kj::WriteMode::CREATE_PARENT, + // fs->getCurrent(), realPath, kj::TransferMode::LINK); + // auto schema = parser.parseFromDirectory(*dir, fakePath, nullptr); + // + // In this example, note that any imports in the file will fail, since the in-memory directory + // you created contains no files except the specific one you linked in. + + ParsedSchema parseDiskFile(kj::StringPtr displayName, kj::StringPtr diskPath, + kj::ArrayPtr importPath) const + CAPNP_DEPRECATED("Use parseFromDirectory() instead."); + // Creates a private kj::Filesystem and uses it to parse files from the real disk. + // + // DO NOT USE in new code. Use parseFromDirectory() instead. + // + // This API has a serious problem: the file can import and embed files located anywhere on disk + // using relative paths. Even if you specify no `importPath`, relative imports still work. By + // using `parseFromDirectory()`, you can arrange so that imports are only allowed within a + // particular directory, or even set up a dummy filesystem where other files are not visible. + + void setDiskFilesystem(kj::Filesystem& fs) + CAPNP_DEPRECATED("Use parseFromDirectory() instead."); + // Call before calling parseDiskFile() to choose an alternative disk filesystem implementation. + // This exists mostly for testing purposes; new code should use parseFromDirectory() instead. + // + // If parseDiskFile() is called without having called setDiskFilesystem(), then + // kj::newDiskFilesystem() will be used instead. ParsedSchema parseFile(kj::Own&& file) const; // Advanced interface for parsing a file that may or may not be located in any global namespace. - // Most users will prefer `parseDiskFile()`. + // Most users will prefer `parseFromDirectory()`. // // If the file has already been parsed (that is, a SchemaFile that compares equal to this one // was parsed previously), the existing schema will be returned again. @@ -77,19 +130,46 @@ class SchemaParser { // normally. In this case, the result is a best-effort attempt to compile the schema, but it // may be invalid or corrupt, and using it for anything may cause exceptions to be thrown. + kj::Maybe getSourceInfo(Schema schema) const; + // Look up source info (e.g. doc comments) for the given schema, which must have come from this + // SchemaParser. Note that this will also work for implicit group and param types that don't have + // a type name hence don't have a `ParsedSchema`. + template inline void loadCompiledTypeAndDependencies() { // See SchemaLoader::loadCompiledTypeAndDependencies(). getLoader().loadCompiledTypeAndDependencies(); } + kj::Array getAllLoaded() const { + // Gets an array of all schema nodes that have been parsed so far. + return getLoader().getAllLoaded(); + } + + void setFileIdsRequired(bool value) { fileIdsRequired = value; } + // By befault, capnp files must declare a file-level type ID (like `@0xbe702824338d3f7f;`). + // Use `setFileIdsReqired(false)` to lift this requirement. + // + // If no ID is specified, a random one will be assigned. This will cause all types declared in + // the file to have randomized IDs as well (unless they declare an ID explicitly), which means + // that parsing the same file twice will appear to to produce a totally new, incompatible set of + // types. In particular, this means that you will not be able to use any interface types in the + // file for RPC, since the RPC protocol uses type IDs to identify methods. + // + // Setting this false is particularly useful when using Cap'n Proto as a config format. Typically + // type IDs are irrelevant for config files, and the requirement to specify one is cumbersome. + // For this reason, `capnp eval` does not require type ID to be present. + private: struct Impl; + struct DiskFileCompat; class ModuleImpl; kj::Own impl; mutable bool hadErrors = false; + bool fileIdsRequired = true; ModuleImpl& getModuleImpl(kj::Own&& file) const; + const SchemaLoader& getLoader() const; SchemaLoader& getLoader(); friend class ParsedSchema; @@ -99,6 +179,9 @@ class ParsedSchema: public Schema { // ParsedSchema is an extension of Schema which also has the ability to look up nested nodes // by name. See `SchemaParser`. + class ParsedSchemaList; + friend class ParsedSchemaList; + public: inline ParsedSchema(): parser(nullptr) {} @@ -110,6 +193,12 @@ class ParsedSchema: public Schema { // Gets the nested node with the given name, or throws an exception if there is no such nested // declaration. + ParsedSchemaList getAllNested() const; + // Get all the nested nodes + + schema::Node::SourceInfo::Reader getSourceInfo() const; + // Get the source info for this schema. + private: inline ParsedSchema(Schema inner, const SchemaParser& parser): Schema(inner), parser(&parser) {} @@ -117,6 +206,27 @@ class ParsedSchema: public Schema { friend class SchemaParser; }; +class ParsedSchema::ParsedSchemaList { +public: + ParsedSchemaList() = default; // empty list + + inline uint size() const { return list.size(); } + ParsedSchema operator[](uint index) const; + + typedef _::IndexingIterator Iterator; + inline Iterator begin() const { return Iterator(this, 0); } + inline Iterator end() const { return Iterator(this, size()); } + +private: + ParsedSchema parent; + List::Reader list; + + inline ParsedSchemaList(ParsedSchema parent, List::Reader list) + : parent(parent), list(list) {} + + friend class ParsedSchema; +}; + // ======================================================================================= // Advanced API @@ -127,44 +237,20 @@ class SchemaFile { // `SchemaFile::newDiskFile()`. public: - class FileReader { - public: - virtual bool exists(kj::StringPtr path) const = 0; - virtual kj::Array read(kj::StringPtr path) const = 0; - }; - - class DiskFileReader final: public FileReader { - // Implementation of FileReader that uses the local disk. Files are read using mmap() if - // possible. - - public: - static const DiskFileReader instance; - - bool exists(kj::StringPtr path) const override; - kj::Array read(kj::StringPtr path) const override; - }; - - static kj::Own newDiskFile( - kj::StringPtr displayName, kj::StringPtr diskPath, - kj::ArrayPtr importPath, - const FileReader& fileReader = DiskFileReader::instance); - // Construct a SchemaFile representing a file on disk (or located in the filesystem-like - // namespace represented by `fileReader`). - // - // Parameters: - // * `displayName`: The name that will appear in the file's schema node. - // * `diskPath`: The path to the file on disk. - // * `importPath`: Directories to search when resolving absolute imports within this file - // (imports that start with a `/`). The array content must remain valid as long as the - // SchemaFile exists (which is at least as long as the SchemaParser that parses it exists). - // * `fileReader`: Allows you to use a filesystem other than the actual local disk. Although, - // if you find yourself using this, it may make more sense for you to implement SchemaFile - // yourself. - // - // The SchemaFile compares equal to any other SchemaFile that has exactly the same disk path, - // after canonicalization. + // Note: Cap'n Proto 0.6.x and below had classes FileReader and DiskFileReader and a method + // newDiskFile() defined here. These were removed when SchemaParser was transitioned to use the + // KJ filesystem API. You should be able to get the same effect by subclassing + // kj::ReadableDirectory, or using kj::newInMemoryDirectory(). + + static kj::Own newFromDirectory( + const kj::ReadableDirectory& baseDir, kj::Path path, + kj::ArrayPtr importPath, + kj::Maybe displayNameOverride = nullptr); + // Construct a SchemaFile representing a file in a kj::ReadableDirectory. This is used to + // implement SchemaParser::parseFromDirectory(); see there for details. // - // The SchemaFile will throw an exception if any errors are reported. + // The SchemaFile compares equal to any other SchemaFile that has exactly the same `baseDir` + // object (by identity) and `path` (by value). // ----------------------------------------------------------------- // For more control, you can implement this interface. @@ -204,4 +290,4 @@ class SchemaFile { } // namespace capnp -#endif // CAPNP_SCHEMA_PARSER_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/schema-test.c++ b/c++/src/capnp/schema-test.c++ index a65d2fcde5..94b3c9a470 100644 --- a/c++/src/capnp/schema-test.c++ +++ b/c++/src/capnp/schema-test.c++ @@ -19,6 +19,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#define CAPNP_TESTING_CAPNP 1 + #include "schema.h" #include #include "test-util.h" @@ -50,6 +52,7 @@ TEST(Schema, Structs) { EXPECT_TRUE(schema.asStruct() == schema); EXPECT_NONFATAL_FAILURE(schema.asEnum()); EXPECT_NONFATAL_FAILURE(schema.asInterface()); + ASSERT_EQ("TestAllTypes", schema.getUnqualifiedName()); ASSERT_EQ(schema.getFields().size(), schema.getProto().getStruct().getFields().size()); StructSchema::Field field = schema.getFields()[0]; @@ -368,6 +371,43 @@ TEST(Schema, Generics) { } } +KJ_TEST("StructSchema::hasNoCapabilites()") { + // At present, TestAllTypes doesn't actually cover interfaces or AnyPointer. + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + KJ_EXPECT(Schema::from().mayContainCapabilities()); + KJ_EXPECT(Schema::from().mayContainCapabilities()); + + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + // Generic arguments could be capabilities. + KJ_EXPECT(Schema::from::Inner>().mayContainCapabilities()); + + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + KJ_EXPECT(!Schema::from().mayContainCapabilities()); + + KJ_EXPECT(Schema::from().mayContainCapabilities()); + KJ_EXPECT(Schema::from().mayContainCapabilities()); +} + +KJ_TEST("list-of-enum as generic type parameter has working schema") { + // Tests for a bug where when a list-of-enum type was used as a type parameter to a generic, + // the schema would be constructed wrong. + auto field = Schema::from() + .getFieldByName("bindEnumList").getType().asStruct() + .getFieldByName("foo"); + auto type = field.getType(); + KJ_ASSERT(type.isList()); + auto elementType = type.asList().getElementType(); + KJ_ASSERT(elementType.isEnum()); + KJ_ASSERT(elementType.asEnum() == Schema::from()); +} + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/c++/src/capnp/schema.c++ b/c++/src/capnp/schema.c++ index 36bf832b99..9ec5df62ea 100644 --- a/c++/src/capnp/schema.c++ +++ b/c++/src/capnp/schema.c++ @@ -22,9 +22,16 @@ #include "schema.h" #include "message.h" #include +#include namespace capnp { +namespace schema { + uint KJ_HASHCODE(Type::Which w) { return kj::hashCode(static_cast(w)); } + // TODO(cleanup): Cap'n Proto does not declare stringifiers nor hashers for `Which` enums, unlike + // all other enums. Fix that and remove this. +} + namespace _ { // private // Null schemas generated using the below schema file with: @@ -258,6 +265,19 @@ Schema::BrandArgumentList Schema::getBrandArgumentsAtScope(uint64_t scopeId) con return BrandArgumentList(scopeId, raw->isUnbound()); } +kj::Array Schema::getGenericScopeIds() const { + if (!getProto().getIsGeneric()) + return nullptr; + + auto result = kj::heapArray(raw->scopeCount); + for (auto iScope: kj::indices(result)) { + result[iScope] = raw->scopes[iScope].typeId; + } + + return result; +} + + StructSchema Schema::asStruct() const { KJ_REQUIRE(getProto().isStruct(), "Tried to use non-struct schema as a struct.", getProto().getDisplayName()) { @@ -295,6 +315,11 @@ kj::StringPtr Schema::getShortDisplayName() const { return proto.getDisplayName().slice(proto.getDisplayNamePrefixLength()); } +const kj::StringPtr Schema::getUnqualifiedName() const { + auto proto = getProto(); + return proto.getDisplayName().slice(proto.getDisplayNamePrefixLength()); +} + void Schema::requireUsableAs(const _::RawSchema* expected) const { KJ_REQUIRE(raw->generic == expected || (expected != nullptr && raw->generic->canCastTo == expected), @@ -497,6 +522,11 @@ kj::Maybe StructSchema::getFieldByDiscriminant(uint16_t dis } } +bool StructSchema::isStreamResult() const { + auto& streamRaw = _::rawSchema(); + return raw->generic == &streamRaw || raw->generic->canCastTo == &streamRaw; +} + Type StructSchema::Field::getType() const { auto proto = getProto(); uint location = _::RawBrandedSchema::makeDepLocation(_::RawBrandedSchema::DepKind::FIELD, index); @@ -868,7 +898,7 @@ bool Type::operator==(const Type& other) const { KJ_UNREACHABLE; } -size_t Type::hashCode() const { +uint Type::hashCode() const { switch (baseType) { case schema::Type::VOID: case schema::Type::BOOL: @@ -884,12 +914,24 @@ size_t Type::hashCode() const { case schema::Type::FLOAT64: case schema::Type::TEXT: case schema::Type::DATA: - return (static_cast(baseType) << 3) + listDepth; + if (listDepth == 0) { + // Make sure that hashCode(Type(baseType)) == hashCode(baseType), otherwise HashMap lookups + // keyed by `Type` won't work when the caller passes `baseType` as the key. + return kj::hashCode(baseType); + } else { + return kj::hashCode(baseType, listDepth); + } case schema::Type::STRUCT: case schema::Type::ENUM: case schema::Type::INTERFACE: - return reinterpret_cast(schema) + listDepth; + if (listDepth == 0) { + // Make sure that hashCode(Type(schema)) == hashCode(schema), otherwise HashMap lookups + // keyed by `Type` won't work when the caller passes `schema` as the key. + return kj::hashCode(schema); + } else { + return kj::hashCode(schema, listDepth); + } case schema::Type::LIST: KJ_UNREACHABLE; @@ -897,9 +939,9 @@ size_t Type::hashCode() const { case schema::Type::ANY_POINTER: { // Trying to comply with strict aliasing rules. Hopefully the compiler realizes that // both branches compile to the same instructions and can optimize it away. - size_t val = scopeId != 0 || isImplicitParam ? + uint16_t val = scopeId != 0 || isImplicitParam ? paramIndex : static_cast(anyPointerKind); - return (val << 1 | isImplicitParam) ^ scopeId; + return kj::hashCode(val, isImplicitParam, scopeId, listDepth); } } diff --git a/c++/src/capnp/schema.capnp b/c++/src/capnp/schema.capnp index 4bef693f6c..a47c1517c0 100644 --- a/c++/src/capnp/schema.capnp +++ b/c++/src/capnp/schema.capnp @@ -169,6 +169,33 @@ struct Node { targetsAnnotation @30 :Bool; } } + + struct SourceInfo { + # Additional information about a node which is not needed at runtime, but may be useful for + # documentation or debugging purposes. This is kept in a separate struct to make sure it + # doesn't accidentally get included in contexts where it is not needed. The + # `CodeGeneratorRequest` includes this information in a separate array. + + id @0 :Id; + # ID of the Node which this info describes. + + docComment @1 :Text; + # The top-level doc comment for the Node. + + members @2 :List(Member); + # Information about each member -- i.e. fields (for structs), enumerants (for enums), or + # methods (for interfaces). + # + # This list is the same length and order as the corresponding list in the Node, i.e. + # Node.struct.fields, Node.enum.enumerants, or Node.interface.methods. + + struct Member { + docComment @0 :Text; + # Doc comment on the member. + } + + # TODO(someday): Record location of the declaration in the original source code. + } } struct Field { @@ -371,8 +398,21 @@ struct Brand { # List of parameter bindings. inherit @2 :Void; - # The place where this Brand appears is actually within this scope or a sub-scope, - # and the bindings for this scope should be inherited from the reference point. + # The place where the Brand appears is within this scope or a sub-scope, and bindings + # for this scope are deferred to later Brand applications. This is equivalent to a + # pass-through binding list, where each of this scope's parameters is bound to itself. + # For example: + # + # struct Outer(T) { + # struct Inner { + # value @0 :T; + # } + # innerInherit @0 :Inner; # Outer Brand.Scope is `inherit`. + # innerBindSelf @1 :Outer(T).Inner; # Outer Brand.Scope explicitly binds T to T. + # } + # + # The innerInherit and innerBindSelf fields have equivalent types, but different Brand + # styles. } } @@ -468,6 +508,10 @@ struct CodeGeneratorRequest { # All nodes parsed by the compiler, including for the files on the command line and their # imports. + sourceInfo @3 :List(Node.SourceInfo); + # Information about the original source code for each node, where available. This array may be + # omitted or may be missing some nodes if no info is available for them. + requestedFiles @1 :List(RequestedFile); # Files which were listed on the command line. diff --git a/c++/src/capnp/schema.capnp.c++ b/c++/src/capnp/schema.capnp.c++ index 5d1a0711d6..1b7c7c2ef8 100644 --- a/c++/src/capnp/schema.capnp.c++ +++ b/c++/src/capnp/schema.capnp.c++ @@ -5,7 +5,7 @@ namespace capnp { namespace schemas { -static const ::capnp::_::AlignedData<221> b_e682ab4cf923a417 = { +static const ::capnp::_::AlignedData<225> b_e682ab4cf923a417 = { { 0, 0, 0, 0, 5, 0, 6, 0, 23, 164, 35, 249, 76, 171, 130, 230, 19, 0, 0, 0, 1, 0, 5, 0, @@ -13,23 +13,27 @@ static const ::capnp::_::AlignedData<221> b_e682ab4cf923a417 = { 6, 0, 7, 0, 0, 0, 6, 0, 6, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 194, 0, 0, 0, - 29, 0, 0, 0, 39, 0, 0, 0, + 29, 0, 0, 0, 55, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 57, 0, 0, 0, 23, 3, 0, 0, + 73, 0, 0, 0, 23, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 115, 99, 104, 101, 109, 97, 46, 99, 97, 112, 110, 112, 58, 78, 111, 100, 101, 0, - 8, 0, 0, 0, 1, 0, 1, 0, + 12, 0, 0, 0, 1, 0, 1, 0, 177, 163, 15, 241, 204, 27, 82, 185, - 9, 0, 0, 0, 82, 0, 0, 0, + 17, 0, 0, 0, 82, 0, 0, 0, 66, 194, 15, 250, 187, 85, 191, 222, - 9, 0, 0, 0, 90, 0, 0, 0, + 17, 0, 0, 0, 90, 0, 0, 0, + 174, 87, 19, 4, 227, 29, 142, 243, + 17, 0, 0, 0, 90, 0, 0, 0, 80, 97, 114, 97, 109, 101, 116, 101, 114, 0, 0, 0, 0, 0, 0, 0, 78, 101, 115, 116, 101, 100, 78, 111, 100, 101, 0, 0, 0, 0, 0, 0, + 83, 111, 117, 114, 99, 101, 73, 110, + 102, 111, 0, 0, 0, 0, 0, 0, 56, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, @@ -243,8 +247,8 @@ static const ::capnp::_::RawSchema* const d_e682ab4cf923a417[] = { static const uint16_t m_e682ab4cf923a417[] = {11, 5, 10, 1, 2, 8, 6, 0, 9, 13, 4, 12, 3, 7}; static const uint16_t i_e682ab4cf923a417[] = {6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13}; const ::capnp::_::RawSchema s_e682ab4cf923a417 = { - 0xe682ab4cf923a417, b_e682ab4cf923a417.words, 221, d_e682ab4cf923a417, m_e682ab4cf923a417, - 8, 14, i_e682ab4cf923a417, nullptr, nullptr, { &s_e682ab4cf923a417, nullptr, nullptr, 0, 0, nullptr } + 0xe682ab4cf923a417, b_e682ab4cf923a417.words, 225, d_e682ab4cf923a417, m_e682ab4cf923a417, + 8, 14, i_e682ab4cf923a417, nullptr, nullptr, { &s_e682ab4cf923a417, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<34> b_b9521bccf10fa3b1 = { @@ -289,7 +293,7 @@ static const uint16_t m_b9521bccf10fa3b1[] = {0}; static const uint16_t i_b9521bccf10fa3b1[] = {0}; const ::capnp::_::RawSchema s_b9521bccf10fa3b1 = { 0xb9521bccf10fa3b1, b_b9521bccf10fa3b1.words, 34, nullptr, m_b9521bccf10fa3b1, - 0, 1, i_b9521bccf10fa3b1, nullptr, nullptr, { &s_b9521bccf10fa3b1, nullptr, nullptr, 0, 0, nullptr } + 0, 1, i_b9521bccf10fa3b1, nullptr, nullptr, { &s_b9521bccf10fa3b1, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_debf55bbfa0fc242 = { @@ -349,7 +353,140 @@ static const uint16_t m_debf55bbfa0fc242[] = {1, 0}; static const uint16_t i_debf55bbfa0fc242[] = {0, 1}; const ::capnp::_::RawSchema s_debf55bbfa0fc242 = { 0xdebf55bbfa0fc242, b_debf55bbfa0fc242.words, 49, nullptr, m_debf55bbfa0fc242, - 0, 2, i_debf55bbfa0fc242, nullptr, nullptr, { &s_debf55bbfa0fc242, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_debf55bbfa0fc242, nullptr, nullptr, { &s_debf55bbfa0fc242, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<72> b_f38e1de3041357ae = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 174, 87, 19, 4, 227, 29, 142, 243, + 24, 0, 0, 0, 1, 0, 1, 0, + 23, 164, 35, 249, 76, 171, 130, 230, + 2, 0, 7, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 26, 1, 0, 0, + 37, 0, 0, 0, 23, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 45, 0, 0, 0, 175, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 115, 99, + 104, 101, 109, 97, 46, 99, 97, 112, + 110, 112, 58, 78, 111, 100, 101, 46, + 83, 111, 117, 114, 99, 101, 73, 110, + 102, 111, 0, 0, 0, 0, 0, 0, + 4, 0, 0, 0, 1, 0, 1, 0, + 162, 31, 142, 137, 56, 144, 186, 194, + 1, 0, 0, 0, 58, 0, 0, 0, + 77, 101, 109, 98, 101, 114, 0, 0, + 12, 0, 0, 0, 3, 0, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 69, 0, 0, 0, 26, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 64, 0, 0, 0, 3, 0, 1, 0, + 76, 0, 0, 0, 2, 0, 1, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 73, 0, 0, 0, 90, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 72, 0, 0, 0, 3, 0, 1, 0, + 84, 0, 0, 0, 2, 0, 1, 0, + 2, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 1, 0, 2, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 81, 0, 0, 0, 66, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 76, 0, 0, 0, 3, 0, 1, 0, + 104, 0, 0, 0, 2, 0, 1, 0, + 105, 100, 0, 0, 0, 0, 0, 0, + 9, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 9, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 100, 111, 99, 67, 111, 109, 109, 101, + 110, 116, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 109, 101, 109, 98, 101, 114, 115, 0, + 14, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 3, 0, 1, 0, + 16, 0, 0, 0, 0, 0, 0, 0, + 162, 31, 142, 137, 56, 144, 186, 194, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 14, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_f38e1de3041357ae = b_f38e1de3041357ae.words; +#if !CAPNP_LITE +static const ::capnp::_::RawSchema* const d_f38e1de3041357ae[] = { + &s_c2ba9038898e1fa2, +}; +static const uint16_t m_f38e1de3041357ae[] = {1, 0, 2}; +static const uint16_t i_f38e1de3041357ae[] = {0, 1, 2}; +const ::capnp::_::RawSchema s_f38e1de3041357ae = { + 0xf38e1de3041357ae, b_f38e1de3041357ae.words, 72, d_f38e1de3041357ae, m_f38e1de3041357ae, + 1, 3, i_f38e1de3041357ae, nullptr, nullptr, { &s_f38e1de3041357ae, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +static const ::capnp::_::AlignedData<36> b_c2ba9038898e1fa2 = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 162, 31, 142, 137, 56, 144, 186, 194, + 35, 0, 0, 0, 1, 0, 0, 0, + 174, 87, 19, 4, 227, 29, 142, 243, + 1, 0, 7, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 82, 1, 0, 0, + 41, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 37, 0, 0, 0, 63, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 115, 99, + 104, 101, 109, 97, 46, 99, 97, 112, + 110, 112, 58, 78, 111, 100, 101, 46, + 83, 111, 117, 114, 99, 101, 73, 110, + 102, 111, 46, 77, 101, 109, 98, 101, + 114, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 1, 0, + 4, 0, 0, 0, 3, 0, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 13, 0, 0, 0, 90, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 3, 0, 1, 0, + 24, 0, 0, 0, 2, 0, 1, 0, + 100, 111, 99, 67, 111, 109, 109, 101, + 110, 116, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, } +}; +::capnp::word const* const bp_c2ba9038898e1fa2 = b_c2ba9038898e1fa2.words; +#if !CAPNP_LITE +static const uint16_t m_c2ba9038898e1fa2[] = {0}; +static const uint16_t i_c2ba9038898e1fa2[] = {0}; +const ::capnp::_::RawSchema s_c2ba9038898e1fa2 = { + 0xc2ba9038898e1fa2, b_c2ba9038898e1fa2.words, 36, nullptr, m_c2ba9038898e1fa2, + 0, 1, i_c2ba9038898e1fa2, nullptr, nullptr, { &s_c2ba9038898e1fa2, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<134> b_9ea0b19b37fb4435 = { @@ -499,7 +636,7 @@ static const uint16_t m_9ea0b19b37fb4435[] = {0, 4, 5, 6, 3, 1, 2}; static const uint16_t i_9ea0b19b37fb4435[] = {0, 1, 2, 3, 4, 5, 6}; const ::capnp::_::RawSchema s_9ea0b19b37fb4435 = { 0x9ea0b19b37fb4435, b_9ea0b19b37fb4435.words, 134, d_9ea0b19b37fb4435, m_9ea0b19b37fb4435, - 3, 7, i_9ea0b19b37fb4435, nullptr, nullptr, { &s_9ea0b19b37fb4435, nullptr, nullptr, 0, 0, nullptr } + 3, 7, i_9ea0b19b37fb4435, nullptr, nullptr, { &s_9ea0b19b37fb4435, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<37> b_b54ab3364333f598 = { @@ -551,7 +688,7 @@ static const uint16_t m_b54ab3364333f598[] = {0}; static const uint16_t i_b54ab3364333f598[] = {0}; const ::capnp::_::RawSchema s_b54ab3364333f598 = { 0xb54ab3364333f598, b_b54ab3364333f598.words, 37, d_b54ab3364333f598, m_b54ab3364333f598, - 2, 1, i_b54ab3364333f598, nullptr, nullptr, { &s_b54ab3364333f598, nullptr, nullptr, 0, 0, nullptr } + 2, 1, i_b54ab3364333f598, nullptr, nullptr, { &s_b54ab3364333f598, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<57> b_e82753cff0c2218f = { @@ -624,7 +761,7 @@ static const uint16_t m_e82753cff0c2218f[] = {0, 1}; static const uint16_t i_e82753cff0c2218f[] = {0, 1}; const ::capnp::_::RawSchema s_e82753cff0c2218f = { 0xe82753cff0c2218f, b_e82753cff0c2218f.words, 57, d_e82753cff0c2218f, m_e82753cff0c2218f, - 3, 2, i_e82753cff0c2218f, nullptr, nullptr, { &s_e82753cff0c2218f, nullptr, nullptr, 0, 0, nullptr } + 3, 2, i_e82753cff0c2218f, nullptr, nullptr, { &s_e82753cff0c2218f, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<47> b_b18aa5ac7a0d9420 = { @@ -687,7 +824,7 @@ static const uint16_t m_b18aa5ac7a0d9420[] = {0, 1}; static const uint16_t i_b18aa5ac7a0d9420[] = {0, 1}; const ::capnp::_::RawSchema s_b18aa5ac7a0d9420 = { 0xb18aa5ac7a0d9420, b_b18aa5ac7a0d9420.words, 47, d_b18aa5ac7a0d9420, m_b18aa5ac7a0d9420, - 3, 2, i_b18aa5ac7a0d9420, nullptr, nullptr, { &s_b18aa5ac7a0d9420, nullptr, nullptr, 0, 0, nullptr } + 3, 2, i_b18aa5ac7a0d9420, nullptr, nullptr, { &s_b18aa5ac7a0d9420, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<228> b_ec1619d4400a0290 = { @@ -930,7 +1067,7 @@ static const uint16_t m_ec1619d4400a0290[] = {12, 2, 3, 4, 6, 1, 8, 9, 10, 11, 5 static const uint16_t i_ec1619d4400a0290[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; const ::capnp::_::RawSchema s_ec1619d4400a0290 = { 0xec1619d4400a0290, b_ec1619d4400a0290.words, 228, d_ec1619d4400a0290, m_ec1619d4400a0290, - 2, 13, i_ec1619d4400a0290, nullptr, nullptr, { &s_ec1619d4400a0290, nullptr, nullptr, 0, 0, nullptr } + 2, 13, i_ec1619d4400a0290, nullptr, nullptr, { &s_ec1619d4400a0290, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<114> b_9aad50a41f4af45f = { @@ -1061,7 +1198,7 @@ static const uint16_t m_9aad50a41f4af45f[] = {2, 1, 3, 5, 0, 6, 4}; static const uint16_t i_9aad50a41f4af45f[] = {4, 5, 0, 1, 2, 3, 6}; const ::capnp::_::RawSchema s_9aad50a41f4af45f = { 0x9aad50a41f4af45f, b_9aad50a41f4af45f.words, 114, d_9aad50a41f4af45f, m_9aad50a41f4af45f, - 4, 7, i_9aad50a41f4af45f, nullptr, nullptr, { &s_9aad50a41f4af45f, nullptr, nullptr, 0, 0, nullptr } + 4, 7, i_9aad50a41f4af45f, nullptr, nullptr, { &s_9aad50a41f4af45f, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<25> b_97b14cbe7cfec712 = { @@ -1095,7 +1232,7 @@ static const ::capnp::_::AlignedData<25> b_97b14cbe7cfec712 = { #if !CAPNP_LITE const ::capnp::_::RawSchema s_97b14cbe7cfec712 = { 0x97b14cbe7cfec712, b_97b14cbe7cfec712.words, 25, nullptr, nullptr, - 0, 0, nullptr, nullptr, nullptr, { &s_97b14cbe7cfec712, nullptr, nullptr, 0, 0, nullptr } + 0, 0, nullptr, nullptr, nullptr, { &s_97b14cbe7cfec712, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<80> b_c42305476bb4746f = { @@ -1191,7 +1328,7 @@ static const uint16_t m_c42305476bb4746f[] = {2, 3, 0, 1}; static const uint16_t i_c42305476bb4746f[] = {0, 1, 2, 3}; const ::capnp::_::RawSchema s_c42305476bb4746f = { 0xc42305476bb4746f, b_c42305476bb4746f.words, 80, d_c42305476bb4746f, m_c42305476bb4746f, - 3, 4, i_c42305476bb4746f, nullptr, nullptr, { &s_c42305476bb4746f, nullptr, nullptr, 0, 0, nullptr } + 3, 4, i_c42305476bb4746f, nullptr, nullptr, { &s_c42305476bb4746f, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<32> b_cafccddb68db1d11 = { @@ -1237,7 +1374,7 @@ static const uint16_t m_cafccddb68db1d11[] = {0}; static const uint16_t i_cafccddb68db1d11[] = {0}; const ::capnp::_::RawSchema s_cafccddb68db1d11 = { 0xcafccddb68db1d11, b_cafccddb68db1d11.words, 32, d_cafccddb68db1d11, m_cafccddb68db1d11, - 1, 1, i_cafccddb68db1d11, nullptr, nullptr, { &s_cafccddb68db1d11, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_cafccddb68db1d11, nullptr, nullptr, { &s_cafccddb68db1d11, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<50> b_bb90d5c287870be6 = { @@ -1301,7 +1438,7 @@ static const uint16_t m_bb90d5c287870be6[] = {1, 0}; static const uint16_t i_bb90d5c287870be6[] = {0, 1}; const ::capnp::_::RawSchema s_bb90d5c287870be6 = { 0xbb90d5c287870be6, b_bb90d5c287870be6.words, 50, d_bb90d5c287870be6, m_bb90d5c287870be6, - 1, 2, i_bb90d5c287870be6, nullptr, nullptr, { &s_bb90d5c287870be6, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_bb90d5c287870be6, nullptr, nullptr, { &s_bb90d5c287870be6, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<69> b_978a7cebdc549a4d = { @@ -1384,7 +1521,7 @@ static const uint16_t m_978a7cebdc549a4d[] = {2, 1, 0}; static const uint16_t i_978a7cebdc549a4d[] = {0, 1, 2}; const ::capnp::_::RawSchema s_978a7cebdc549a4d = { 0x978a7cebdc549a4d, b_978a7cebdc549a4d.words, 69, d_978a7cebdc549a4d, m_978a7cebdc549a4d, - 1, 3, i_978a7cebdc549a4d, nullptr, nullptr, { &s_978a7cebdc549a4d, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_978a7cebdc549a4d, nullptr, nullptr, { &s_978a7cebdc549a4d, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<48> b_a9962a9ed0a4d7f8 = { @@ -1446,7 +1583,7 @@ static const uint16_t m_a9962a9ed0a4d7f8[] = {1, 0}; static const uint16_t i_a9962a9ed0a4d7f8[] = {0, 1}; const ::capnp::_::RawSchema s_a9962a9ed0a4d7f8 = { 0xa9962a9ed0a4d7f8, b_a9962a9ed0a4d7f8.words, 48, d_a9962a9ed0a4d7f8, m_a9962a9ed0a4d7f8, - 1, 2, i_a9962a9ed0a4d7f8, nullptr, nullptr, { &s_a9962a9ed0a4d7f8, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_a9962a9ed0a4d7f8, nullptr, nullptr, { &s_a9962a9ed0a4d7f8, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<155> b_9500cce23b334d80 = { @@ -1617,7 +1754,7 @@ static const uint16_t m_9500cce23b334d80[] = {4, 1, 7, 0, 5, 2, 6, 3}; static const uint16_t i_9500cce23b334d80[] = {0, 1, 2, 3, 4, 5, 6, 7}; const ::capnp::_::RawSchema s_9500cce23b334d80 = { 0x9500cce23b334d80, b_9500cce23b334d80.words, 155, d_9500cce23b334d80, m_9500cce23b334d80, - 3, 8, i_9500cce23b334d80, nullptr, nullptr, { &s_9500cce23b334d80, nullptr, nullptr, 0, 0, nullptr } + 3, 8, i_9500cce23b334d80, nullptr, nullptr, { &s_9500cce23b334d80, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<269> b_d07378ede1f9cc60 = { @@ -1904,7 +2041,7 @@ static const uint16_t m_d07378ede1f9cc60[] = {18, 1, 13, 15, 10, 11, 3, 4, 5, 2, static const uint16_t i_d07378ede1f9cc60[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}; const ::capnp::_::RawSchema s_d07378ede1f9cc60 = { 0xd07378ede1f9cc60, b_d07378ede1f9cc60.words, 269, d_d07378ede1f9cc60, m_d07378ede1f9cc60, - 5, 19, i_d07378ede1f9cc60, nullptr, nullptr, { &s_d07378ede1f9cc60, nullptr, nullptr, 0, 0, nullptr } + 5, 19, i_d07378ede1f9cc60, nullptr, nullptr, { &s_d07378ede1f9cc60, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<33> b_87e739250a60ea97 = { @@ -1951,7 +2088,7 @@ static const uint16_t m_87e739250a60ea97[] = {0}; static const uint16_t i_87e739250a60ea97[] = {0}; const ::capnp::_::RawSchema s_87e739250a60ea97 = { 0x87e739250a60ea97, b_87e739250a60ea97.words, 33, d_87e739250a60ea97, m_87e739250a60ea97, - 1, 1, i_87e739250a60ea97, nullptr, nullptr, { &s_87e739250a60ea97, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_87e739250a60ea97, nullptr, nullptr, { &s_87e739250a60ea97, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<47> b_9e0e78711a7f87a9 = { @@ -2013,7 +2150,7 @@ static const uint16_t m_9e0e78711a7f87a9[] = {1, 0}; static const uint16_t i_9e0e78711a7f87a9[] = {0, 1}; const ::capnp::_::RawSchema s_9e0e78711a7f87a9 = { 0x9e0e78711a7f87a9, b_9e0e78711a7f87a9.words, 47, d_9e0e78711a7f87a9, m_9e0e78711a7f87a9, - 2, 2, i_9e0e78711a7f87a9, nullptr, nullptr, { &s_9e0e78711a7f87a9, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_9e0e78711a7f87a9, nullptr, nullptr, { &s_9e0e78711a7f87a9, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<47> b_ac3a6f60ef4cc6d3 = { @@ -2075,7 +2212,7 @@ static const uint16_t m_ac3a6f60ef4cc6d3[] = {1, 0}; static const uint16_t i_ac3a6f60ef4cc6d3[] = {0, 1}; const ::capnp::_::RawSchema s_ac3a6f60ef4cc6d3 = { 0xac3a6f60ef4cc6d3, b_ac3a6f60ef4cc6d3.words, 47, d_ac3a6f60ef4cc6d3, m_ac3a6f60ef4cc6d3, - 2, 2, i_ac3a6f60ef4cc6d3, nullptr, nullptr, { &s_ac3a6f60ef4cc6d3, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_ac3a6f60ef4cc6d3, nullptr, nullptr, { &s_ac3a6f60ef4cc6d3, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<48> b_ed8bca69f7fb0cbf = { @@ -2138,7 +2275,7 @@ static const uint16_t m_ed8bca69f7fb0cbf[] = {1, 0}; static const uint16_t i_ed8bca69f7fb0cbf[] = {0, 1}; const ::capnp::_::RawSchema s_ed8bca69f7fb0cbf = { 0xed8bca69f7fb0cbf, b_ed8bca69f7fb0cbf.words, 48, d_ed8bca69f7fb0cbf, m_ed8bca69f7fb0cbf, - 2, 2, i_ed8bca69f7fb0cbf, nullptr, nullptr, { &s_ed8bca69f7fb0cbf, nullptr, nullptr, 0, 0, nullptr } + 2, 2, i_ed8bca69f7fb0cbf, nullptr, nullptr, { &s_ed8bca69f7fb0cbf, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<46> b_c2573fe8a23e49f1 = { @@ -2201,7 +2338,7 @@ static const uint16_t m_c2573fe8a23e49f1[] = {2, 1, 0}; static const uint16_t i_c2573fe8a23e49f1[] = {0, 1, 2}; const ::capnp::_::RawSchema s_c2573fe8a23e49f1 = { 0xc2573fe8a23e49f1, b_c2573fe8a23e49f1.words, 46, d_c2573fe8a23e49f1, m_c2573fe8a23e49f1, - 4, 3, i_c2573fe8a23e49f1, nullptr, nullptr, { &s_c2573fe8a23e49f1, nullptr, nullptr, 0, 0, nullptr } + 4, 3, i_c2573fe8a23e49f1, nullptr, nullptr, { &s_c2573fe8a23e49f1, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<81> b_8e3b5f79fe593656 = { @@ -2296,7 +2433,7 @@ static const uint16_t m_8e3b5f79fe593656[] = {0, 3, 2, 1}; static const uint16_t i_8e3b5f79fe593656[] = {0, 1, 2, 3}; const ::capnp::_::RawSchema s_8e3b5f79fe593656 = { 0x8e3b5f79fe593656, b_8e3b5f79fe593656.words, 81, d_8e3b5f79fe593656, m_8e3b5f79fe593656, - 1, 4, i_8e3b5f79fe593656, nullptr, nullptr, { &s_8e3b5f79fe593656, nullptr, nullptr, 0, 0, nullptr } + 1, 4, i_8e3b5f79fe593656, nullptr, nullptr, { &s_8e3b5f79fe593656, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<50> b_9dd1f724f4614a85 = { @@ -2360,7 +2497,7 @@ static const uint16_t m_9dd1f724f4614a85[] = {1, 0}; static const uint16_t i_9dd1f724f4614a85[] = {0, 1}; const ::capnp::_::RawSchema s_9dd1f724f4614a85 = { 0x9dd1f724f4614a85, b_9dd1f724f4614a85.words, 50, d_9dd1f724f4614a85, m_9dd1f724f4614a85, - 1, 2, i_9dd1f724f4614a85, nullptr, nullptr, { &s_9dd1f724f4614a85, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_9dd1f724f4614a85, nullptr, nullptr, { &s_9dd1f724f4614a85, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<37> b_baefc9120c56e274 = { @@ -2411,7 +2548,7 @@ static const uint16_t m_baefc9120c56e274[] = {0}; static const uint16_t i_baefc9120c56e274[] = {0}; const ::capnp::_::RawSchema s_baefc9120c56e274 = { 0xbaefc9120c56e274, b_baefc9120c56e274.words, 37, d_baefc9120c56e274, m_baefc9120c56e274, - 1, 1, i_baefc9120c56e274, nullptr, nullptr, { &s_baefc9120c56e274, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_baefc9120c56e274, nullptr, nullptr, { &s_baefc9120c56e274, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<43> b_903455f06065422b = { @@ -2468,7 +2605,7 @@ static const uint16_t m_903455f06065422b[] = {0}; static const uint16_t i_903455f06065422b[] = {0}; const ::capnp::_::RawSchema s_903455f06065422b = { 0x903455f06065422b, b_903455f06065422b.words, 43, d_903455f06065422b, m_903455f06065422b, - 1, 1, i_903455f06065422b, nullptr, nullptr, { &s_903455f06065422b, nullptr, nullptr, 0, 0, nullptr } + 1, 1, i_903455f06065422b, nullptr, nullptr, { &s_903455f06065422b, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<67> b_abd73485a9636bc9 = { @@ -2549,7 +2686,7 @@ static const uint16_t m_abd73485a9636bc9[] = {1, 2, 0}; static const uint16_t i_abd73485a9636bc9[] = {1, 2, 0}; const ::capnp::_::RawSchema s_abd73485a9636bc9 = { 0xabd73485a9636bc9, b_abd73485a9636bc9.words, 67, d_abd73485a9636bc9, m_abd73485a9636bc9, - 1, 3, i_abd73485a9636bc9, nullptr, nullptr, { &s_abd73485a9636bc9, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_abd73485a9636bc9, nullptr, nullptr, { &s_abd73485a9636bc9, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<49> b_c863cd16969ee7fc = { @@ -2612,7 +2749,7 @@ static const uint16_t m_c863cd16969ee7fc[] = {1, 0}; static const uint16_t i_c863cd16969ee7fc[] = {0, 1}; const ::capnp::_::RawSchema s_c863cd16969ee7fc = { 0xc863cd16969ee7fc, b_c863cd16969ee7fc.words, 49, d_c863cd16969ee7fc, m_c863cd16969ee7fc, - 1, 2, i_c863cd16969ee7fc, nullptr, nullptr, { &s_c863cd16969ee7fc, nullptr, nullptr, 0, 0, nullptr } + 1, 2, i_c863cd16969ee7fc, nullptr, nullptr, { &s_c863cd16969ee7fc, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<305> b_ce23dcd2d7b00c9b = { @@ -2928,7 +3065,7 @@ static const uint16_t m_ce23dcd2d7b00c9b[] = {18, 1, 13, 15, 10, 11, 3, 4, 5, 2, static const uint16_t i_ce23dcd2d7b00c9b[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}; const ::capnp::_::RawSchema s_ce23dcd2d7b00c9b = { 0xce23dcd2d7b00c9b, b_ce23dcd2d7b00c9b.words, 305, nullptr, m_ce23dcd2d7b00c9b, - 0, 19, i_ce23dcd2d7b00c9b, nullptr, nullptr, { &s_ce23dcd2d7b00c9b, nullptr, nullptr, 0, 0, nullptr } + 0, 19, i_ce23dcd2d7b00c9b, nullptr, nullptr, { &s_ce23dcd2d7b00c9b, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<63> b_f1c8950dab257542 = { @@ -3006,7 +3143,7 @@ static const uint16_t m_f1c8950dab257542[] = {2, 0, 1}; static const uint16_t i_f1c8950dab257542[] = {0, 1, 2}; const ::capnp::_::RawSchema s_f1c8950dab257542 = { 0xf1c8950dab257542, b_f1c8950dab257542.words, 63, d_f1c8950dab257542, m_f1c8950dab257542, - 2, 3, i_f1c8950dab257542, nullptr, nullptr, { &s_f1c8950dab257542, nullptr, nullptr, 0, 0, nullptr } + 2, 3, i_f1c8950dab257542, nullptr, nullptr, { &s_f1c8950dab257542, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<54> b_d1958f7dba521926 = { @@ -3070,7 +3207,7 @@ static const ::capnp::_::AlignedData<54> b_d1958f7dba521926 = { static const uint16_t m_d1958f7dba521926[] = {1, 2, 5, 0, 4, 7, 6, 3}; const ::capnp::_::RawSchema s_d1958f7dba521926 = { 0xd1958f7dba521926, b_d1958f7dba521926.words, 54, nullptr, m_d1958f7dba521926, - 0, 8, nullptr, nullptr, nullptr, { &s_d1958f7dba521926, nullptr, nullptr, 0, 0, nullptr } + 0, 8, nullptr, nullptr, nullptr, { &s_d1958f7dba521926, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE CAPNP_DEFINE_ENUM(ElementSize_d1958f7dba521926, d1958f7dba521926); @@ -3145,20 +3282,20 @@ static const uint16_t m_d85d305b7d839963[] = {0, 2, 1}; static const uint16_t i_d85d305b7d839963[] = {0, 1, 2}; const ::capnp::_::RawSchema s_d85d305b7d839963 = { 0xd85d305b7d839963, b_d85d305b7d839963.words, 63, nullptr, m_d85d305b7d839963, - 0, 3, i_d85d305b7d839963, nullptr, nullptr, { &s_d85d305b7d839963, nullptr, nullptr, 0, 0, nullptr } + 0, 3, i_d85d305b7d839963, nullptr, nullptr, { &s_d85d305b7d839963, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE -static const ::capnp::_::AlignedData<78> b_bfc546f6210ad7ce = { +static const ::capnp::_::AlignedData<98> b_bfc546f6210ad7ce = { { 0, 0, 0, 0, 5, 0, 6, 0, 206, 215, 10, 33, 246, 70, 197, 191, 19, 0, 0, 0, 1, 0, 0, 0, 217, 114, 76, 98, 9, 197, 63, 169, - 3, 0, 7, 0, 0, 0, 0, 0, + 4, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 66, 1, 0, 0, 37, 0, 0, 0, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 49, 0, 0, 0, 175, 0, 0, 0, + 49, 0, 0, 0, 231, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 99, 97, 112, 110, 112, 47, 115, 99, @@ -3171,28 +3308,35 @@ static const ::capnp::_::AlignedData<78> b_bfc546f6210ad7ce = { 1, 0, 0, 0, 114, 0, 0, 0, 82, 101, 113, 117, 101, 115, 116, 101, 100, 70, 105, 108, 101, 0, 0, 0, - 12, 0, 0, 0, 3, 0, 4, 0, + 16, 0, 0, 0, 3, 0, 4, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 69, 0, 0, 0, 50, 0, 0, 0, + 97, 0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 64, 0, 0, 0, 3, 0, 1, 0, - 92, 0, 0, 0, 2, 0, 1, 0, - 2, 0, 0, 0, 1, 0, 0, 0, + 92, 0, 0, 0, 3, 0, 1, 0, + 120, 0, 0, 0, 2, 0, 1, 0, + 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 89, 0, 0, 0, 122, 0, 0, 0, + 117, 0, 0, 0, 122, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 88, 0, 0, 0, 3, 0, 1, 0, - 116, 0, 0, 0, 2, 0, 1, 0, + 116, 0, 0, 0, 3, 0, 1, 0, + 144, 0, 0, 0, 2, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 113, 0, 0, 0, 106, 0, 0, 0, + 141, 0, 0, 0, 106, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 140, 0, 0, 0, 3, 0, 1, 0, + 152, 0, 0, 0, 2, 0, 1, 0, + 2, 0, 0, 0, 3, 0, 0, 0, + 0, 0, 1, 0, 3, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 149, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 112, 0, 0, 0, 3, 0, 1, 0, - 124, 0, 0, 0, 2, 0, 1, 0, + 148, 0, 0, 0, 3, 0, 1, 0, + 176, 0, 0, 0, 2, 0, 1, 0, 110, 111, 100, 101, 115, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -3225,6 +3369,19 @@ static const ::capnp::_::AlignedData<78> b_bfc546f6210ad7ce = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 115, 111, 117, 114, 99, 101, 73, 110, + 102, 111, 0, 0, 0, 0, 0, 0, + 14, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 3, 0, 1, 0, + 16, 0, 0, 0, 0, 0, 0, 0, + 174, 87, 19, 4, 227, 29, 142, 243, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } }; @@ -3234,12 +3391,13 @@ static const ::capnp::_::RawSchema* const d_bfc546f6210ad7ce[] = { &s_cfea0eb02e810062, &s_d85d305b7d839963, &s_e682ab4cf923a417, + &s_f38e1de3041357ae, }; -static const uint16_t m_bfc546f6210ad7ce[] = {2, 0, 1}; -static const uint16_t i_bfc546f6210ad7ce[] = {0, 1, 2}; +static const uint16_t m_bfc546f6210ad7ce[] = {2, 0, 1, 3}; +static const uint16_t i_bfc546f6210ad7ce[] = {0, 1, 2, 3}; const ::capnp::_::RawSchema s_bfc546f6210ad7ce = { - 0xbfc546f6210ad7ce, b_bfc546f6210ad7ce.words, 78, d_bfc546f6210ad7ce, m_bfc546f6210ad7ce, - 3, 3, i_bfc546f6210ad7ce, nullptr, nullptr, { &s_bfc546f6210ad7ce, nullptr, nullptr, 0, 0, nullptr } + 0xbfc546f6210ad7ce, b_bfc546f6210ad7ce.words, 98, d_bfc546f6210ad7ce, m_bfc546f6210ad7ce, + 4, 4, i_bfc546f6210ad7ce, nullptr, nullptr, { &s_bfc546f6210ad7ce, nullptr, nullptr, 0, 0, nullptr }, true }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<74> b_cfea0eb02e810062 = { @@ -3327,7 +3485,7 @@ static const uint16_t m_cfea0eb02e810062[] = {1, 0, 2}; static const uint16_t i_cfea0eb02e810062[] = {0, 1, 2}; const ::capnp::_::RawSchema s_cfea0eb02e810062 = { 0xcfea0eb02e810062, b_cfea0eb02e810062.words, 74, d_cfea0eb02e810062, m_cfea0eb02e810062, - 1, 3, i_cfea0eb02e810062, nullptr, nullptr, { &s_cfea0eb02e810062, nullptr, nullptr, 0, 0, nullptr } + 1, 3, i_cfea0eb02e810062, nullptr, nullptr, { &s_cfea0eb02e810062, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE static const ::capnp::_::AlignedData<52> b_ae504193122357e5 = { @@ -3390,7 +3548,7 @@ static const uint16_t m_ae504193122357e5[] = {0, 1}; static const uint16_t i_ae504193122357e5[] = {0, 1}; const ::capnp::_::RawSchema s_ae504193122357e5 = { 0xae504193122357e5, b_ae504193122357e5.words, 52, nullptr, m_ae504193122357e5, - 0, 2, i_ae504193122357e5, nullptr, nullptr, { &s_ae504193122357e5, nullptr, nullptr, 0, 0, nullptr } + 0, 2, i_ae504193122357e5, nullptr, nullptr, { &s_ae504193122357e5, nullptr, nullptr, 0, 0, nullptr }, false }; #endif // !CAPNP_LITE } // namespace schemas @@ -3402,270 +3560,426 @@ namespace capnp { namespace schema { // Node +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::_capnpPrivate::dataWordSize; constexpr uint16_t Node::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Parameter +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Parameter::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Parameter::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Parameter::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Parameter::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::NestedNode +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::NestedNode::_capnpPrivate::dataWordSize; constexpr uint16_t Node::NestedNode::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::NestedNode::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::NestedNode::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#endif // !CAPNP_LITE + +// Node::SourceInfo +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr uint16_t Node::SourceInfo::_capnpPrivate::dataWordSize; +constexpr uint16_t Node::SourceInfo::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr ::capnp::Kind Node::SourceInfo::_capnpPrivate::kind; +constexpr ::capnp::_::RawSchema const* Node::SourceInfo::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#endif // !CAPNP_LITE + +// Node::SourceInfo::Member +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr uint16_t Node::SourceInfo::Member::_capnpPrivate::dataWordSize; +constexpr uint16_t Node::SourceInfo::Member::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr ::capnp::Kind Node::SourceInfo::Member::_capnpPrivate::kind; +constexpr ::capnp::_::RawSchema const* Node::SourceInfo::Member::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Struct +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Struct::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Struct::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Struct::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Struct::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Enum +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Enum::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Enum::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Enum::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Enum::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Interface +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Interface::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Interface::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Interface::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Interface::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Const +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Const::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Const::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Const::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Const::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Node::Annotation +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Node::Annotation::_capnpPrivate::dataWordSize; constexpr uint16_t Node::Annotation::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Node::Annotation::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Node::Annotation::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Field +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Field::_capnpPrivate::dataWordSize; constexpr uint16_t Field::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Field::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Field::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE -#ifndef _MSC_VER +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::uint16_t Field::NO_DISCRIMINANT; #endif // Field::Slot +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Field::Slot::_capnpPrivate::dataWordSize; constexpr uint16_t Field::Slot::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Field::Slot::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Field::Slot::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Field::Group +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Field::Group::_capnpPrivate::dataWordSize; constexpr uint16_t Field::Group::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Field::Group::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Field::Group::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Field::Ordinal +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Field::Ordinal::_capnpPrivate::dataWordSize; constexpr uint16_t Field::Ordinal::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Field::Ordinal::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Field::Ordinal::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Enumerant +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Enumerant::_capnpPrivate::dataWordSize; constexpr uint16_t Enumerant::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Enumerant::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Enumerant::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Superclass +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Superclass::_capnpPrivate::dataWordSize; constexpr uint16_t Superclass::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Superclass::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Superclass::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Method +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Method::_capnpPrivate::dataWordSize; constexpr uint16_t Method::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Method::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Method::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::_capnpPrivate::dataWordSize; constexpr uint16_t Type::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::List +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::List::_capnpPrivate::dataWordSize; constexpr uint16_t Type::List::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::List::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::List::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::Enum +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::Enum::_capnpPrivate::dataWordSize; constexpr uint16_t Type::Enum::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::Enum::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::Enum::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::Struct +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::Struct::_capnpPrivate::dataWordSize; constexpr uint16_t Type::Struct::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::Struct::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::Struct::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::Interface +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::Interface::_capnpPrivate::dataWordSize; constexpr uint16_t Type::Interface::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::Interface::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::Interface::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::AnyPointer +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::AnyPointer::_capnpPrivate::dataWordSize; constexpr uint16_t Type::AnyPointer::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::AnyPointer::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::AnyPointer::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::AnyPointer::Unconstrained +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::AnyPointer::Unconstrained::_capnpPrivate::dataWordSize; constexpr uint16_t Type::AnyPointer::Unconstrained::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::AnyPointer::Unconstrained::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::AnyPointer::Unconstrained::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::AnyPointer::Parameter +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::AnyPointer::Parameter::_capnpPrivate::dataWordSize; constexpr uint16_t Type::AnyPointer::Parameter::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::AnyPointer::Parameter::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::AnyPointer::Parameter::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Type::AnyPointer::ImplicitMethodParameter +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Type::AnyPointer::ImplicitMethodParameter::_capnpPrivate::dataWordSize; constexpr uint16_t Type::AnyPointer::ImplicitMethodParameter::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Type::AnyPointer::ImplicitMethodParameter::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Type::AnyPointer::ImplicitMethodParameter::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Brand +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Brand::_capnpPrivate::dataWordSize; constexpr uint16_t Brand::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Brand::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Brand::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Brand::Scope +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Brand::Scope::_capnpPrivate::dataWordSize; constexpr uint16_t Brand::Scope::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Brand::Scope::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Brand::Scope::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Brand::Binding +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Brand::Binding::_capnpPrivate::dataWordSize; constexpr uint16_t Brand::Binding::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Brand::Binding::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Brand::Binding::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Value +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Value::_capnpPrivate::dataWordSize; constexpr uint16_t Value::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Value::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Value::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // Annotation +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t Annotation::_capnpPrivate::dataWordSize; constexpr uint16_t Annotation::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind Annotation::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* Annotation::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CapnpVersion +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CapnpVersion::_capnpPrivate::dataWordSize; constexpr uint16_t CapnpVersion::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CapnpVersion::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CapnpVersion::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CodeGeneratorRequest +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CodeGeneratorRequest::_capnpPrivate::dataWordSize; constexpr uint16_t CodeGeneratorRequest::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CodeGeneratorRequest::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CodeGeneratorRequest::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CodeGeneratorRequest::RequestedFile +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CodeGeneratorRequest::RequestedFile::_capnpPrivate::dataWordSize; constexpr uint16_t CodeGeneratorRequest::RequestedFile::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CodeGeneratorRequest::RequestedFile::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CodeGeneratorRequest::RequestedFile::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE // CodeGeneratorRequest::RequestedFile::Import +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr uint16_t CodeGeneratorRequest::RequestedFile::Import::_capnpPrivate::dataWordSize; constexpr uint16_t CodeGeneratorRequest::RequestedFile::Import::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL constexpr ::capnp::Kind CodeGeneratorRequest::RequestedFile::Import::_capnpPrivate::kind; constexpr ::capnp::_::RawSchema const* CodeGeneratorRequest::RequestedFile::Import::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL #endif // !CAPNP_LITE diff --git a/c++/src/capnp/schema.capnp.h b/c++/src/capnp/schema.capnp.h index 1f116c9f8f..519bd5274a 100644 --- a/c++/src/capnp/schema.capnp.h +++ b/c++/src/capnp/schema.capnp.h @@ -1,22 +1,28 @@ // Generated by Cap'n Proto compiler, DO NOT EDIT // source: schema.capnp -#ifndef CAPNP_INCLUDED_a93fc509624c72d9_ -#define CAPNP_INCLUDED_a93fc509624c72d9_ +#pragma once #include +#include -#if CAPNP_VERSION != 6001 +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 #error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." #endif +CAPNP_BEGIN_HEADER + namespace capnp { namespace schemas { CAPNP_DECLARE_SCHEMA(e682ab4cf923a417); CAPNP_DECLARE_SCHEMA(b9521bccf10fa3b1); CAPNP_DECLARE_SCHEMA(debf55bbfa0fc242); +CAPNP_DECLARE_SCHEMA(f38e1de3041357ae); +CAPNP_DECLARE_SCHEMA(c2ba9038898e1fa2); CAPNP_DECLARE_SCHEMA(9ea0b19b37fb4435); CAPNP_DECLARE_SCHEMA(b54ab3364333f598); CAPNP_DECLARE_SCHEMA(e82753cff0c2218f); @@ -83,6 +89,7 @@ struct Node { }; struct Parameter; struct NestedNode; + struct SourceInfo; struct Struct; struct Enum; struct Interface; @@ -127,6 +134,37 @@ struct Node::NestedNode { }; }; +struct Node::SourceInfo { + SourceInfo() = delete; + + class Reader; + class Builder; + class Pipeline; + struct Member; + + struct _capnpPrivate { + CAPNP_DECLARE_STRUCT_HEADER(f38e1de3041357ae, 1, 2) + #if !CAPNP_LITE + static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } + #endif // !CAPNP_LITE + }; +}; + +struct Node::SourceInfo::Member { + Member() = delete; + + class Reader; + class Builder; + class Pipeline; + + struct _capnpPrivate { + CAPNP_DECLARE_STRUCT_HEADER(c2ba9038898e1fa2, 0, 1) + #if !CAPNP_LITE + static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } + #endif // !CAPNP_LITE + }; +}; + struct Node::Struct { Struct() = delete; @@ -626,7 +664,7 @@ struct CodeGeneratorRequest { struct RequestedFile; struct _capnpPrivate { - CAPNP_DECLARE_STRUCT_HEADER(bfc546f6210ad7ce, 0, 3) + CAPNP_DECLARE_STRUCT_HEADER(bfc546f6210ad7ce, 0, 4) #if !CAPNP_LITE static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } #endif // !CAPNP_LITE @@ -694,10 +732,10 @@ class Node::Reader { inline ::uint64_t getScopeId() const; inline bool hasNestedNodes() const; - inline ::capnp::List< ::capnp::schema::Node::NestedNode>::Reader getNestedNodes() const; + inline ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>::Reader getNestedNodes() const; inline bool hasAnnotations() const; - inline ::capnp::List< ::capnp::schema::Annotation>::Reader getAnnotations() const; + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader getAnnotations() const; inline bool isFile() const; inline ::capnp::Void getFile() const; @@ -718,7 +756,7 @@ class Node::Reader { inline typename Annotation::Reader getAnnotation() const; inline bool hasParameters() const; - inline ::capnp::List< ::capnp::schema::Node::Parameter>::Reader getParameters() const; + inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Reader getParameters() const; inline bool getIsGeneric() const; @@ -768,18 +806,18 @@ class Node::Builder { inline void setScopeId( ::uint64_t value); inline bool hasNestedNodes(); - inline ::capnp::List< ::capnp::schema::Node::NestedNode>::Builder getNestedNodes(); - inline void setNestedNodes( ::capnp::List< ::capnp::schema::Node::NestedNode>::Reader value); - inline ::capnp::List< ::capnp::schema::Node::NestedNode>::Builder initNestedNodes(unsigned int size); - inline void adoptNestedNodes(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::NestedNode>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::NestedNode>> disownNestedNodes(); + inline ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>::Builder getNestedNodes(); + inline void setNestedNodes( ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>::Builder initNestedNodes(unsigned int size); + inline void adoptNestedNodes(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>> disownNestedNodes(); inline bool hasAnnotations(); - inline ::capnp::List< ::capnp::schema::Annotation>::Builder getAnnotations(); - inline void setAnnotations( ::capnp::List< ::capnp::schema::Annotation>::Reader value); - inline ::capnp::List< ::capnp::schema::Annotation>::Builder initAnnotations(unsigned int size); - inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>> disownAnnotations(); + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder getAnnotations(); + inline void setAnnotations( ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder initAnnotations(unsigned int size); + inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>> disownAnnotations(); inline bool isFile(); inline ::capnp::Void getFile(); @@ -806,11 +844,11 @@ class Node::Builder { inline typename Annotation::Builder initAnnotation(); inline bool hasParameters(); - inline ::capnp::List< ::capnp::schema::Node::Parameter>::Builder getParameters(); - inline void setParameters( ::capnp::List< ::capnp::schema::Node::Parameter>::Reader value); - inline ::capnp::List< ::capnp::schema::Node::Parameter>::Builder initParameters(unsigned int size); - inline void adoptParameters(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter>> disownParameters(); + inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Builder getParameters(); + inline void setParameters( ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Builder initParameters(unsigned int size); + inline void adoptParameters(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>> disownParameters(); inline bool getIsGeneric(); inline void setIsGeneric(bool value); @@ -1008,6 +1046,183 @@ class Node::NestedNode::Pipeline { }; #endif // !CAPNP_LITE +class Node::SourceInfo::Reader { +public: + typedef SourceInfo Reads; + + Reader() = default; + inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} + + inline ::capnp::MessageSize totalSize() const { + return _reader.totalSize().asPublic(); + } + +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { + return ::capnp::_::structString(_reader, *_capnpPrivate::brand()); + } +#endif // !CAPNP_LITE + + inline ::uint64_t getId() const; + + inline bool hasDocComment() const; + inline ::capnp::Text::Reader getDocComment() const; + + inline bool hasMembers() const; + inline ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>::Reader getMembers() const; + +private: + ::capnp::_::StructReader _reader; + template + friend struct ::capnp::ToDynamic_; + template + friend struct ::capnp::_::PointerHelpers; + template + friend struct ::capnp::List; + friend class ::capnp::MessageBuilder; + friend class ::capnp::Orphanage; +}; + +class Node::SourceInfo::Builder { +public: + typedef SourceInfo Builds; + + Builder() = delete; // Deleted to discourage incorrect usage. + // You can explicitly initialize to nullptr instead. + inline Builder(decltype(nullptr)) {} + inline explicit Builder(::capnp::_::StructBuilder base): _builder(base) {} + inline operator Reader() const { return Reader(_builder.asReader()); } + inline Reader asReader() const { return *this; } + + inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); } +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { return asReader().toString(); } +#endif // !CAPNP_LITE + + inline ::uint64_t getId(); + inline void setId( ::uint64_t value); + + inline bool hasDocComment(); + inline ::capnp::Text::Builder getDocComment(); + inline void setDocComment( ::capnp::Text::Reader value); + inline ::capnp::Text::Builder initDocComment(unsigned int size); + inline void adoptDocComment(::capnp::Orphan< ::capnp::Text>&& value); + inline ::capnp::Orphan< ::capnp::Text> disownDocComment(); + + inline bool hasMembers(); + inline ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>::Builder getMembers(); + inline void setMembers( ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>::Builder initMembers(unsigned int size); + inline void adoptMembers(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>> disownMembers(); + +private: + ::capnp::_::StructBuilder _builder; + template + friend struct ::capnp::ToDynamic_; + friend class ::capnp::Orphanage; + template + friend struct ::capnp::_::PointerHelpers; +}; + +#if !CAPNP_LITE +class Node::SourceInfo::Pipeline { +public: + typedef SourceInfo Pipelines; + + inline Pipeline(decltype(nullptr)): _typeless(nullptr) {} + inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) + : _typeless(kj::mv(typeless)) {} + +private: + ::capnp::AnyPointer::Pipeline _typeless; + friend class ::capnp::PipelineHook; + template + friend struct ::capnp::ToDynamic_; +}; +#endif // !CAPNP_LITE + +class Node::SourceInfo::Member::Reader { +public: + typedef Member Reads; + + Reader() = default; + inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} + + inline ::capnp::MessageSize totalSize() const { + return _reader.totalSize().asPublic(); + } + +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { + return ::capnp::_::structString(_reader, *_capnpPrivate::brand()); + } +#endif // !CAPNP_LITE + + inline bool hasDocComment() const; + inline ::capnp::Text::Reader getDocComment() const; + +private: + ::capnp::_::StructReader _reader; + template + friend struct ::capnp::ToDynamic_; + template + friend struct ::capnp::_::PointerHelpers; + template + friend struct ::capnp::List; + friend class ::capnp::MessageBuilder; + friend class ::capnp::Orphanage; +}; + +class Node::SourceInfo::Member::Builder { +public: + typedef Member Builds; + + Builder() = delete; // Deleted to discourage incorrect usage. + // You can explicitly initialize to nullptr instead. + inline Builder(decltype(nullptr)) {} + inline explicit Builder(::capnp::_::StructBuilder base): _builder(base) {} + inline operator Reader() const { return Reader(_builder.asReader()); } + inline Reader asReader() const { return *this; } + + inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); } +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { return asReader().toString(); } +#endif // !CAPNP_LITE + + inline bool hasDocComment(); + inline ::capnp::Text::Builder getDocComment(); + inline void setDocComment( ::capnp::Text::Reader value); + inline ::capnp::Text::Builder initDocComment(unsigned int size); + inline void adoptDocComment(::capnp::Orphan< ::capnp::Text>&& value); + inline ::capnp::Orphan< ::capnp::Text> disownDocComment(); + +private: + ::capnp::_::StructBuilder _builder; + template + friend struct ::capnp::ToDynamic_; + friend class ::capnp::Orphanage; + template + friend struct ::capnp::_::PointerHelpers; +}; + +#if !CAPNP_LITE +class Node::SourceInfo::Member::Pipeline { +public: + typedef Member Pipelines; + + inline Pipeline(decltype(nullptr)): _typeless(nullptr) {} + inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) + : _typeless(kj::mv(typeless)) {} + +private: + ::capnp::AnyPointer::Pipeline _typeless; + friend class ::capnp::PipelineHook; + template + friend struct ::capnp::ToDynamic_; +}; +#endif // !CAPNP_LITE + class Node::Struct::Reader { public: typedef Struct Reads; @@ -1038,7 +1253,7 @@ class Node::Struct::Reader { inline ::uint32_t getDiscriminantOffset() const; inline bool hasFields() const; - inline ::capnp::List< ::capnp::schema::Field>::Reader getFields() const; + inline ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>::Reader getFields() const; private: ::capnp::_::StructReader _reader; @@ -1087,11 +1302,11 @@ class Node::Struct::Builder { inline void setDiscriminantOffset( ::uint32_t value); inline bool hasFields(); - inline ::capnp::List< ::capnp::schema::Field>::Builder getFields(); - inline void setFields( ::capnp::List< ::capnp::schema::Field>::Reader value); - inline ::capnp::List< ::capnp::schema::Field>::Builder initFields(unsigned int size); - inline void adoptFields(::capnp::Orphan< ::capnp::List< ::capnp::schema::Field>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Field>> disownFields(); + inline ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>::Builder getFields(); + inline void setFields( ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>::Builder initFields(unsigned int size); + inline void adoptFields(::capnp::Orphan< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>> disownFields(); private: ::capnp::_::StructBuilder _builder; @@ -1137,7 +1352,7 @@ class Node::Enum::Reader { #endif // !CAPNP_LITE inline bool hasEnumerants() const; - inline ::capnp::List< ::capnp::schema::Enumerant>::Reader getEnumerants() const; + inline ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>::Reader getEnumerants() const; private: ::capnp::_::StructReader _reader; @@ -1168,11 +1383,11 @@ class Node::Enum::Builder { #endif // !CAPNP_LITE inline bool hasEnumerants(); - inline ::capnp::List< ::capnp::schema::Enumerant>::Builder getEnumerants(); - inline void setEnumerants( ::capnp::List< ::capnp::schema::Enumerant>::Reader value); - inline ::capnp::List< ::capnp::schema::Enumerant>::Builder initEnumerants(unsigned int size); - inline void adoptEnumerants(::capnp::Orphan< ::capnp::List< ::capnp::schema::Enumerant>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Enumerant>> disownEnumerants(); + inline ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>::Builder getEnumerants(); + inline void setEnumerants( ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>::Builder initEnumerants(unsigned int size); + inline void adoptEnumerants(::capnp::Orphan< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>> disownEnumerants(); private: ::capnp::_::StructBuilder _builder; @@ -1218,10 +1433,10 @@ class Node::Interface::Reader { #endif // !CAPNP_LITE inline bool hasMethods() const; - inline ::capnp::List< ::capnp::schema::Method>::Reader getMethods() const; + inline ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>::Reader getMethods() const; inline bool hasSuperclasses() const; - inline ::capnp::List< ::capnp::schema::Superclass>::Reader getSuperclasses() const; + inline ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>::Reader getSuperclasses() const; private: ::capnp::_::StructReader _reader; @@ -1252,18 +1467,18 @@ class Node::Interface::Builder { #endif // !CAPNP_LITE inline bool hasMethods(); - inline ::capnp::List< ::capnp::schema::Method>::Builder getMethods(); - inline void setMethods( ::capnp::List< ::capnp::schema::Method>::Reader value); - inline ::capnp::List< ::capnp::schema::Method>::Builder initMethods(unsigned int size); - inline void adoptMethods(::capnp::Orphan< ::capnp::List< ::capnp::schema::Method>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Method>> disownMethods(); + inline ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>::Builder getMethods(); + inline void setMethods( ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>::Builder initMethods(unsigned int size); + inline void adoptMethods(::capnp::Orphan< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>> disownMethods(); inline bool hasSuperclasses(); - inline ::capnp::List< ::capnp::schema::Superclass>::Builder getSuperclasses(); - inline void setSuperclasses( ::capnp::List< ::capnp::schema::Superclass>::Reader value); - inline ::capnp::List< ::capnp::schema::Superclass>::Builder initSuperclasses(unsigned int size); - inline void adoptSuperclasses(::capnp::Orphan< ::capnp::List< ::capnp::schema::Superclass>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Superclass>> disownSuperclasses(); + inline ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>::Builder getSuperclasses(); + inline void setSuperclasses( ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>::Builder initSuperclasses(unsigned int size); + inline void adoptSuperclasses(::capnp::Orphan< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>> disownSuperclasses(); private: ::capnp::_::StructBuilder _builder; @@ -1550,7 +1765,7 @@ class Field::Reader { inline ::uint16_t getCodeOrder() const; inline bool hasAnnotations() const; - inline ::capnp::List< ::capnp::schema::Annotation>::Reader getAnnotations() const; + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader getAnnotations() const; inline ::uint16_t getDiscriminantValue() const; @@ -1602,11 +1817,11 @@ class Field::Builder { inline void setCodeOrder( ::uint16_t value); inline bool hasAnnotations(); - inline ::capnp::List< ::capnp::schema::Annotation>::Builder getAnnotations(); - inline void setAnnotations( ::capnp::List< ::capnp::schema::Annotation>::Reader value); - inline ::capnp::List< ::capnp::schema::Annotation>::Builder initAnnotations(unsigned int size); - inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>> disownAnnotations(); + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder getAnnotations(); + inline void setAnnotations( ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder initAnnotations(unsigned int size); + inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>> disownAnnotations(); inline ::uint16_t getDiscriminantValue(); inline void setDiscriminantValue( ::uint16_t value); @@ -1938,7 +2153,7 @@ class Enumerant::Reader { inline ::uint16_t getCodeOrder() const; inline bool hasAnnotations() const; - inline ::capnp::List< ::capnp::schema::Annotation>::Reader getAnnotations() const; + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader getAnnotations() const; private: ::capnp::_::StructReader _reader; @@ -1979,11 +2194,11 @@ class Enumerant::Builder { inline void setCodeOrder( ::uint16_t value); inline bool hasAnnotations(); - inline ::capnp::List< ::capnp::schema::Annotation>::Builder getAnnotations(); - inline void setAnnotations( ::capnp::List< ::capnp::schema::Annotation>::Reader value); - inline ::capnp::List< ::capnp::schema::Annotation>::Builder initAnnotations(unsigned int size); - inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>> disownAnnotations(); + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder getAnnotations(); + inline void setAnnotations( ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder initAnnotations(unsigned int size); + inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>> disownAnnotations(); private: ::capnp::_::StructBuilder _builder; @@ -2125,7 +2340,7 @@ class Method::Reader { inline ::uint64_t getResultStructType() const; inline bool hasAnnotations() const; - inline ::capnp::List< ::capnp::schema::Annotation>::Reader getAnnotations() const; + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader getAnnotations() const; inline bool hasParamBrand() const; inline ::capnp::schema::Brand::Reader getParamBrand() const; @@ -2134,7 +2349,7 @@ class Method::Reader { inline ::capnp::schema::Brand::Reader getResultBrand() const; inline bool hasImplicitParameters() const; - inline ::capnp::List< ::capnp::schema::Node::Parameter>::Reader getImplicitParameters() const; + inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Reader getImplicitParameters() const; private: ::capnp::_::StructReader _reader; @@ -2181,11 +2396,11 @@ class Method::Builder { inline void setResultStructType( ::uint64_t value); inline bool hasAnnotations(); - inline ::capnp::List< ::capnp::schema::Annotation>::Builder getAnnotations(); - inline void setAnnotations( ::capnp::List< ::capnp::schema::Annotation>::Reader value); - inline ::capnp::List< ::capnp::schema::Annotation>::Builder initAnnotations(unsigned int size); - inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>> disownAnnotations(); + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder getAnnotations(); + inline void setAnnotations( ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder initAnnotations(unsigned int size); + inline void adoptAnnotations(::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>> disownAnnotations(); inline bool hasParamBrand(); inline ::capnp::schema::Brand::Builder getParamBrand(); @@ -2202,11 +2417,11 @@ class Method::Builder { inline ::capnp::Orphan< ::capnp::schema::Brand> disownResultBrand(); inline bool hasImplicitParameters(); - inline ::capnp::List< ::capnp::schema::Node::Parameter>::Builder getImplicitParameters(); - inline void setImplicitParameters( ::capnp::List< ::capnp::schema::Node::Parameter>::Reader value); - inline ::capnp::List< ::capnp::schema::Node::Parameter>::Builder initImplicitParameters(unsigned int size); - inline void adoptImplicitParameters(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter>> disownImplicitParameters(); + inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Builder getImplicitParameters(); + inline void setImplicitParameters( ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Builder initImplicitParameters(unsigned int size); + inline void adoptImplicitParameters(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>> disownImplicitParameters(); private: ::capnp::_::StructBuilder _builder; @@ -3155,7 +3370,7 @@ class Brand::Reader { #endif // !CAPNP_LITE inline bool hasScopes() const; - inline ::capnp::List< ::capnp::schema::Brand::Scope>::Reader getScopes() const; + inline ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>::Reader getScopes() const; private: ::capnp::_::StructReader _reader; @@ -3186,11 +3401,11 @@ class Brand::Builder { #endif // !CAPNP_LITE inline bool hasScopes(); - inline ::capnp::List< ::capnp::schema::Brand::Scope>::Builder getScopes(); - inline void setScopes( ::capnp::List< ::capnp::schema::Brand::Scope>::Reader value); - inline ::capnp::List< ::capnp::schema::Brand::Scope>::Builder initScopes(unsigned int size); - inline void adoptScopes(::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Scope>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Scope>> disownScopes(); + inline ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>::Builder getScopes(); + inline void setScopes( ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>::Builder initScopes(unsigned int size); + inline void adoptScopes(::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>> disownScopes(); private: ::capnp::_::StructBuilder _builder; @@ -3240,7 +3455,7 @@ class Brand::Scope::Reader { inline bool isBind() const; inline bool hasBind() const; - inline ::capnp::List< ::capnp::schema::Brand::Binding>::Reader getBind() const; + inline ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>::Reader getBind() const; inline bool isInherit() const; inline ::capnp::Void getInherit() const; @@ -3279,11 +3494,11 @@ class Brand::Scope::Builder { inline bool isBind(); inline bool hasBind(); - inline ::capnp::List< ::capnp::schema::Brand::Binding>::Builder getBind(); - inline void setBind( ::capnp::List< ::capnp::schema::Brand::Binding>::Reader value); - inline ::capnp::List< ::capnp::schema::Brand::Binding>::Builder initBind(unsigned int size); - inline void adoptBind(::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Binding>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Binding>> disownBind(); + inline ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>::Builder getBind(); + inline void setBind( ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>::Builder initBind(unsigned int size); + inline void adoptBind(::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>> disownBind(); inline bool isInherit(); inline ::capnp::Void getInherit(); @@ -3831,14 +4046,17 @@ class CodeGeneratorRequest::Reader { #endif // !CAPNP_LITE inline bool hasNodes() const; - inline ::capnp::List< ::capnp::schema::Node>::Reader getNodes() const; + inline ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>::Reader getNodes() const; inline bool hasRequestedFiles() const; - inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>::Reader getRequestedFiles() const; + inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>::Reader getRequestedFiles() const; inline bool hasCapnpVersion() const; inline ::capnp::schema::CapnpVersion::Reader getCapnpVersion() const; + inline bool hasSourceInfo() const; + inline ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>::Reader getSourceInfo() const; + private: ::capnp::_::StructReader _reader; template @@ -3868,18 +4086,18 @@ class CodeGeneratorRequest::Builder { #endif // !CAPNP_LITE inline bool hasNodes(); - inline ::capnp::List< ::capnp::schema::Node>::Builder getNodes(); - inline void setNodes( ::capnp::List< ::capnp::schema::Node>::Reader value); - inline ::capnp::List< ::capnp::schema::Node>::Builder initNodes(unsigned int size); - inline void adoptNodes(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node>> disownNodes(); + inline ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>::Builder getNodes(); + inline void setNodes( ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>::Builder initNodes(unsigned int size); + inline void adoptNodes(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>> disownNodes(); inline bool hasRequestedFiles(); - inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>::Builder getRequestedFiles(); - inline void setRequestedFiles( ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>::Reader value); - inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>::Builder initRequestedFiles(unsigned int size); - inline void adoptRequestedFiles(::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>> disownRequestedFiles(); + inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>::Builder getRequestedFiles(); + inline void setRequestedFiles( ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>::Builder initRequestedFiles(unsigned int size); + inline void adoptRequestedFiles(::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>> disownRequestedFiles(); inline bool hasCapnpVersion(); inline ::capnp::schema::CapnpVersion::Builder getCapnpVersion(); @@ -3888,6 +4106,13 @@ class CodeGeneratorRequest::Builder { inline void adoptCapnpVersion(::capnp::Orphan< ::capnp::schema::CapnpVersion>&& value); inline ::capnp::Orphan< ::capnp::schema::CapnpVersion> disownCapnpVersion(); + inline bool hasSourceInfo(); + inline ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>::Builder getSourceInfo(); + inline void setSourceInfo( ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>::Builder initSourceInfo(unsigned int size); + inline void adoptSourceInfo(::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>> disownSourceInfo(); + private: ::capnp::_::StructBuilder _builder; template @@ -3938,7 +4163,7 @@ class CodeGeneratorRequest::RequestedFile::Reader { inline ::capnp::Text::Reader getFilename() const; inline bool hasImports() const; - inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>::Reader getImports() const; + inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>::Reader getImports() const; private: ::capnp::_::StructReader _reader; @@ -3979,11 +4204,11 @@ class CodeGeneratorRequest::RequestedFile::Builder { inline ::capnp::Orphan< ::capnp::Text> disownFilename(); inline bool hasImports(); - inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>::Builder getImports(); - inline void setImports( ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>::Reader value); - inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>::Builder initImports(unsigned int size); - inline void adoptImports(::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>>&& value); - inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>> disownImports(); + inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>::Builder getImports(); + inline void setImports( ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>::Reader value); + inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>::Builder initImports(unsigned int size); + inline void adoptImports(::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>>&& value); + inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>> disownImports(); private: ::capnp::_::StructBuilder _builder; @@ -4192,29 +4417,29 @@ inline bool Node::Builder::hasNestedNodes() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Node::NestedNode>::Reader Node::Reader::getNestedNodes() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>::Reader Node::Reader::getNestedNodes() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Node::NestedNode>::Builder Node::Builder::getNestedNodes() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>::Builder Node::Builder::getNestedNodes() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void Node::Builder::setNestedNodes( ::capnp::List< ::capnp::schema::Node::NestedNode>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode>>::set(_builder.getPointerField( +inline void Node::Builder::setNestedNodes( ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Node::NestedNode>::Builder Node::Builder::initNestedNodes(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>::Builder Node::Builder::initNestedNodes(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void Node::Builder::adoptNestedNodes( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::NestedNode>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::NestedNode>> Node::Builder::disownNestedNodes() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>> Node::Builder::disownNestedNodes() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::NestedNode, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -4226,29 +4451,29 @@ inline bool Node::Builder::hasAnnotations() { return !_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Annotation>::Reader Node::Reader::getAnnotations() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader Node::Reader::getAnnotations() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Annotation>::Builder Node::Builder::getAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder Node::Builder::getAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } -inline void Node::Builder::setAnnotations( ::capnp::List< ::capnp::schema::Annotation>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::set(_builder.getPointerField( +inline void Node::Builder::setAnnotations( ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Annotation>::Builder Node::Builder::initAnnotations(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder Node::Builder::initAnnotations(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), size); } inline void Node::Builder::adoptAnnotations( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>> Node::Builder::disownAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>> Node::Builder::disownAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<2>() * ::capnp::POINTERS)); } @@ -4416,29 +4641,29 @@ inline bool Node::Builder::hasParameters() { return !_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Node::Parameter>::Reader Node::Reader::getParameters() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Reader Node::Reader::getParameters() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Node::Parameter>::Builder Node::Builder::getParameters() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Builder Node::Builder::getParameters() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS)); } -inline void Node::Builder::setParameters( ::capnp::List< ::capnp::schema::Node::Parameter>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::set(_builder.getPointerField( +inline void Node::Builder::setParameters( ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Node::Parameter>::Builder Node::Builder::initParameters(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Builder Node::Builder::initParameters(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS), size); } inline void Node::Builder::adoptParameters( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter>> Node::Builder::disownParameters() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>> Node::Builder::disownParameters() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<5>() * ::capnp::POINTERS)); } @@ -4538,6 +4763,122 @@ inline void Node::NestedNode::Builder::setId( ::uint64_t value) { ::capnp::bounded<0>() * ::capnp::ELEMENTS, value); } +inline ::uint64_t Node::SourceInfo::Reader::getId() const { + return _reader.getDataField< ::uint64_t>( + ::capnp::bounded<0>() * ::capnp::ELEMENTS); +} + +inline ::uint64_t Node::SourceInfo::Builder::getId() { + return _builder.getDataField< ::uint64_t>( + ::capnp::bounded<0>() * ::capnp::ELEMENTS); +} +inline void Node::SourceInfo::Builder::setId( ::uint64_t value) { + _builder.setDataField< ::uint64_t>( + ::capnp::bounded<0>() * ::capnp::ELEMENTS, value); +} + +inline bool Node::SourceInfo::Reader::hasDocComment() const { + return !_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline bool Node::SourceInfo::Builder::hasDocComment() { + return !_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::Text::Reader Node::SourceInfo::Reader::getDocComment() const { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline ::capnp::Text::Builder Node::SourceInfo::Builder::getDocComment() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline void Node::SourceInfo::Builder::setDocComment( ::capnp::Text::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), value); +} +inline ::capnp::Text::Builder Node::SourceInfo::Builder::initDocComment(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), size); +} +inline void Node::SourceInfo::Builder::adoptDocComment( + ::capnp::Orphan< ::capnp::Text>&& value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Text> Node::SourceInfo::Builder::disownDocComment() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} + +inline bool Node::SourceInfo::Reader::hasMembers() const { + return !_reader.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); +} +inline bool Node::SourceInfo::Builder::hasMembers() { + return !_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>::Reader Node::SourceInfo::Reader::getMembers() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} +inline ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>::Builder Node::SourceInfo::Builder::getMembers() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} +inline void Node::SourceInfo::Builder::setMembers( ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), value); +} +inline ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>::Builder Node::SourceInfo::Builder::initMembers(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), size); +} +inline void Node::SourceInfo::Builder::adoptMembers( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>> Node::SourceInfo::Builder::disownMembers() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo::Member, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( + ::capnp::bounded<1>() * ::capnp::POINTERS)); +} + +inline bool Node::SourceInfo::Member::Reader::hasDocComment() const { + return !_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline bool Node::SourceInfo::Member::Builder::hasDocComment() { + return !_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::Text::Reader Node::SourceInfo::Member::Reader::getDocComment() const { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_reader.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline ::capnp::Text::Builder Node::SourceInfo::Member::Builder::getDocComment() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::get(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} +inline void Node::SourceInfo::Member::Builder::setDocComment( ::capnp::Text::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::set(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), value); +} +inline ::capnp::Text::Builder Node::SourceInfo::Member::Builder::initDocComment(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::Text>::init(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), size); +} +inline void Node::SourceInfo::Member::Builder::adoptDocComment( + ::capnp::Orphan< ::capnp::Text>&& value) { + ::capnp::_::PointerHelpers< ::capnp::Text>::adopt(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::Text> Node::SourceInfo::Member::Builder::disownDocComment() { + return ::capnp::_::PointerHelpers< ::capnp::Text>::disown(_builder.getPointerField( + ::capnp::bounded<0>() * ::capnp::POINTERS)); +} + inline ::uint16_t Node::Struct::Reader::getDataWordCount() const { return _reader.getDataField< ::uint16_t>( ::capnp::bounded<7>() * ::capnp::ELEMENTS); @@ -4630,29 +4971,29 @@ inline bool Node::Struct::Builder::hasFields() { return !_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Field>::Reader Node::Struct::Reader::getFields() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>::Reader Node::Struct::Reader::getFields() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Field>::Builder Node::Struct::Builder::getFields() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>::Builder Node::Struct::Builder::getFields() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } -inline void Node::Struct::Builder::setFields( ::capnp::List< ::capnp::schema::Field>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field>>::set(_builder.getPointerField( +inline void Node::Struct::Builder::setFields( ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Field>::Builder Node::Struct::Builder::initFields(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>::Builder Node::Struct::Builder::initFields(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), size); } inline void Node::Struct::Builder::adoptFields( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Field>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Field>> Node::Struct::Builder::disownFields() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>> Node::Struct::Builder::disownFields() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Field, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } @@ -4664,29 +5005,29 @@ inline bool Node::Enum::Builder::hasEnumerants() { return !_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Enumerant>::Reader Node::Enum::Reader::getEnumerants() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>::Reader Node::Enum::Reader::getEnumerants() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Enumerant>::Builder Node::Enum::Builder::getEnumerants() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>::Builder Node::Enum::Builder::getEnumerants() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } -inline void Node::Enum::Builder::setEnumerants( ::capnp::List< ::capnp::schema::Enumerant>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant>>::set(_builder.getPointerField( +inline void Node::Enum::Builder::setEnumerants( ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Enumerant>::Builder Node::Enum::Builder::initEnumerants(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>::Builder Node::Enum::Builder::initEnumerants(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), size); } inline void Node::Enum::Builder::adoptEnumerants( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Enumerant>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Enumerant>> Node::Enum::Builder::disownEnumerants() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>> Node::Enum::Builder::disownEnumerants() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Enumerant, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } @@ -4698,29 +5039,29 @@ inline bool Node::Interface::Builder::hasMethods() { return !_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Method>::Reader Node::Interface::Reader::getMethods() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>::Reader Node::Interface::Reader::getMethods() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Method>::Builder Node::Interface::Builder::getMethods() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>::Builder Node::Interface::Builder::getMethods() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } -inline void Node::Interface::Builder::setMethods( ::capnp::List< ::capnp::schema::Method>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method>>::set(_builder.getPointerField( +inline void Node::Interface::Builder::setMethods( ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Method>::Builder Node::Interface::Builder::initMethods(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>::Builder Node::Interface::Builder::initMethods(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), size); } inline void Node::Interface::Builder::adoptMethods( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Method>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Method>> Node::Interface::Builder::disownMethods() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>> Node::Interface::Builder::disownMethods() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Method, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<3>() * ::capnp::POINTERS)); } @@ -4732,29 +5073,29 @@ inline bool Node::Interface::Builder::hasSuperclasses() { return !_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Superclass>::Reader Node::Interface::Reader::getSuperclasses() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>::Reader Node::Interface::Reader::getSuperclasses() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Superclass>::Builder Node::Interface::Builder::getSuperclasses() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>::Builder Node::Interface::Builder::getSuperclasses() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS)); } -inline void Node::Interface::Builder::setSuperclasses( ::capnp::List< ::capnp::schema::Superclass>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass>>::set(_builder.getPointerField( +inline void Node::Interface::Builder::setSuperclasses( ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Superclass>::Builder Node::Interface::Builder::initSuperclasses(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>::Builder Node::Interface::Builder::initSuperclasses(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS), size); } inline void Node::Interface::Builder::adoptSuperclasses( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Superclass>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Superclass>> Node::Interface::Builder::disownSuperclasses() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>> Node::Interface::Builder::disownSuperclasses() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Superclass, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS)); } @@ -5108,29 +5449,29 @@ inline bool Field::Builder::hasAnnotations() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Annotation>::Reader Field::Reader::getAnnotations() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader Field::Reader::getAnnotations() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Annotation>::Builder Field::Builder::getAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder Field::Builder::getAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void Field::Builder::setAnnotations( ::capnp::List< ::capnp::schema::Annotation>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::set(_builder.getPointerField( +inline void Field::Builder::setAnnotations( ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Annotation>::Builder Field::Builder::initAnnotations(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder Field::Builder::initAnnotations(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void Field::Builder::adoptAnnotations( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>> Field::Builder::disownAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>> Field::Builder::disownAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -5448,29 +5789,29 @@ inline bool Enumerant::Builder::hasAnnotations() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Annotation>::Reader Enumerant::Reader::getAnnotations() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader Enumerant::Reader::getAnnotations() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Annotation>::Builder Enumerant::Builder::getAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder Enumerant::Builder::getAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void Enumerant::Builder::setAnnotations( ::capnp::List< ::capnp::schema::Annotation>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::set(_builder.getPointerField( +inline void Enumerant::Builder::setAnnotations( ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Annotation>::Builder Enumerant::Builder::initAnnotations(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder Enumerant::Builder::initAnnotations(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void Enumerant::Builder::adoptAnnotations( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>> Enumerant::Builder::disownAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>> Enumerant::Builder::disownAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -5611,29 +5952,29 @@ inline bool Method::Builder::hasAnnotations() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Annotation>::Reader Method::Reader::getAnnotations() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader Method::Reader::getAnnotations() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Annotation>::Builder Method::Builder::getAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder Method::Builder::getAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void Method::Builder::setAnnotations( ::capnp::List< ::capnp::schema::Annotation>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::set(_builder.getPointerField( +inline void Method::Builder::setAnnotations( ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Annotation>::Builder Method::Builder::initAnnotations(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>::Builder Method::Builder::initAnnotations(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void Method::Builder::adoptAnnotations( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation>> Method::Builder::disownAnnotations() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>> Method::Builder::disownAnnotations() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Annotation, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -5723,29 +6064,29 @@ inline bool Method::Builder::hasImplicitParameters() { return !_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Node::Parameter>::Reader Method::Reader::getImplicitParameters() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Reader Method::Reader::getImplicitParameters() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Node::Parameter>::Builder Method::Builder::getImplicitParameters() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Builder Method::Builder::getImplicitParameters() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS)); } -inline void Method::Builder::setImplicitParameters( ::capnp::List< ::capnp::schema::Node::Parameter>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::set(_builder.getPointerField( +inline void Method::Builder::setImplicitParameters( ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Node::Parameter>::Builder Method::Builder::initImplicitParameters(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>::Builder Method::Builder::initImplicitParameters(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS), size); } inline void Method::Builder::adoptImplicitParameters( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter>> Method::Builder::disownImplicitParameters() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>> Method::Builder::disownImplicitParameters() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::Parameter, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<4>() * ::capnp::POINTERS)); } @@ -6674,29 +7015,29 @@ inline bool Brand::Builder::hasScopes() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Brand::Scope>::Reader Brand::Reader::getScopes() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>::Reader Brand::Reader::getScopes() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Brand::Scope>::Builder Brand::Builder::getScopes() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>::Builder Brand::Builder::getScopes() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void Brand::Builder::setScopes( ::capnp::List< ::capnp::schema::Brand::Scope>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope>>::set(_builder.getPointerField( +inline void Brand::Builder::setScopes( ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Brand::Scope>::Builder Brand::Builder::initScopes(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>::Builder Brand::Builder::initScopes(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void Brand::Builder::adoptScopes( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Scope>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Scope>> Brand::Builder::disownScopes() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>> Brand::Builder::disownScopes() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Scope, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -6739,41 +7080,41 @@ inline bool Brand::Scope::Builder::hasBind() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Brand::Binding>::Reader Brand::Scope::Reader::getBind() const { +inline ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>::Reader Brand::Scope::Reader::getBind() const { KJ_IREQUIRE((which() == Brand::Scope::BIND), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding>>::get(_reader.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Brand::Binding>::Builder Brand::Scope::Builder::getBind() { +inline ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>::Builder Brand::Scope::Builder::getBind() { KJ_IREQUIRE((which() == Brand::Scope::BIND), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding>>::get(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void Brand::Scope::Builder::setBind( ::capnp::List< ::capnp::schema::Brand::Binding>::Reader value) { +inline void Brand::Scope::Builder::setBind( ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>::Reader value) { _builder.setDataField( ::capnp::bounded<4>() * ::capnp::ELEMENTS, Brand::Scope::BIND); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding>>::set(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Brand::Binding>::Builder Brand::Scope::Builder::initBind(unsigned int size) { +inline ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>::Builder Brand::Scope::Builder::initBind(unsigned int size) { _builder.setDataField( ::capnp::bounded<4>() * ::capnp::ELEMENTS, Brand::Scope::BIND); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding>>::init(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void Brand::Scope::Builder::adoptBind( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Binding>>&& value) { + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>>&& value) { _builder.setDataField( ::capnp::bounded<4>() * ::capnp::ELEMENTS, Brand::Scope::BIND); - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding>>::adopt(_builder.getPointerField( + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Binding>> Brand::Scope::Builder::disownBind() { +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>> Brand::Scope::Builder::disownBind() { KJ_IREQUIRE((which() == Brand::Scope::BIND), "Must check which() before get()ing a union member."); - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding>>::disown(_builder.getPointerField( + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Brand::Binding, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -7626,29 +7967,29 @@ inline bool CodeGeneratorRequest::Builder::hasNodes() { return !_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::Node>::Reader CodeGeneratorRequest::Reader::getNodes() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>::Reader CodeGeneratorRequest::Reader::getNodes() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::Node>::Builder CodeGeneratorRequest::Builder::getNodes() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>::Builder CodeGeneratorRequest::Builder::getNodes() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } -inline void CodeGeneratorRequest::Builder::setNodes( ::capnp::List< ::capnp::schema::Node>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node>>::set(_builder.getPointerField( +inline void CodeGeneratorRequest::Builder::setNodes( ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::Node>::Builder CodeGeneratorRequest::Builder::initNodes(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>::Builder CodeGeneratorRequest::Builder::initNodes(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), size); } inline void CodeGeneratorRequest::Builder::adoptNodes( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node>> CodeGeneratorRequest::Builder::disownNodes() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>> CodeGeneratorRequest::Builder::disownNodes() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<0>() * ::capnp::POINTERS)); } @@ -7660,29 +8001,29 @@ inline bool CodeGeneratorRequest::Builder::hasRequestedFiles() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>::Reader CodeGeneratorRequest::Reader::getRequestedFiles() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>::Reader CodeGeneratorRequest::Reader::getRequestedFiles() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>::Builder CodeGeneratorRequest::Builder::getRequestedFiles() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>::Builder CodeGeneratorRequest::Builder::getRequestedFiles() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void CodeGeneratorRequest::Builder::setRequestedFiles( ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>>::set(_builder.getPointerField( +inline void CodeGeneratorRequest::Builder::setRequestedFiles( ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>::Builder CodeGeneratorRequest::Builder::initRequestedFiles(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>::Builder CodeGeneratorRequest::Builder::initRequestedFiles(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void CodeGeneratorRequest::Builder::adoptRequestedFiles( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>> CodeGeneratorRequest::Builder::disownRequestedFiles() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>> CodeGeneratorRequest::Builder::disownRequestedFiles() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -7725,6 +8066,40 @@ inline ::capnp::Orphan< ::capnp::schema::CapnpVersion> CodeGeneratorRequest::Bui ::capnp::bounded<2>() * ::capnp::POINTERS)); } +inline bool CodeGeneratorRequest::Reader::hasSourceInfo() const { + return !_reader.getPointerField( + ::capnp::bounded<3>() * ::capnp::POINTERS).isNull(); +} +inline bool CodeGeneratorRequest::Builder::hasSourceInfo() { + return !_builder.getPointerField( + ::capnp::bounded<3>() * ::capnp::POINTERS).isNull(); +} +inline ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>::Reader CodeGeneratorRequest::Reader::getSourceInfo() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( + ::capnp::bounded<3>() * ::capnp::POINTERS)); +} +inline ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>::Builder CodeGeneratorRequest::Builder::getSourceInfo() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( + ::capnp::bounded<3>() * ::capnp::POINTERS)); +} +inline void CodeGeneratorRequest::Builder::setSourceInfo( ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( + ::capnp::bounded<3>() * ::capnp::POINTERS), value); +} +inline ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>::Builder CodeGeneratorRequest::Builder::initSourceInfo(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( + ::capnp::bounded<3>() * ::capnp::POINTERS), size); +} +inline void CodeGeneratorRequest::Builder::adoptSourceInfo( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( + ::capnp::bounded<3>() * ::capnp::POINTERS), kj::mv(value)); +} +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>> CodeGeneratorRequest::Builder::disownSourceInfo() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::Node::SourceInfo, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( + ::capnp::bounded<3>() * ::capnp::POINTERS)); +} + inline ::uint64_t CodeGeneratorRequest::RequestedFile::Reader::getId() const { return _reader.getDataField< ::uint64_t>( ::capnp::bounded<0>() * ::capnp::ELEMENTS); @@ -7781,29 +8156,29 @@ inline bool CodeGeneratorRequest::RequestedFile::Builder::hasImports() { return !_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS).isNull(); } -inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>::Reader CodeGeneratorRequest::RequestedFile::Reader::getImports() const { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>>::get(_reader.getPointerField( +inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>::Reader CodeGeneratorRequest::RequestedFile::Reader::getImports() const { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>>::get(_reader.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>::Builder CodeGeneratorRequest::RequestedFile::Builder::getImports() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>>::get(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>::Builder CodeGeneratorRequest::RequestedFile::Builder::getImports() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>>::get(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } -inline void CodeGeneratorRequest::RequestedFile::Builder::setImports( ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>::Reader value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>>::set(_builder.getPointerField( +inline void CodeGeneratorRequest::RequestedFile::Builder::setImports( ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>::Reader value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>>::set(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), value); } -inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>::Builder CodeGeneratorRequest::RequestedFile::Builder::initImports(unsigned int size) { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>>::init(_builder.getPointerField( +inline ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>::Builder CodeGeneratorRequest::RequestedFile::Builder::initImports(unsigned int size) { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>>::init(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), size); } inline void CodeGeneratorRequest::RequestedFile::Builder::adoptImports( - ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>>&& value) { - ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>>::adopt(_builder.getPointerField( + ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>>&& value) { + ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>>::adopt(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS), kj::mv(value)); } -inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>> CodeGeneratorRequest::RequestedFile::Builder::disownImports() { - return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import>>::disown(_builder.getPointerField( +inline ::capnp::Orphan< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>> CodeGeneratorRequest::RequestedFile::Builder::disownImports() { + return ::capnp::_::PointerHelpers< ::capnp::List< ::capnp::schema::CodeGeneratorRequest::RequestedFile::Import, ::capnp::Kind::STRUCT>>::disown(_builder.getPointerField( ::capnp::bounded<1>() * ::capnp::POINTERS)); } @@ -7858,4 +8233,5 @@ inline ::capnp::Orphan< ::capnp::Text> CodeGeneratorRequest::RequestedFile::Impo } // namespace } // namespace -#endif // CAPNP_INCLUDED_a93fc509624c72d9_ +CAPNP_END_HEADER + diff --git a/c++/src/capnp/schema.h b/c++/src/capnp/schema.h index d59fa75236..5eebacba28 100644 --- a/c++/src/capnp/schema.h +++ b/c++/src/capnp/schema.h @@ -19,18 +19,28 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_SCHEMA_H_ -#define CAPNP_SCHEMA_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #if CAPNP_LITE #error "Reflection APIs, including this header, are not available in lite mode." #endif +#undef CONST +// For some ridiculous reason, Windows defines CONST to const. We have an enum value called CONST +// in schema.capnp.h, so if this is defined, compilation is gonna fail. So we undef it because +// that seems strictly better than failing entirely. But this could cause trouble for people later +// on if they, say, include windows.h, then include schema.h, then include another windows API +// header that uses CONST. I suppose they may have to re-#define CONST in between, or change the +// header ordering. Sorry. +// +// Please don't file a bug report telling us to change our enum naming style. You are at least +// seven years too late. + #include +#include +#include // work-around macro conflict with `VOID` + +CAPNP_BEGIN_HEADER namespace capnp { @@ -82,7 +92,7 @@ class Schema { // Get the encoded schema node content as a single message segment. It is safe to read as an // unchecked message. - Schema getDependency(uint64_t id) const KJ_DEPRECATED("Does not handle generics correctly."); + Schema getDependency(uint64_t id) const CAPNP_DEPRECATED("Does not handle generics correctly."); // DEPRECATED: This method cannot correctly account for generic type parameter bindings that // may apply to the dependency. Instead of using this method, use a method of the Schema API // that corresponds to the exact kind of dependency. For example, to get a field type, use @@ -116,6 +126,10 @@ class Schema { BrandArgumentList getBrandArgumentsAtScope(uint64_t scopeId) const; // Gets the values bound to the brand parameters at the given scope. + kj::Array getGenericScopeIds() const; + // Returns the type IDs of all parent scopes that have generic parameters, to which this type is + // subject. + StructSchema asStruct() const; EnumSchema asEnum() const; InterfaceSchema asInterface() const; @@ -129,6 +143,8 @@ class Schema { // you want to check if two Schemas represent the same type (but possibly different versions of // it), compare their IDs instead. + inline uint hashCode() const { return kj::hashCode(raw); } + template void requireUsableAs() const; // Throws an exception if a value with this Schema cannot safely be cast to a native value of @@ -140,6 +156,9 @@ class Schema { kj::StringPtr getShortDisplayName() const; // Get the short version of the node's display name. + const kj::StringPtr getUnqualifiedName() const; + // Get the display name "nickname" of this node minus the prefix + private: const _::RawBrandedSchema* raw; @@ -257,6 +276,28 @@ class StructSchema: public Schema { // there is no such field. (If the schema does not represent a union or a struct containing // an unnamed union, then this always returns null.) + bool isStreamResult() const; + // Convenience method to check if this is the result type of a streaming RPC method. + + bool mayContainCapabilities() const { return raw->generic->mayContainCapabilities; } + // Returns true if a struct of this type may transitively contain any capabilities. I.e., are + // any of the fields an interface type, or a struct type that may in turn contain capabilities? + // + // This is meant for optimizations where various bookkeeping can possibly be skipped if it is + // known in advance that there are no capabilities. Note that this may conservatively return true + // spuriously, e.g. if it would be inconvenient to compute the correct answer. A false positive + // should never cause incorrect behavior, just potentially hurt performance. + // + // It's important to keep in mind that even if a schema has no capability-typed fields today, + // they could always be added in future versions of the schema. So, just because the schema + // doesn't contain capabilities does NOT necessarily mean that an instance of the struct can't + // contain capabilities. However, it is a pretty good hint that the application won't plan to + // use such capabilities -- for example, if there are no caps in an RPC call's response type + // according to the client's version of the schema, then the client clearly isn't going to try + // to make any pipelined calls. The server could be operating with a new version of the schema + // and could actually return capabilities, but for the client to make a pipelined call, the + // client would have to know in advance that capabilities could be returned. + private: StructSchema(Schema base): Schema(base) {} template static inline StructSchema fromImpl() { @@ -302,6 +343,7 @@ class StructSchema::Field { inline bool operator==(const Field& other) const; inline bool operator!=(const Field& other) const { return !(*this == other); } + inline uint hashCode() const; private: StructSchema parent; @@ -400,6 +442,7 @@ class EnumSchema::Enumerant { inline bool operator==(const Enumerant& other) const; inline bool operator!=(const Enumerant& other) const { return !(*this == other); } + inline uint hashCode() const; private: EnumSchema parent; @@ -486,12 +529,16 @@ class InterfaceSchema::Method { inline uint16_t getOrdinal() const { return ordinal; } inline uint getIndex() const { return ordinal; } + bool isStreaming() const { return getResultType().isStreamResult(); } + // Check if this is a streaming method. + StructSchema getParamType() const; StructSchema getResultType() const; // Get the parameter and result types, including substituting generic parameters. inline bool operator==(const Method& other) const; inline bool operator!=(const Method& other) const { return !(*this == other); } + inline uint hashCode() const; private: InterfaceSchema parent; @@ -599,6 +646,8 @@ class Type { template inline static Type from(); + template + inline static Type from(T&& value); inline schema::Type::Which which() const; @@ -642,7 +691,7 @@ class Type { bool operator==(const Type& other) const; inline bool operator!=(const Type& other) const { return !(*this == other); } - size_t hashCode() const; + uint hashCode() const; inline Type wrapInList(uint depth = 1) const; // Return the Type formed by wrapping this type in List() `depth` times. @@ -680,6 +729,9 @@ class Type { void requireUsableAs(Type expected) const; + template + struct FromValueImpl; + friend class ListSchema; // only for requireUsableAs() }; @@ -701,7 +753,7 @@ class ListSchema { // Construct the schema for a list of the given type. static ListSchema of(schema::Type::Reader elementType, Schema context) - KJ_DEPRECATED("Does not handle generics correctly."); + CAPNP_DEPRECATED("Does not handle generics correctly."); // DEPRECATED: This method cannot correctly account for generic type parameter bindings that // may apply to the input type. Instead of using this method, use a method of the Schema API // that corresponds to the exact kind of dependency. For example, to get a field type, use @@ -791,6 +843,16 @@ inline bool InterfaceSchema::Method::operator==(const Method& other) const { return parent == other.parent && ordinal == other.ordinal; } +inline uint StructSchema::Field::hashCode() const { + return kj::hashCode(parent, index); +} +inline uint EnumSchema::Enumerant::hashCode() const { + return kj::hashCode(parent, ordinal); +} +inline uint InterfaceSchema::Method::hashCode() const { + return kj::hashCode(parent, ordinal); +} + inline ListSchema ListSchema::of(StructSchema elementType) { return ListSchema(Type(elementType)); } @@ -899,6 +961,29 @@ inline schema::Type::AnyPointer::Unconstrained::Which Type::whichAnyPointerKind( template inline Type Type::from() { return Type(Schema::from()); } +template +struct Type::FromValueImpl { + template + static inline Type type(U&& value) { + return Type::from(); + } +}; + +template +struct Type::FromValueImpl { + template + static inline Type type(U&& value) { + // All dynamic types have getSchema(). + return value.getSchema(); + } +}; + +template +inline Type Type::from(T&& value) { + typedef FromAny> Base; + return Type::FromValueImpl()>::type(kj::fwd(value)); +} + inline bool Type::isVoid () const { return baseType == schema::Type::VOID && listDepth == 0; } inline bool Type::isBool () const { return baseType == schema::Type::BOOL && listDepth == 0; } inline bool Type::isInt8 () const { return baseType == schema::Type::INT8 && listDepth == 0; } @@ -931,4 +1016,4 @@ inline Type Type::wrapInList(uint depth) const { } // namespace capnp -#endif // CAPNP_SCHEMA_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/serialize-async-test.c++ b/c++/src/capnp/serialize-async-test.c++ index d153246635..dcae6b060f 100644 --- a/c++/src/capnp/serialize-async-test.c++ +++ b/c++/src/capnp/serialize-async-test.c++ @@ -19,6 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#if _WIN32 +#include +#endif + #include "serialize-async.h" #include "serialize.h" #include @@ -27,9 +35,9 @@ #include #include "test-util.h" #include +#include #if _WIN32 -#define WIN32_LEAN_AND_MEAN #include #include namespace kj { @@ -339,6 +347,219 @@ TEST(SerializeAsyncTest, WriteAsyncEvenSegmentCount) { writeMessage(*output, message).wait(ioContext.waitScope); } +TEST(SerializeAsyncTest, WriteMultipleMessagesAsync) { + PipeWithSmallBuffer fds; + auto ioContext = kj::setupAsyncIo(); + auto output = ioContext.lowLevelProvider->wrapOutputFd(fds[1]); + + const int numMessages = 5; + const int baseListSize = 16; + auto messages = kj::heapArrayBuilder(numMessages); + for (int i = 0; i < numMessages; ++i) { + messages.add(i+1); + auto root = messages[i].getRoot(); + auto list = root.initStructList(baseListSize+i); + for (auto element: list) { + initTestMessage(element); + } + } + + kj::Thread thread([&]() { + SocketInputStream input(fds[0]); + for (int i = 0; i < numMessages; ++i) { + InputStreamMessageReader reader(input); + auto listReader = reader.getRoot().getStructList(); + EXPECT_EQ(baseListSize+i, listReader.size()); + for (auto element: listReader) { + checkTestMessage(element); + } + } + }); + + auto msgs = kj::heapArray(numMessages); + for (int i = 0; i < numMessages; ++i) { + msgs[i] = &messages[i]; + } + writeMessages(*output, msgs).wait(ioContext.waitScope); +} + +void writeSmallMessage(kj::OutputStream& output, kj::StringPtr text) { + capnp::MallocMessageBuilder message; + message.getRoot().getAnyPointerField().setAs(text); + writeMessage(output, message); +} + +void expectSmallMessage(MessageStream& stream, kj::StringPtr text, kj::WaitScope& waitScope) { + auto msg = stream.readMessage().wait(waitScope); + KJ_EXPECT(msg->getRoot().getAnyPointerField().getAs() == text); +} + +void writeBigMessage(kj::OutputStream& output) { + capnp::MallocMessageBuilder message(4); // first segment is small + initTestMessage(message.getRoot()); + writeMessage(output, message); +} + +void expectBigMessage(MessageStream& stream, kj::WaitScope& waitScope) { + auto msg = stream.readMessage().wait(waitScope); + checkTestMessage(msg->getRoot()); +} + +KJ_TEST("BufferedMessageStream basics") { + // Encode input data. + kj::VectorOutputStream data; + + writeSmallMessage(data, "foo"); + + KJ_EXPECT(data.getArray().size() / sizeof(word) == 4); + + // A big message (more than half a buffer) + writeBigMessage(data); + + KJ_EXPECT(data.getArray().size() / sizeof(word) > 16); + + writeSmallMessage(data, "bar"); + writeSmallMessage(data, "baz"); + writeSmallMessage(data, "qux"); + + // Run the test. + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + auto writePromise = pipe.ends[1]->write(data.getArray().begin(), data.getArray().size()); + + uint callbackCallCount = 0; + auto callback = [&](MessageReader& reader) { + ++callbackCallCount; + return false; + }; + + BufferedMessageStream stream(*pipe.ends[0], callback, 16); + expectSmallMessage(stream, "foo", waitScope); + KJ_EXPECT(callbackCallCount == 1); + + KJ_EXPECT(!writePromise.poll(waitScope)); + + expectBigMessage(stream, waitScope); + KJ_EXPECT(callbackCallCount == 1); // no callback on big message + + KJ_EXPECT(!writePromise.poll(waitScope)); + + expectSmallMessage(stream, "bar", waitScope); + KJ_EXPECT(callbackCallCount == 2); + + // All data is now in the buffer, so this part is done. + KJ_EXPECT(writePromise.poll(waitScope)); + + expectSmallMessage(stream, "baz", waitScope); + expectSmallMessage(stream, "qux", waitScope); + KJ_EXPECT(callbackCallCount == 4); + + auto eofPromise = stream.MessageStream::tryReadMessage(); + KJ_EXPECT(!eofPromise.poll(waitScope)); + + pipe.ends[1]->shutdownWrite(); + KJ_EXPECT(eofPromise.wait(waitScope) == nullptr); +} + +KJ_TEST("BufferedMessageStream fragmented reads") { + // Encode input data. + kj::VectorOutputStream data; + writeBigMessage(data); + + // Run the test. + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + auto callback = [&](MessageReader& reader) { + return false; + }; + BufferedMessageStream stream(*pipe.ends[0], callback, 16); + + // Arrange to read a big message. + auto readPromise = stream.MessageStream::tryReadMessage(); + KJ_EXPECT(!readPromise.poll(waitScope)); + + auto remainingData = data.getArray(); + + // Write 5 bytes. This won't even fulfill the first read's minBytes. + pipe.ends[1]->write(remainingData.begin(), 5).wait(waitScope); + remainingData = remainingData.slice(5, remainingData.size()); + KJ_EXPECT(!readPromise.poll(waitScope)); + + // Write 4 more. Now the MessageStream will only see the first word which contains the first + // segment size. This size is small so the MessageStream won't yet fall back to + // readEntireMessage(). + pipe.ends[1]->write(remainingData.begin(), 4).wait(waitScope); + remainingData = remainingData.slice(4, remainingData.size()); + KJ_EXPECT(!readPromise.poll(waitScope)); + + // Drip 10 more bytes. Now the MessageStream will realize that it needs to try + // readEntireMessage(). + pipe.ends[1]->write(remainingData.begin(), 10).wait(waitScope); + remainingData = remainingData.slice(10, remainingData.size()); + KJ_EXPECT(!readPromise.poll(waitScope)); + + // Give it all except the last byte. + pipe.ends[1]->write(remainingData.begin(), remainingData.size() - 1).wait(waitScope); + remainingData = remainingData.slice(remainingData.size() - 1, remainingData.size()); + KJ_EXPECT(!readPromise.poll(waitScope)); + + // Finish it off. + pipe.ends[1]->write(remainingData.begin(), 1).wait(waitScope); + KJ_ASSERT(readPromise.poll(waitScope)); + + auto msg = readPromise.wait(waitScope); + checkTestMessage(KJ_ASSERT_NONNULL(msg)->getRoot()); +} + +KJ_TEST("BufferedMessageStream many small messages") { + // Encode input data. + kj::VectorOutputStream data; + + for (auto i: kj::zeroTo(16)) { + // Intentionally make these 5 words each so they cross buffer boundaries. + writeSmallMessage(data, kj::str("12345678-", i)); + KJ_EXPECT(data.getArray().size() / sizeof(word) == (i+1) * 5); + } + + // Run the test. + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto pipe = kj::newTwoWayPipe(); + auto writePromise = pipe.ends[1]->write(data.getArray().begin(), data.getArray().size()) + .then([&]() { + // Write some garbage at the end. + return pipe.ends[1]->write("bogus", 5); + }).then([&]() { + // EOF. + return pipe.ends[1]->shutdownWrite(); + }).eagerlyEvaluate(nullptr); + + uint callbackCallCount = 0; + auto callback = [&](MessageReader& reader) { + ++callbackCallCount; + return false; + }; + + BufferedMessageStream stream(*pipe.ends[0], callback, 16); + + for (auto i: kj::zeroTo(16)) { + // Intentionally make these 5 words each so they cross buffer boundaries. + expectSmallMessage(stream, kj::str("12345678-", i), waitScope); + KJ_EXPECT(callbackCallCount == i + 1); + } + + KJ_EXPECT_THROW(DISCONNECTED, stream.MessageStream::tryReadMessage().wait(waitScope)); + KJ_EXPECT(callbackCallCount == 16); +} + +// TODO(test): We should probably test BufferedMessageStream's FD handling here... but really it +// gets tested well enough by rpc-twoparty-test. + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/c++/src/capnp/serialize-async.c++ b/c++/src/capnp/serialize-async.c++ index c23f562df8..45eb1846ae 100644 --- a/c++/src/capnp/serialize-async.c++ +++ b/c++/src/capnp/serialize-async.c++ @@ -19,8 +19,21 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +// Includes just for need SOL_SOCKET and SO_SNDBUF +#if _WIN32 +#include + +#include +#include +#include +#else +#include +#endif + #include "serialize-async.h" +#include "serialize.h" #include +#include namespace capnp { @@ -35,6 +48,10 @@ public: kj::Promise read(kj::AsyncInputStream& inputStream, kj::ArrayPtr scratchSpace); + kj::Promise> readWithFds( + kj::AsyncCapabilityStream& inputStream, + kj::ArrayPtr fds, kj::ArrayPtr scratchSpace); + // implements MessageReader ---------------------------------------- kj::ArrayPtr getSegment(uint id) override { @@ -71,15 +88,35 @@ kj::Promise AsyncMessageReader::read(kj::AsyncInputStream& inputStream, return false; } else if (n < sizeof(firstWord)) { // EOF in first word. - KJ_FAIL_REQUIRE("Premature EOF.") { - return false; - } + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); + return false; } return readAfterFirstWord(inputStream, scratchSpace).then([]() { return true; }); }); } +kj::Promise> AsyncMessageReader::readWithFds( + kj::AsyncCapabilityStream& inputStream, kj::ArrayPtr fds, + kj::ArrayPtr scratchSpace) { + return inputStream.tryReadWithFds(firstWord, sizeof(firstWord), sizeof(firstWord), + fds.begin(), fds.size()) + .then([this,&inputStream,KJ_CPCAP(scratchSpace)] + (kj::AsyncCapabilityStream::ReadResult result) mutable + -> kj::Promise> { + if (result.byteCount == 0) { + return kj::Maybe(nullptr); + } else if (result.byteCount < sizeof(firstWord)) { + // EOF in first word. + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); + return kj::Maybe(nullptr); + } + + return readAfterFirstWord(inputStream, scratchSpace) + .then([result]() -> kj::Maybe { return result.capCount; }); + }); +} + kj::Promise AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream, kj::ArrayPtr scratchSpace) { if (segmentCount() == 0) { @@ -152,24 +189,57 @@ kj::Promise> readMessage( kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr scratchSpace) { auto reader = kj::heap(options); auto promise = reader->read(input, scratchSpace); - return promise.then(kj::mvCapture(reader, [](kj::Own&& reader, bool success) { - KJ_REQUIRE(success, "Premature EOF.") { break; } + return promise.then([reader = kj::mv(reader)](bool success) mutable -> kj::Own { + if (!success) { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); + } return kj::mv(reader); - })); + }); } kj::Promise>> tryReadMessage( kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr scratchSpace) { auto reader = kj::heap(options); auto promise = reader->read(input, scratchSpace); - return promise.then(kj::mvCapture(reader, - [](kj::Own&& reader, bool success) -> kj::Maybe> { + return promise.then([reader = kj::mv(reader)](bool success) mutable + -> kj::Maybe> { if (success) { return kj::mv(reader); } else { return nullptr; } - })); + }); +} + +kj::Promise readMessage( + kj::AsyncCapabilityStream& input, kj::ArrayPtr fdSpace, + ReaderOptions options, kj::ArrayPtr scratchSpace) { + auto reader = kj::heap(options); + auto promise = reader->readWithFds(input, fdSpace, scratchSpace); + return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe nfds) mutable + -> MessageReaderAndFds { + KJ_IF_MAYBE(n, nfds) { + return { kj::mv(reader), fdSpace.slice(0, *n) }; + } else { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); + return { kj::mv(reader), nullptr }; + } + }); +} + +kj::Promise> tryReadMessage( + kj::AsyncCapabilityStream& input, kj::ArrayPtr fdSpace, + ReaderOptions options, kj::ArrayPtr scratchSpace) { + auto reader = kj::heap(options); + auto promise = reader->readWithFds(input, fdSpace, scratchSpace); + return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe nfds) mutable + -> kj::Maybe { + KJ_IF_MAYBE(n, nfds) { + return MessageReaderAndFds { kj::mv(reader), fdSpace.slice(0, *n) }; + } else { + return nullptr; + } + }); } // ======================================================================================= @@ -183,38 +253,590 @@ struct WriteArrays { kj::Array> pieces; }; -} // namespace +inline size_t tableSizeForSegments(size_t segmentsSize) { + return (segmentsSize + 2) & ~size_t(1); +} -kj::Promise writeMessage(kj::AsyncOutputStream& output, - kj::ArrayPtr> segments) { +// Helper function that allocates and fills the pointed-to table with info about the segments and +// populates the pieces array with pointers to the segments. +void fillWriteArraysWithMessage(kj::ArrayPtr> segments, + kj::ArrayPtr<_::WireValue> table, + kj::ArrayPtr> pieces) { KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); - WriteArrays arrays; - arrays.table = kj::heapArray<_::WireValue>((segments.size() + 2) & ~size_t(1)); - // We write the segment count - 1 because this makes the first word zero for single-segment // messages, improving compression. We don't bother doing this with segment sizes because // one-word segments are rare anyway. - arrays.table[0].set(segments.size() - 1); + table[0].set(segments.size() - 1); for (uint i = 0; i < segments.size(); i++) { - arrays.table[i + 1].set(segments[i].size()); + table[i + 1].set(segments[i].size()); } if (segments.size() % 2 == 0) { // Set padding byte. - arrays.table[segments.size() + 1].set(0); + table[segments.size() + 1].set(0); } - arrays.pieces = kj::heapArray>(segments.size() + 1); - arrays.pieces[0] = arrays.table.asBytes(); - + KJ_ASSERT(pieces.size() == segments.size() + 1, "incorrectly sized pieces array during write"); + pieces[0] = table.asBytes(); for (uint i = 0; i < segments.size(); i++) { - arrays.pieces[i + 1] = segments[i].asBytes(); + pieces[i + 1] = segments[i].asBytes(); } +} + +template +kj::Promise writeMessageImpl(kj::ArrayPtr> segments, + WriteFunc&& writeFunc) { + KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); + + WriteArrays arrays; + arrays.table = kj::heapArray<_::WireValue>(tableSizeForSegments(segments.size())); + arrays.pieces = kj::heapArray>(segments.size() + 1); + fillWriteArraysWithMessage(segments, arrays.table, arrays.pieces); - auto promise = output.write(arrays.pieces); + auto promise = writeFunc(arrays.pieces); // Make sure the arrays aren't freed until the write completes. - return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {})); + return promise.then([arrays=kj::mv(arrays)]() {}); +} + +template +kj::Promise writeMessagesImpl( + kj::ArrayPtr>> messages, WriteFunc&& writeFunc) { + KJ_REQUIRE(messages.size() > 0, "Tried to serialize zero messages."); + + // Determine how large the shared table and pieces arrays needs to be. + size_t tableSize = 0; + size_t piecesSize = 0; + for (auto& segments : messages) { + tableSize += tableSizeForSegments(segments.size()); + piecesSize += segments.size() + 1; + } + auto table = kj::heapArray<_::WireValue>(tableSize); + auto pieces = kj::heapArray>(piecesSize); + + size_t tableValsWritten = 0; + size_t piecesWritten = 0; + for (auto i : kj::indices(messages)) { + const size_t tableValsToWrite = tableSizeForSegments(messages[i].size()); + const size_t piecesToWrite = messages[i].size() + 1; + fillWriteArraysWithMessage( + messages[i], + table.slice(tableValsWritten, tableValsWritten + tableValsToWrite), + pieces.slice(piecesWritten, piecesWritten + piecesToWrite)); + tableValsWritten += tableValsToWrite; + piecesWritten += piecesToWrite; + } + + auto promise = writeFunc(pieces); + return promise.attach(kj::mv(table), kj::mv(pieces)); +} + +} // namespace + +kj::Promise writeMessage(kj::AsyncOutputStream& output, + kj::ArrayPtr> segments) { + return writeMessageImpl(segments, + [&](kj::ArrayPtr> pieces) { + return output.write(pieces); + }); +} + +kj::Promise writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr fds, + kj::ArrayPtr> segments) { + return writeMessageImpl(segments, + [&](kj::ArrayPtr> pieces) { + return output.writeWithFds(pieces[0], pieces.slice(1, pieces.size()), fds); + }); +} + +kj::Promise writeMessages( + kj::AsyncOutputStream& output, + kj::ArrayPtr>> messages) { + return writeMessagesImpl(messages, + [&](kj::ArrayPtr> pieces) { + return output.write(pieces); + }); +} + +kj::Promise writeMessages( + kj::AsyncOutputStream& output, kj::ArrayPtr builders) { + auto messages = kj::heapArray>>(builders.size()); + for (auto i : kj::indices(builders)) { + messages[i] = builders[i]->getSegmentsForOutput(); + } + return writeMessages(output, messages); +} + +kj::Promise MessageStream::writeMessages(kj::ArrayPtr messages) { + if (messages.size() == 0) return kj::READY_NOW; + kj::ArrayPtr remainingMessages; + + auto writeProm = [&]() { + if (messages[0].fds.size() > 0) { + // We have a message with FDs attached. We need to write any bare messages we've accumulated, + // if any, then write the message with FDs, then continue on with any remaining messages. + + if (messages.size() > 1) { + remainingMessages = messages.slice(1, messages.size()); + } + + return writeMessage(messages[0].fds, messages[0].segments); + } else { + kj::Vector>> bareMessages(messages.size()); + for(auto i : kj::zeroTo(messages.size())) { + if (messages[i].fds.size() > 0) { + break; + } + bareMessages.add(messages[i].segments); + } + + if (messages.size() > bareMessages.size()) { + remainingMessages = messages.slice(bareMessages.size(), messages.size()); + } + return writeMessages(bareMessages.asPtr()).attach(kj::mv(bareMessages)); + } + }(); + + if (remainingMessages.size() > 0) { + return writeProm.then([this, remainingMessages]() mutable { + return writeMessages(remainingMessages); + }); + } else { + return writeProm; + } +} + +kj::Promise MessageStream::writeMessages(kj::ArrayPtr builders) { + auto messages = kj::heapArray>>(builders.size()); + for (auto i : kj::indices(builders)) { + messages[i] = builders[i]->getSegmentsForOutput(); + } + return writeMessages(messages); +} + +AsyncIoMessageStream::AsyncIoMessageStream(kj::AsyncIoStream& stream) + : stream(stream) {}; + +kj::Promise> AsyncIoMessageStream::tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options, + kj::ArrayPtr scratchSpace) { + return capnp::tryReadMessage(stream, options, scratchSpace) + .then([](kj::Maybe> maybeReader) -> kj::Maybe { + KJ_IF_MAYBE(reader, maybeReader) { + return MessageReaderAndFds { kj::mv(*reader), nullptr }; + } else { + return nullptr; + } + }); +} + +kj::Promise AsyncIoMessageStream::writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) { + return capnp::writeMessage(stream, segments); +} + +kj::Promise AsyncIoMessageStream::writeMessages( + kj::ArrayPtr>> messages) { + return capnp::writeMessages(stream, messages); +} + +kj::Maybe getSendBufferSize(kj::AsyncIoStream& stream) { + // TODO(perf): It might be nice to have a tryGetsockopt() that doesn't require catching + // exceptions? + int bufSize = 0; + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + uint len = sizeof(int); + stream.getsockopt(SOL_SOCKET, SO_SNDBUF, &bufSize, &len); + KJ_ASSERT(len == sizeof(bufSize)) { break; } + })) { + if (exception->getType() != kj::Exception::Type::UNIMPLEMENTED) { + // TODO(someday): Figure out why getting SO_SNDBUF sometimes throws EINVAL. I suspect it + // happens when the remote side has closed their read end, meaning we no longer have + // a send buffer, but I don't know what is the best way to verify that that was actually + // the reason. I'd prefer not to ignore EINVAL errors in general. + + // kj::throwRecoverableException(kj::mv(*exception)); + } + return nullptr; + } + return bufSize; +} + +kj::Promise AsyncIoMessageStream::end() { + stream.shutdownWrite(); + return kj::READY_NOW; +} + +kj::Maybe AsyncIoMessageStream::getSendBufferSize() { + return capnp::getSendBufferSize(stream); +} + +AsyncCapabilityMessageStream::AsyncCapabilityMessageStream(kj::AsyncCapabilityStream& stream) + : stream(stream) {}; + +kj::Promise> AsyncCapabilityMessageStream::tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options, + kj::ArrayPtr scratchSpace) { + return capnp::tryReadMessage(stream, fdSpace, options, scratchSpace); +} + +kj::Promise AsyncCapabilityMessageStream::writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) { + return capnp::writeMessage(stream, fds, segments); +} + +kj::Promise AsyncCapabilityMessageStream::writeMessages( + kj::ArrayPtr>> messages) { + return capnp::writeMessages(stream, messages); +} + +kj::Maybe AsyncCapabilityMessageStream::getSendBufferSize() { + return capnp::getSendBufferSize(stream); +} + +kj::Promise AsyncCapabilityMessageStream::end() { + stream.shutdownWrite(); + return kj::READY_NOW; +} + +kj::Promise> MessageStream::readMessage( + ReaderOptions options, + kj::ArrayPtr scratchSpace) { + return tryReadMessage(options, scratchSpace).then([](kj::Maybe> maybeResult) { + KJ_IF_MAYBE(result, maybeResult) { + return kj::mv(*result); + } else { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); + KJ_UNREACHABLE; + } + }); +} + +kj::Promise>> MessageStream::tryReadMessage( + ReaderOptions options, + kj::ArrayPtr scratchSpace) { + return tryReadMessage(nullptr, options, scratchSpace) + .then([](auto maybeReaderAndFds) -> kj::Maybe> { + KJ_IF_MAYBE(readerAndFds, maybeReaderAndFds) { + return kj::mv(readerAndFds->reader); + } else { + return nullptr; + } + }); +} + +kj::Promise MessageStream::readMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options, kj::ArrayPtr scratchSpace) { + return tryReadMessage(fdSpace, options, scratchSpace).then([](auto maybeResult) { + KJ_IF_MAYBE(result, maybeResult) { + return kj::mv(*result); + } else { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); + KJ_UNREACHABLE; + } + }); +} + +// ======================================================================================= + +class BufferedMessageStream::MessageReaderImpl: public FlatArrayMessageReader { +public: + MessageReaderImpl(BufferedMessageStream& parent, kj::ArrayPtr data, + ReaderOptions options) + : FlatArrayMessageReader(data, options), state(&parent) { + KJ_DASSERT(!parent.hasOutstandingShortLivedMessage); + parent.hasOutstandingShortLivedMessage = true; + } + MessageReaderImpl(kj::Array&& ownBuffer, ReaderOptions options) + : FlatArrayMessageReader(ownBuffer, options), state(kj::mv(ownBuffer)) {} + MessageReaderImpl(kj::ArrayPtr scratchBuffer, ReaderOptions options) + : FlatArrayMessageReader(scratchBuffer, options) {} + + ~MessageReaderImpl() noexcept(false) { + KJ_IF_MAYBE(parent, state.tryGet()) { + (*parent)->hasOutstandingShortLivedMessage = false; + } + } + +private: + kj::OneOf> state; + // * BufferedMessageStream* if this reader aliases the original buffer. + // * kj::Array if this reader owns its own backing buffer. +}; + +BufferedMessageStream::BufferedMessageStream( + kj::AsyncIoStream& stream, IsShortLivedCallback isShortLivedCallback, + size_t bufferSizeInWords) + : stream(stream), isShortLivedCallback(kj::mv(isShortLivedCallback)), + buffer(kj::heapArray(bufferSizeInWords)), + beginData(buffer.begin()), beginAvailable(buffer.asBytes().begin()) {} + +BufferedMessageStream::BufferedMessageStream( + kj::AsyncCapabilityStream& stream, IsShortLivedCallback isShortLivedCallback, + size_t bufferSizeInWords) + : stream(stream), capStream(stream), isShortLivedCallback(kj::mv(isShortLivedCallback)), + buffer(kj::heapArray(bufferSizeInWords)), + beginData(buffer.begin()), beginAvailable(buffer.asBytes().begin()) {} + +kj::Promise> BufferedMessageStream::tryReadMessage( + kj::ArrayPtr fdSpace, ReaderOptions options, kj::ArrayPtr scratchSpace) { + return tryReadMessageImpl(fdSpace, 0, options, scratchSpace); +} + +kj::Promise BufferedMessageStream::writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) { + KJ_IF_MAYBE(cs, capStream) { + return capnp::writeMessage(*cs, fds, segments); + } else { + return capnp::writeMessage(stream, segments); + } +} + +kj::Promise BufferedMessageStream::writeMessages( + kj::ArrayPtr>> messages) { + return capnp::writeMessages(stream, messages); +} + +kj::Maybe BufferedMessageStream::getSendBufferSize() { + return capnp::getSendBufferSize(stream); +} + +kj::Promise BufferedMessageStream::end() { + stream.shutdownWrite(); + return kj::READY_NOW; +} + +kj::Promise> BufferedMessageStream::tryReadMessageImpl( + kj::ArrayPtr fdSpace, size_t fdsSoFar, + ReaderOptions options, kj::ArrayPtr scratchSpace) { + KJ_REQUIRE(!hasOutstandingShortLivedMessage, + "can't read another message while the previous short-lived message still exists"); + + kj::byte* beginDataBytes = reinterpret_cast(beginData); + size_t dataByteSize = beginAvailable - beginDataBytes; + kj::ArrayPtr data = kj::arrayPtr(beginData, dataByteSize / sizeof(word)); + + size_t expected = expectedSizeInWordsFromPrefix(data); + + if (!leftoverFds.empty() && expected * sizeof(word) == dataByteSize) { + // We're about to return a message that consumes the rest of the data in the buffer, and + // `leftoverFds` is non-empty. Those FDs are considered attached to whatever message contains + // the last byte in the buffer. That's us! Let's consume them. + + // `fdsSoFar` must be empty here because we shouldn't have performed any reads while + // `leftoverFds` was non-empty, so there shouldn't have been any other chance to add FDs to + // `fdSpace`. + KJ_ASSERT(fdsSoFar == 0); + + fdsSoFar = kj::min(leftoverFds.size(), fdSpace.size()); + for (auto i: kj::zeroTo(fdsSoFar)) { + fdSpace[i] = kj::mv(leftoverFds[i]); + } + leftoverFds.clear(); + } + + if (expected <= data.size()) { + // The buffer contains at least one whole message, which we can just return without reading + // any more data. + + auto msgData = kj::arrayPtr(beginData, expected); + auto reader = kj::heap(*this, msgData, options); + if (!isShortLivedCallback(*reader)) { + // This message is long-lived, so we must make a copy to get it out of our buffer. + if (msgData.size() <= scratchSpace.size()) { + // Oh hey, we can use the provided scratch space. + memcpy(scratchSpace.begin(), msgData.begin(), msgData.asBytes().size()); + reader = kj::heap(scratchSpace, options); + } else { + auto ownMsgData = kj::heapArray(msgData.size()); + memcpy(ownMsgData.begin(), msgData.begin(), msgData.asBytes().size()); + reader = kj::heap(kj::mv(ownMsgData), options); + } + } + + beginData += expected; + if (reinterpret_cast(beginData) == beginAvailable) { + // The buffer is empty. Let's opportunistically reset the pointers. + beginData = buffer.begin(); + beginAvailable = buffer.asBytes().begin(); + } else if (fdsSoFar > 0) { + // The buffer is NOT empty, and we received FDs when we were filling it. These FDs must + // actually belong to the last message in the buffer, because when the OS returns FDs + // attached to a read, it will make sure the read does not extend past the last byte to + // which those FDs were attached. + // + // So, we must set these FDs aside for the moment. + for (auto i: kj::zeroTo(fdsSoFar)) { + leftoverFds.add(kj::mv(fdSpace[i])); + } + fdsSoFar = 0; + } + + return kj::Maybe(MessageReaderAndFds { + kj::mv(reader), + fdSpace.slice(0, fdsSoFar) + }); + } + + // At this point, the buffer doesn't contain a complete message. We are going to need to perform + // a read. + + if (expected > buffer.size() / 2 || fdsSoFar > 0) { + // Read this message into its own separately-allocated buffer. We do this for: + // - Big messages, because they might not fit in the buffer and because big messages are + // almost certainly going to be long-lived and so would require a copy later anyway. + // - Messages where we've already received some FDs, because these are also almost certainly + // long-lived, and we want to avoid accidentally reading into the next message since we + // could end up receiving FDs that were intended for that one. + // + // Optimization note: You might argue that if the expected size is more than half the buffer, + // but still less than the *whole* buffer, then we should still try to read into the buffer + // first. However, keep in mind that in the RPC system, all short-lived messages are + // relatively small, and hence we can assume that since this is a large message, it will + // end up being long-lived. Long-lived messages need to be copied out into their own buffer + // at some point anyway. So we might as well go ahead and allocate that separate buffer + // now, and read directly into it, rather than try to use the shared buffer. We choose to + // use buffer.size() / 2 as the cutoff because that ensures that we won't try to move the + // bytes of a known-large message to the beginning of the buffer (see next if() after this + // one). + + auto prefix = kj::arrayPtr(beginDataBytes, dataByteSize); + + // We are consuming everything in the buffer here, so we can reset the pointers so the + // buffer appears empty on the next message read after this. + beginData = buffer.begin(); + beginAvailable = buffer.asBytes().begin(); + + return readEntireMessage(prefix, expected, fdSpace, fdsSoFar, options); + } + + // Set minBytes to at least complete the current message. + size_t minBytes = expected * sizeof(word) - dataByteSize; + + // minBytes must be less than half the buffer otherwise we would have taken the + // readEntireMessage() branch above. + KJ_DASSERT(minBytes <= buffer.asBytes().size() / 2); + + // Set maxBytes to the space we have available in the buffer. + size_t maxBytes = buffer.asBytes().end() - beginAvailable; + + if (maxBytes < buffer.asBytes().size() / 2) { + // We have less than half the buffer remaining to read into. Move the buffered data to the + // beginning of the buffer to make more space. + memmove(buffer.begin(), beginData, dataByteSize); + beginData = buffer.begin(); + beginDataBytes = buffer.asBytes().begin(); + beginAvailable = beginDataBytes + dataByteSize; + + maxBytes = buffer.asBytes().end() - beginAvailable; + } + + // maxBytes must now be more than half the buffer, because if it weren't we would have moved + // the existing data above, and the existing data cannot be more than half the buffer because + // if it were we would have taken the readEntireMesage() path earlier. + KJ_DASSERT(maxBytes >= buffer.asBytes().size() / 2); + + // Since minBytes is less that half the buffer and maxBytes is more then half, minBytes is + // definitely less than maxBytes. + KJ_DASSERT(minBytes <= maxBytes); + + // Read from underlying stream. + return tryReadWithFds(beginAvailable, minBytes, maxBytes, + fdSpace.begin() + fdsSoFar, fdSpace.size() - fdsSoFar) + .then([this,minBytes,fdSpace,fdsSoFar,options,scratchSpace] + (kj::AsyncCapabilityStream::ReadResult result) mutable + -> kj::Promise> { + // Account for new data received in the buffer. + beginAvailable += result.byteCount; + + if (result.byteCount < minBytes) { + // Didn't reach minBytes, so we must have hit EOF. That's legal as long as it happened on + // a clean message boundray. + if (beginAvailable > reinterpret_cast(beginData)) { + // We had received a partial message before EOF, so this should be considered an error. + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, + "stream disconnected prematurely")); + } + return kj::Maybe(nullptr); + } + + // Loop! + return tryReadMessageImpl(fdSpace, fdsSoFar + result.capCount, options, scratchSpace); + }); +} + +kj::Promise> BufferedMessageStream::readEntireMessage( + kj::ArrayPtr prefix, size_t expectedSizeInWords, + kj::ArrayPtr fdSpace, size_t fdsSoFar, + ReaderOptions options) { + KJ_REQUIRE(expectedSizeInWords <= options.traversalLimitInWords, + "incoming RPC message exceeds size limit"); + + auto msgBuffer = kj::heapArray(expectedSizeInWords); + + memcpy(msgBuffer.asBytes().begin(), prefix.begin(), prefix.size()); + + size_t bytesRemaining = msgBuffer.asBytes().size() - prefix.size(); + + // TODO(perf): If we had scatter-read API support, we could optimistically try to read additional + // bytes into the shared buffer, to save syscalls when a big message is immediately followed + // by small messages. + auto promise = tryReadWithFds( + msgBuffer.asBytes().begin() + prefix.size(), bytesRemaining, bytesRemaining, + fdSpace.begin() + fdsSoFar, fdSpace.size() - fdsSoFar); + return promise + .then([this, msgBuffer = kj::mv(msgBuffer), fdSpace, fdsSoFar, options, bytesRemaining] + (kj::AsyncCapabilityStream::ReadResult result) mutable + -> kj::Promise> { + fdsSoFar += result.capCount; + + if (result.byteCount < bytesRemaining) { + // Received EOF during message. + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "stream disconnected prematurely")); + return kj::Maybe(nullptr); + } + + size_t newExpectedSize = expectedSizeInWordsFromPrefix(msgBuffer); + if (newExpectedSize > msgBuffer.size()) { + // Unfortunately, the predicted size increased. This can happen if the segment table had + // not been fully received when we generated the first prediction. This should be rare + // (most segment tables are small and should be received all at once), but in this case we + // will need to make a whole new copy of the message. + // + // We recurse here, but this should never recurse more than once, since we should always + // have the entire segment table by this point and therefore the expected size is now final. + // + // TODO(perf): Technically it's guaranteed that the original expectation should have stopped + // at the boundary between two segments, so with a clever MesnsageReader implementation + // we could actually read the rest of the message into a second buffer, avoiding the copy. + // Unclear if it's worth the effort to implement this. + return readEntireMessage(msgBuffer.asBytes(), newExpectedSize, fdSpace, fdsSoFar, options); + } + + return kj::Maybe(MessageReaderAndFds { + kj::heap(kj::mv(msgBuffer), options), + fdSpace.slice(0, fdsSoFar) + }); + }); +} + +kj::Promise BufferedMessageStream::tryReadWithFds( + void* buffer, size_t minBytes, size_t maxBytes, kj::AutoCloseFd* fdBuffer, size_t maxFds) { + KJ_IF_MAYBE(cs, capStream) { + return cs->tryReadWithFds(buffer, minBytes, maxBytes, fdBuffer, maxFds); + } else { + // Regular byte stream, no FDs. + return stream.tryRead(buffer, minBytes, maxBytes) + .then([](size_t amount) mutable -> kj::AsyncCapabilityStream::ReadResult { + return { amount, 0 }; + }); + } } } // namespace capnp diff --git a/c++/src/capnp/serialize-async.h b/c++/src/capnp/serialize-async.h index a16bfd8975..cd661d7809 100644 --- a/c++/src/capnp/serialize-async.h +++ b/c++/src/capnp/serialize-async.h @@ -19,38 +19,277 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_SERIALIZE_ASYNC_H_ -#define CAPNP_SERIALIZE_ASYNC_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include +#include #include "message.h" +CAPNP_BEGIN_HEADER + namespace capnp { +struct MessageReaderAndFds { + kj::Own reader; + kj::ArrayPtr fds; +}; + +struct MessageAndFds { + kj::ArrayPtr> segments; + kj::ArrayPtr fds; +}; + +class MessageStream { + // Interface over which messages can be sent and received; virtualizes + // the functionality above. +public: + virtual kj::Promise> tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr) = 0; + // Read a message that may also have file descriptors attached, e.g. from a Unix socket with + // SCM_RIGHTS. Returns null on EOF. + // + // `scratchSpace`, if provided, must remain valid until the returned MessageReader is destroyed. + + kj::Promise>> tryReadMessage( + ReaderOptions options = ReaderOptions(), + kj::ArrayPtr scratchSpace = nullptr); + // Equivalent to the above with fdSpace = nullptr. + + kj::Promise readMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr); + kj::Promise> readMessage( + ReaderOptions options = ReaderOptions(), + kj::ArrayPtr scratchSpace = nullptr); + // Like tryReadMessage, but throws an exception on EOF. + + virtual kj::Promise writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) + KJ_WARN_UNUSED_RESULT = 0; + kj::Promise writeMessage( + kj::ArrayPtr fds, + MessageBuilder& builder) + KJ_WARN_UNUSED_RESULT; + // Write a message with FDs attached, e.g. to a Unix socket with SCM_RIGHTS. + // The parameters must remain valid until the returned promise resolves. + + kj::Promise writeMessage( + kj::ArrayPtr> segments) + KJ_WARN_UNUSED_RESULT; + kj::Promise writeMessage(MessageBuilder& builder) + KJ_WARN_UNUSED_RESULT; + // Equivalent to the above with fds = nullptr. + + kj::Promise writeMessages( + kj::ArrayPtr messages) + KJ_WARN_UNUSED_RESULT; + virtual kj::Promise writeMessages( + kj::ArrayPtr>> messages) + KJ_WARN_UNUSED_RESULT = 0; + kj::Promise writeMessages(kj::ArrayPtr builders) + KJ_WARN_UNUSED_RESULT; + // Similar to the above, but for writing multiple messages at a time in a batch. + + virtual kj::Maybe getSendBufferSize() = 0; + // Get the size of the underlying send buffer, if applicable. The RPC + // system uses this as a hint for flow control purposes; see: + // + // https://capnproto.org/news/2020-04-23-capnproto-0.8.html#multi-stream-flow-control + // + // ...for a more thorough explanation of how this is used. Implementations + // may return nullptr if they do not have access to this information, or if + // the underlying transport does not use a congestion window. + + virtual kj::Promise end() = 0; + // Cleanly shut down just the write end of the transport, while keeping the read end open. + +}; + +class AsyncIoMessageStream final: public MessageStream { + // A MessageStream that wraps an AsyncIoStream. +public: + explicit AsyncIoMessageStream(kj::AsyncIoStream& stream); + + // Implements MessageStream + kj::Promise> tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr) override; + kj::Promise writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) override; + kj::Promise writeMessages( + kj::ArrayPtr>> messages) override; + kj::Maybe getSendBufferSize() override; + + kj::Promise end() override; + + // Make sure the overridden virtual methods don't hide the non-virtual methods. + using MessageStream::tryReadMessage; + using MessageStream::writeMessage; +private: + kj::AsyncIoStream& stream; +}; + +class AsyncCapabilityMessageStream final: public MessageStream { + // A MessageStream that wraps an AsyncCapabilityStream. +public: + explicit AsyncCapabilityMessageStream(kj::AsyncCapabilityStream& stream); + + // Implements MessageStream + kj::Promise> tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr) override; + kj::Promise writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) override; + kj::Promise writeMessages( + kj::ArrayPtr>> messages) override; + kj::Maybe getSendBufferSize() override; + kj::Promise end() override; + + // Make sure the overridden virtual methods don't hide the non-virtual methods. + using MessageStream::tryReadMessage; + using MessageStream::writeMessage; +private: + kj::AsyncCapabilityStream& stream; +}; + +class BufferedMessageStream final: public MessageStream { + // A MessageStream that reads into a buffer in the hopes of receiving multiple messages in a + // single system call. Compared to the other implementations, this implementation is expected + // to be faster when reading from an OS stream (but probably not when reading from an in-memory + // async pipe). It has the down sides of using more memory (for the buffer) and requiring extra + // copies. + +public: + using IsShortLivedCallback = kj::Function; + // Callback function which decides whether a message will be "short-lived", meaning that it is + // guaranteed to be dropped before the next message is read. The stream uses this as an + // optimization to decide whether it can return a MessageReader pointing into the buffer, which + // will be reused for future reads. For long-lived messages, the stream must copy the content + // into a separate buffer. + + explicit BufferedMessageStream( + kj::AsyncIoStream& stream, IsShortLivedCallback isShortLivedCallback, + size_t bufferSizeInWords = 8192); + explicit BufferedMessageStream( + kj::AsyncCapabilityStream& stream, IsShortLivedCallback isShortLivedCallback, + size_t bufferSizeInWords = 8192); + + // Implements MessageStream + kj::Promise> tryReadMessage( + kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr) override; + kj::Promise writeMessage( + kj::ArrayPtr fds, + kj::ArrayPtr> segments) override; + kj::Promise writeMessages( + kj::ArrayPtr>> messages) override; + kj::Maybe getSendBufferSize() override; + kj::Promise end() override; + + // Make sure the overridden virtual methods don't hide the non-virtual methods. + using MessageStream::tryReadMessage; + using MessageStream::writeMessage; + +private: + kj::AsyncIoStream& stream; + kj::Maybe capStream; + IsShortLivedCallback isShortLivedCallback; + + kj::Array buffer; + + word* beginData; + // Pointer to location in `buffer` where the next message starts. This is always on a word + // boundray since messages are always a whole number of words. + + kj::byte* beginAvailable; + // Pointer to the location in `buffer` where unused buffer space begins, i.e. immediately after + // the last byte read. + + kj::Vector leftoverFds; + // FDs which were accidentally read too early. These are always connected to the last message + // in the buffer, since the OS would not have allowed us to read past that point. + + bool hasOutstandingShortLivedMessage = false; + + kj::Promise> tryReadMessageImpl( + kj::ArrayPtr fdSpace, size_t fdsSoFar, + ReaderOptions options, kj::ArrayPtr scratchSpace); + + kj::Promise> readEntireMessage( + kj::ArrayPtr prefix, size_t expectedSizeInWords, + kj::ArrayPtr fdSpace, size_t fdsSoFar, + ReaderOptions options); + // Given a message prefix and expected size of the whole message, read the entire message into + // a single array and return it. + + kj::Promise tryReadWithFds( + void* buffer, size_t minBytes, size_t maxBytes, kj::AutoCloseFd* fdBuffer, size_t maxFds); + // Executes AsyncCapabilityStream::tryReadWithFds() on the underlying stream, or falls back to + // AsyncIoStream::tryRead() if it's not a capability stream. + + class MessageReaderImpl; +}; + +// ----------------------------------------------------------------------------- +// Stand-alone functions for reading & writing messages on AsyncInput/AsyncOutputStreams. +// +// In general, foo(stream, ...) is equivalent to +// AsyncIoMessageStream(stream).foo(...), whenever the latter would type check. +// +// The first argument must remain valid until the returned promise resolves +// (or is canceled). + kj::Promise> readMessage( kj::AsyncInputStream& input, ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr); -// Read a message asynchronously. -// -// `input` must remain valid until the returned promise resolves (or is canceled). -// -// `scratchSpace`, if provided, must remain valid until the returned MessageReader is destroyed. kj::Promise>> tryReadMessage( kj::AsyncInputStream& input, ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr); -// Like `readMessage` but returns null on EOF. kj::Promise writeMessage(kj::AsyncOutputStream& output, kj::ArrayPtr> segments) KJ_WARN_UNUSED_RESULT; + kj::Promise writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) KJ_WARN_UNUSED_RESULT; -// Write asynchronously. The parameters must remain valid until the returned promise resolves. + +// ----------------------------------------------------------------------------- +// Stand-alone versions that support FD passing. +// +// For each of these, `foo(stream, ...)` is equivalent to +// `AsyncCapabilityMessageStream(stream).foo(...)`. + +kj::Promise readMessage( + kj::AsyncCapabilityStream& input, kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr); + +kj::Promise> tryReadMessage( + kj::AsyncCapabilityStream& input, kj::ArrayPtr fdSpace, + ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr); + +kj::Promise writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr fds, + kj::ArrayPtr> segments) + KJ_WARN_UNUSED_RESULT; +kj::Promise writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr fds, + MessageBuilder& builder) + KJ_WARN_UNUSED_RESULT; + + +// ----------------------------------------------------------------------------- +// Stand-alone functions for writing multiple messages at once on AsyncOutputStreams. + +kj::Promise writeMessages(kj::AsyncOutputStream& output, + kj::ArrayPtr>> messages) + KJ_WARN_UNUSED_RESULT; + +kj::Promise writeMessages( + kj::AsyncOutputStream& output, kj::ArrayPtr builders) + KJ_WARN_UNUSED_RESULT; // ======================================================================================= // inline implementation details @@ -58,7 +297,24 @@ kj::Promise writeMessage(kj::AsyncOutputStream& output, MessageBuilder& bu inline kj::Promise writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) { return writeMessage(output, builder.getSegmentsForOutput()); } +inline kj::Promise writeMessage( + kj::AsyncCapabilityStream& output, kj::ArrayPtr fds, MessageBuilder& builder) { + return writeMessage(output, fds, builder.getSegmentsForOutput()); +} + +inline kj::Promise MessageStream::writeMessage(kj::ArrayPtr> segments) { + return writeMessage(nullptr, segments); +} + +inline kj::Promise MessageStream::writeMessage(MessageBuilder& builder) { + return writeMessage(builder.getSegmentsForOutput()); +} + +inline kj::Promise MessageStream::writeMessage( + kj::ArrayPtr fds, MessageBuilder& builder) { + return writeMessage(fds, builder.getSegmentsForOutput()); +} } // namespace capnp -#endif // CAPNP_SERIALIZE_ASYNC_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/serialize-packed-test.c++ b/c++/src/capnp/serialize-packed-test.c++ index 78a5942809..b5a53f8cb8 100644 --- a/c++/src/capnp/serialize-packed-test.c++ +++ b/c++/src/capnp/serialize-packed-test.c++ @@ -86,10 +86,18 @@ private: std::string::size_type readPos; }; -void expectPacksTo(kj::ArrayPtr unpacked, kj::ArrayPtr packed) { +void expectPacksTo(kj::ArrayPtr unpackedUnaligned, kj::ArrayPtr packed) { TestPipe pipe; - EXPECT_EQ(unpacked.size(), computeUnpackedSizeInWords(packed) * sizeof(word)); + auto unpackedSizeInWords = computeUnpackedSizeInWords(packed); + EXPECT_EQ(unpackedUnaligned.size(), unpackedSizeInWords * sizeof(word)); + + // Make a guaranteed-to-be-aligned copy of the unpacked buffer. + kj::Array unpackedWords = kj::heapArray(unpackedSizeInWords); + if (unpackedUnaligned.size() != 0u) { + memcpy(unpackedWords.begin(), unpackedUnaligned.begin(), unpackedUnaligned.size()); + } + kj::ArrayPtr unpacked = unpackedWords.asBytes(); // ----------------------------------------------------------------- // write diff --git a/c++/src/capnp/serialize-packed.c++ b/c++/src/capnp/serialize-packed.c++ index 8416b65465..04fd42d91c 100644 --- a/c++/src/capnp/serialize-packed.c++ +++ b/c++/src/capnp/serialize-packed.c++ @@ -140,7 +140,7 @@ size_t PackedInputStream::tryRead(void* dst, size_t minBytes, size_t maxBytes) { return out - reinterpret_cast(dst); } - uint inRemaining = BUFFER_REMAINING; + size_t inRemaining = BUFFER_REMAINING; if (inRemaining >= runLength) { // Fast path. memcpy(out, in, runLength); @@ -266,7 +266,7 @@ void PackedInputStream::skip(size_t bytes) { bytes -= runLength; - uint inRemaining = BUFFER_REMAINING; + size_t inRemaining = BUFFER_REMAINING; if (inRemaining > runLength) { // Fast path. in += runLength; @@ -351,7 +351,8 @@ void PackedOutputStream::write(const void* src, size_t size) { // An all-zero word is followed by a count of consecutive zero words (not including the // first one). - // We can check a whole word at a time. + // We can check a whole word at a time. (Here is where we use the assumption that + // `src` is word-aligned.) const uint64_t* inWord = reinterpret_cast(in); // The count must fit it 1 byte, so limit to 255 words. diff --git a/c++/src/capnp/serialize-packed.h b/c++/src/capnp/serialize-packed.h index a71260ce1d..a0329b1300 100644 --- a/c++/src/capnp/serialize-packed.h +++ b/c++/src/capnp/serialize-packed.h @@ -19,15 +19,12 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_SERIALIZE_PACKED_H_ -#define CAPNP_SERIALIZE_PACKED_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "serialize.h" +CAPNP_BEGIN_HEADER + namespace capnp { namespace _ { // private @@ -38,7 +35,7 @@ class PackedInputStream: public kj::InputStream { public: explicit PackedInputStream(kj::BufferedInputStream& inner); - KJ_DISALLOW_COPY(PackedInputStream); + KJ_DISALLOW_COPY_AND_MOVE(PackedInputStream); ~PackedInputStream() noexcept(false); // implements InputStream ------------------------------------------ @@ -50,9 +47,10 @@ class PackedInputStream: public kj::InputStream { }; class PackedOutputStream: public kj::OutputStream { + // An output stream that packs data. Buffers passed to `write()` must be word-aligned. public: explicit PackedOutputStream(kj::BufferedOutputStream& inner); - KJ_DISALLOW_COPY(PackedOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(PackedOutputStream); ~PackedOutputStream() noexcept(false); // implements OutputStream ----------------------------------------- @@ -68,7 +66,7 @@ class PackedMessageReader: private _::PackedInputStream, public InputStreamMessa public: PackedMessageReader(kj::BufferedInputStream& inputStream, ReaderOptions options = ReaderOptions(), kj::ArrayPtr scratchSpace = nullptr); - KJ_DISALLOW_COPY(PackedMessageReader); + KJ_DISALLOW_COPY_AND_MOVE(PackedMessageReader); ~PackedMessageReader() noexcept(false); }; @@ -85,7 +83,7 @@ class PackedFdMessageReader: private kj::FdInputStream, private kj::BufferedInpu kj::ArrayPtr scratchSpace = nullptr); // Read a message from a file descriptor, taking ownership of the descriptor. - KJ_DISALLOW_COPY(PackedFdMessageReader); + KJ_DISALLOW_COPY_AND_MOVE(PackedFdMessageReader); ~PackedFdMessageReader() noexcept(false); }; @@ -127,4 +125,4 @@ inline void writePackedMessageToFd(int fd, MessageBuilder& builder) { } // namespace capnp -#endif // CAPNP_SERIALIZE_PACKED_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/serialize-test.c++ b/c++/src/capnp/serialize-test.c++ index 7643319435..d114358abd 100644 --- a/c++/src/capnp/serialize-test.c++ +++ b/c++/src/capnp/serialize-test.c++ @@ -19,6 +19,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include "serialize.h" #include #include diff --git a/c++/src/capnp/serialize-text-test.c++ b/c++/src/capnp/serialize-text-test.c++ index 5ee12a3cbb..8ac4285844 100644 --- a/c++/src/capnp/serialize-text-test.c++ +++ b/c++/src/capnp/serialize-text-test.c++ @@ -32,7 +32,7 @@ namespace capnp { namespace _ { // private namespace { -KJ_TEST("TestAllTypes") { +KJ_TEST("TextCodec TestAllTypes") { MallocMessageBuilder builder; initTestMessage(builder.initRoot()); @@ -66,7 +66,7 @@ KJ_TEST("TestAllTypes") { } } -KJ_TEST("TestDefaults") { +KJ_TEST("TextCodec TestDefaults") { MallocMessageBuilder builder; initTestMessage(builder.initRoot()); @@ -79,7 +79,7 @@ KJ_TEST("TestDefaults") { checkTestMessage(structReader); } -KJ_TEST("TestListDefaults") { +KJ_TEST("TextCodec TestListDefaults") { MallocMessageBuilder builder; initTestMessage(builder.initRoot()); @@ -92,7 +92,7 @@ KJ_TEST("TestListDefaults") { checkTestMessage(structReader); } -KJ_TEST("raw text") { +KJ_TEST("TextCodec raw text") { using TestType = capnproto_test::capnp::test::TestLateUnion; kj::String message = @@ -126,6 +126,66 @@ KJ_TEST("raw text") { KJ_EXPECT(reader.getAnotherUnion().getCorge()[2] == 9); } +KJ_TEST("TextCodec parse error") { + auto message = "\n (,)"_kj; + + MallocMessageBuilder builder; + auto root = builder.initRoot(); + + TextCodec codec; + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions( + [&]() { codec.decode(message, root); })); + + KJ_EXPECT(exception.getFile() == "(capnp text input)"_kj); + KJ_EXPECT(exception.getLine() == 2); + KJ_EXPECT(exception.getDescription() == "3-6: Parse error: Empty list item.", + exception.getDescription()); +} + +KJ_TEST("text format implicitly coerces struct value from first field type") { + // We don't actually use TextCodec here, but rather check how the compiler handled some constants + // defined in test.capnp. It's the same parser code either way but this is easier. + + { + auto s = test::TestImpliedFirstField::Reader().getTextStruct(); + KJ_EXPECT(s.getText() == "foo"); + KJ_EXPECT(s.getI() == 321); + } + + { + auto s = test::TEST_IMPLIED_FIRST_FIELD->getTextStruct(); + KJ_EXPECT(s.getText() == "bar"); + KJ_EXPECT(s.getI() == 321); + } + +#if __GNUC__ && !__clang__ +// GCC generates a spurious warning here... +#pragma GCC diagnostic ignored "-Wmisleading-indentation" +#endif + + { + auto l = test::TEST_IMPLIED_FIRST_FIELD->getTextStructList(); + KJ_ASSERT(l.size() == 2); + + { + auto s = l[0]; + KJ_EXPECT(s.getText() == "baz"); + KJ_EXPECT(s.getI() == 321); + } + { + auto s = l[1]; + KJ_EXPECT(s.getText() == "qux"); + KJ_EXPECT(s.getI() == 123); + } + } + + { + auto s = test::TEST_IMPLIED_FIRST_FIELD->getIntGroup(); + KJ_EXPECT(s.getI() == 123); + KJ_EXPECT(s.getStr() == "corge"); + } +} + } // namespace } // namespace _ (private) } // namespace capnp diff --git a/c++/src/capnp/serialize-text.c++ b/c++/src/capnp/serialize-text.c++ index 738005f258..4583e5dfe1 100644 --- a/c++/src/capnp/serialize-text.c++ +++ b/c++/src/capnp/serialize-text.c++ @@ -29,16 +29,36 @@ #include "compiler/node-translator.h" #include "compiler/parser.h" +namespace capnp { + namespace { class ThrowingErrorReporter final: public capnp::compiler::ErrorReporter { // Throws all errors as assertion failures. public: + ThrowingErrorReporter(kj::StringPtr input): input(input) {} + void addError(uint32_t startByte, uint32_t endByte, kj::StringPtr message) override { - KJ_FAIL_REQUIRE(kj::str(message, " (", startByte, ":", endByte, ").")); + // Note: Line and column numbers are usually 1-based. + uint line = 1; + uint32_t lineStart = 0; + for (auto i: kj::zeroTo(startByte)) { + if (input[i] == '\n') { + ++line; + lineStart = i; // Omit +1 so that column is 1-based. + } + } + + kj::throwRecoverableException(kj::Exception( + kj::Exception::Type::FAILED, "(capnp text input)", line, + kj::str(startByte - lineStart, "-", endByte - lineStart, ": ", message) + )); } bool hadErrors() override { return false; } + +private: + kj::StringPtr input; }; class ExternalResolver final: public capnp::compiler::ValueTranslator::Resolver { @@ -59,7 +79,7 @@ template void lexAndParseExpression(kj::StringPtr input, Function f) { // Parses a single expression from the input and calls `f(expression)`. - ThrowingErrorReporter errorReporter; + ThrowingErrorReporter errorReporter(input); capnp::MallocMessageBuilder tokenArena; auto lexedTokens = tokenArena.initRoot(); @@ -90,8 +110,6 @@ void lexAndParseExpression(kj::StringPtr input, Function f) { } // namespace -namespace capnp { - TextCodec::TextCodec() : prettyPrint(false) {} TextCodec::~TextCodec() noexcept(true) {} @@ -112,10 +130,10 @@ kj::String TextCodec::encode(DynamicValue::Reader value) const { } void TextCodec::decode(kj::StringPtr input, DynamicStruct::Builder output) const { - lexAndParseExpression(input, [&output](compiler::Expression::Reader expression) { - KJ_REQUIRE(expression.isTuple(), "Input does not contain a struct."); + lexAndParseExpression(input, [&](compiler::Expression::Reader expression) { + KJ_REQUIRE(expression.isTuple(), "Input does not contain a struct.") { return; } - ThrowingErrorReporter errorReporter; + ThrowingErrorReporter errorReporter(input); ExternalResolver nullResolver; Orphanage orphanage = Orphanage::getForMessageContaining(output); @@ -126,9 +144,9 @@ void TextCodec::decode(kj::StringPtr input, DynamicStruct::Builder output) const Orphan TextCodec::decode(kj::StringPtr input, Type type, Orphanage orphanage) const { Orphan output; - - lexAndParseExpression(input, [&type, &orphanage, &output](compiler::Expression::Reader expression) { - ThrowingErrorReporter errorReporter; + + lexAndParseExpression(input, [&](compiler::Expression::Reader expression) { + ThrowingErrorReporter errorReporter(input); ExternalResolver nullResolver; compiler::ValueTranslator translator(nullResolver, errorReporter, orphanage); @@ -138,7 +156,7 @@ Orphan TextCodec::decode(kj::StringPtr input, Type type, Orphanage // An error should have already been given to the errorReporter. } }); - + return output; } diff --git a/c++/src/capnp/serialize-text.h b/c++/src/capnp/serialize-text.h index d86fc2c00e..8acd9be844 100644 --- a/c++/src/capnp/serialize-text.h +++ b/c++/src/capnp/serialize-text.h @@ -19,18 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_SERIALIZE_TEXT_H_ -#define CAPNP_SERIALIZE_TEXT_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include "dynamic.h" #include "orphan.h" #include "schema.h" +CAPNP_BEGIN_HEADER + namespace capnp { class TextCodec { @@ -93,4 +90,4 @@ inline Orphan TextCodec::decode(kj::StringPtr input, Orphanage orphanage) con } // namespace capnp -#endif // CAPNP_SERIALIZE_TEXT_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/serialize.c++ b/c++/src/capnp/serialize.c++ index df7e45e030..abb34f7998 100644 --- a/c++/src/capnp/serialize.c++ +++ b/c++/src/capnp/serialize.c++ @@ -23,6 +23,10 @@ #include "layout.h" #include #include +#ifdef _WIN32 +#include +#include +#endif namespace capnp { @@ -301,6 +305,15 @@ void writeMessage(kj::OutputStream& output, kj::ArrayPtr> segments) { +#ifdef _WIN32 + auto oldMode = _setmode(fd, _O_BINARY); + if (oldMode != _O_BINARY) { + _setmode(fd, oldMode); + KJ_FAIL_REQUIRE("Tried to write a message to a file descriptor that is in text mode. Set the " + "file descriptor to binary mode by calling the _setmode Windows CRT function, or passing " + "_O_BINARY to _open()."); + } +#endif kj::FdOutputStream stream(fd); writeMessage(stream, segments); } diff --git a/c++/src/capnp/serialize.h b/c++/src/capnp/serialize.h index 797db51766..b79dae935f 100644 --- a/c++/src/capnp/serialize.h +++ b/c++/src/capnp/serialize.h @@ -38,16 +38,13 @@ // - A multi-segment message can be read entirely in three system calls with no buffering. // - The format is appropriate for mmap()ing since all data is aligned. -#ifndef CAPNP_SERIALIZE_H_ -#define CAPNP_SERIALIZE_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include "message.h" #include +CAPNP_BEGIN_HEADER + namespace capnp { class FlatArrayMessageReader: public MessageReader { @@ -172,7 +169,7 @@ void writeMessage(kj::OutputStream& output, kj::ArrayPtr stream; + # + # Is equivalent to: + # + # write @0 (bytes :Data) -> import "/capnp/stream.capnp".StreamResult; + # + # However, implementations that recognize streaming will elide the reference to StreamResult + # and instead give write() a different signature appropriate for streaming. + # + # Streaming methods do not return a result -- that is, they return Promise. This promise + # resolves not to indicate that the call was actually delivered, but instead to provide + # backpressure. When the previous call's promise resolves, it is time to make another call. On + # the client side, the RPC system will resolve promises immediately until an appropriate number + # of requests are in-flight, and then will delay promise resolution to apply back-pressure. + # On the server side, the RPC system will deliver one call at a time. +} diff --git a/c++/src/capnp/stream.capnp.c++ b/c++/src/capnp/stream.capnp.c++ new file mode 100644 index 0000000000..098f26a5b8 --- /dev/null +++ b/c++/src/capnp/stream.capnp.c++ @@ -0,0 +1,55 @@ +// Generated by Cap'n Proto compiler, DO NOT EDIT +// source: stream.capnp + +#include "stream.capnp.h" + +namespace capnp { +namespace schemas { +static const ::capnp::_::AlignedData<17> b_995f9a3377c0b16e = { + { 0, 0, 0, 0, 5, 0, 6, 0, + 110, 177, 192, 119, 51, 154, 95, 153, + 19, 0, 0, 0, 1, 0, 0, 0, + 248, 243, 147, 19, 169, 102, 195, 134, + 0, 0, 7, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 21, 0, 0, 0, 2, 1, 0, 0, + 33, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 99, 97, 112, 110, 112, 47, 115, 116, + 114, 101, 97, 109, 46, 99, 97, 112, + 110, 112, 58, 83, 116, 114, 101, 97, + 109, 82, 101, 115, 117, 108, 116, 0, + 0, 0, 0, 0, 1, 0, 1, 0, } +}; +::capnp::word const* const bp_995f9a3377c0b16e = b_995f9a3377c0b16e.words; +#if !CAPNP_LITE +const ::capnp::_::RawSchema s_995f9a3377c0b16e = { + 0x995f9a3377c0b16e, b_995f9a3377c0b16e.words, 17, nullptr, nullptr, + 0, 0, nullptr, nullptr, nullptr, { &s_995f9a3377c0b16e, nullptr, nullptr, 0, 0, nullptr }, false +}; +#endif // !CAPNP_LITE +} // namespace schemas +} // namespace capnp + +// ======================================================================================= + +namespace capnp { + +// StreamResult +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr uint16_t StreamResult::_capnpPrivate::dataWordSize; +constexpr uint16_t StreamResult::_capnpPrivate::pointerCount; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#if !CAPNP_LITE +#if CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +constexpr ::capnp::Kind StreamResult::_capnpPrivate::kind; +constexpr ::capnp::_::RawSchema const* StreamResult::_capnpPrivate::schema; +#endif // !CAPNP_NEED_REDUNDANT_CONSTEXPR_DECL +#endif // !CAPNP_LITE + + +} // namespace + diff --git a/c++/src/capnp/stream.capnp.h b/c++/src/capnp/stream.capnp.h new file mode 100644 index 0000000000..b3ddaf61cc --- /dev/null +++ b/c++/src/capnp/stream.capnp.h @@ -0,0 +1,121 @@ +// Generated by Cap'n Proto compiler, DO NOT EDIT +// source: stream.capnp + +#pragma once + +#include +#include + +#ifndef CAPNP_VERSION +#error "CAPNP_VERSION is not defined, is capnp/generated-header-support.h missing?" +#elif CAPNP_VERSION != 1001000 +#error "Version mismatch between generated code and library headers. You must use the same version of the Cap'n Proto compiler and library." +#endif + + +CAPNP_BEGIN_HEADER + +namespace capnp { +namespace schemas { + +CAPNP_DECLARE_SCHEMA(995f9a3377c0b16e); + +} // namespace schemas +} // namespace capnp + +namespace capnp { + +struct StreamResult { + StreamResult() = delete; + + class Reader; + class Builder; + class Pipeline; + + struct _capnpPrivate { + CAPNP_DECLARE_STRUCT_HEADER(995f9a3377c0b16e, 0, 0) + #if !CAPNP_LITE + static constexpr ::capnp::_::RawBrandedSchema const* brand() { return &schema->defaultBrand; } + #endif // !CAPNP_LITE + }; +}; + +// ======================================================================================= + +class StreamResult::Reader { +public: + typedef StreamResult Reads; + + Reader() = default; + inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} + + inline ::capnp::MessageSize totalSize() const { + return _reader.totalSize().asPublic(); + } + +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { + return ::capnp::_::structString(_reader, *_capnpPrivate::brand()); + } +#endif // !CAPNP_LITE + +private: + ::capnp::_::StructReader _reader; + template + friend struct ::capnp::ToDynamic_; + template + friend struct ::capnp::_::PointerHelpers; + template + friend struct ::capnp::List; + friend class ::capnp::MessageBuilder; + friend class ::capnp::Orphanage; +}; + +class StreamResult::Builder { +public: + typedef StreamResult Builds; + + Builder() = delete; // Deleted to discourage incorrect usage. + // You can explicitly initialize to nullptr instead. + inline Builder(decltype(nullptr)) {} + inline explicit Builder(::capnp::_::StructBuilder base): _builder(base) {} + inline operator Reader() const { return Reader(_builder.asReader()); } + inline Reader asReader() const { return *this; } + + inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); } +#if !CAPNP_LITE + inline ::kj::StringTree toString() const { return asReader().toString(); } +#endif // !CAPNP_LITE + +private: + ::capnp::_::StructBuilder _builder; + template + friend struct ::capnp::ToDynamic_; + friend class ::capnp::Orphanage; + template + friend struct ::capnp::_::PointerHelpers; +}; + +#if !CAPNP_LITE +class StreamResult::Pipeline { +public: + typedef StreamResult Pipelines; + + inline Pipeline(decltype(nullptr)): _typeless(nullptr) {} + inline explicit Pipeline(::capnp::AnyPointer::Pipeline&& typeless) + : _typeless(kj::mv(typeless)) {} + +private: + ::capnp::AnyPointer::Pipeline _typeless; + friend class ::capnp::PipelineHook; + template + friend struct ::capnp::ToDynamic_; +}; +#endif // !CAPNP_LITE + +// ======================================================================================= + +} // namespace + +CAPNP_END_HEADER + diff --git a/c++/src/capnp/stringify.c++ b/c++/src/capnp/stringify.c++ index 322810734e..e17d47faf5 100644 --- a/c++/src/capnp/stringify.c++ +++ b/c++/src/capnp/stringify.c++ @@ -22,13 +22,12 @@ #include "dynamic.h" #include #include +#include namespace capnp { namespace { -static const char HEXDIGITS[] = "0123456789abcdef"; - enum PrintMode { BARE, // The value is planned to be printed on its own line, unless it is very short and contains @@ -140,44 +139,14 @@ static kj::StringTree print(const DynamicValue::Reader& value, } else { return kj::strTree(value.as()); } - case DynamicValue::TEXT: + case DynamicValue::TEXT: { + kj::ArrayPtr chars = value.as(); + return kj::strTree('"', kj::encodeCEscape(chars), '"'); + } case DynamicValue::DATA: { // TODO(someday): Maybe data should be printed as binary literal. - kj::ArrayPtr chars; - if (value.getType() == DynamicValue::DATA) { - chars = value.as().asChars(); - } else { - chars = value.as(); - } - - kj::Vector escaped(chars.size()); - - for (char c: chars) { - switch (c) { - case '\a': escaped.addAll(kj::StringPtr("\\a")); break; - case '\b': escaped.addAll(kj::StringPtr("\\b")); break; - case '\f': escaped.addAll(kj::StringPtr("\\f")); break; - case '\n': escaped.addAll(kj::StringPtr("\\n")); break; - case '\r': escaped.addAll(kj::StringPtr("\\r")); break; - case '\t': escaped.addAll(kj::StringPtr("\\t")); break; - case '\v': escaped.addAll(kj::StringPtr("\\v")); break; - case '\'': escaped.addAll(kj::StringPtr("\\\'")); break; - case '\"': escaped.addAll(kj::StringPtr("\\\"")); break; - case '\\': escaped.addAll(kj::StringPtr("\\\\")); break; - default: - if (c < 0x20) { - escaped.add('\\'); - escaped.add('x'); - uint8_t c2 = c; - escaped.add(HEXDIGITS[c2 / 16]); - escaped.add(HEXDIGITS[c2 % 16]); - } else { - escaped.add(c); - } - break; - } - } - return kj::strTree('"', escaped, '"'); + kj::ArrayPtr bytes = value.as().asBytes(); + return kj::strTree('"', kj::encodeCEscape(bytes), '"'); } case DynamicValue::LIST: { auto listValue = value.as(); diff --git a/c++/src/capnp/test-util.c++ b/c++/src/capnp/test-util.c++ index c4a2f8e4f7..41fc08e072 100644 --- a/c++/src/capnp/test-util.c++ +++ b/c++/src/capnp/test-util.c++ @@ -19,9 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include "test-util.h" #include #include +#include +#include namespace capnp { namespace _ { // private @@ -975,6 +981,14 @@ kj::Promise TestPipelineImpl::getAnyCap(GetAnyCapContext context) { }); } +kj::Promise TestPipelineImpl::getCapPipelineOnly(GetCapPipelineOnlyContext context) { + ++callCount; + PipelineBuilder pb; + pb.initOutBox().setCap(kj::heap(callCount)); + context.setPipeline(pb.build()); + return kj::NEVER_DONE; +} + kj::Promise TestCallOrderImpl::getCallSequence(GetCallSequenceContext context) { auto result = context.getResults(); result.setN(count++); @@ -1062,7 +1076,6 @@ kj::Promise TestMoreStuffImpl::neverReturn(NeverReturnContext context) { // Also attach `cap` to the result struct to make sure that is released. context.getResults().setCapCopy(context.getParams().getCap()); - context.allowCancellation(); return kj::mv(promise); } @@ -1105,7 +1118,6 @@ kj::Promise TestMoreStuffImpl::echo(EchoContext context) { kj::Promise TestMoreStuffImpl::expectCancel(ExpectCancelContext context) { auto cap = context.getParams().getCap(); - context.allowCancellation(); return loop(0, cap, context); } @@ -1115,7 +1127,7 @@ kj::Promise TestMoreStuffImpl::loop(uint depth, test::TestInterface::Clien ADD_FAILURE() << "Looped too long, giving up."; return kj::READY_NOW; } else { - return kj::evalLater([this,depth,KJ_CPCAP(cap),KJ_CPCAP(context)]() mutable { + return kj::evalLast([this,depth,KJ_CPCAP(cap),KJ_CPCAP(context)]() mutable { return loop(depth + 1, cap, context); }); } @@ -1144,6 +1156,42 @@ kj::Promise TestMoreStuffImpl::getEnormousString(GetEnormousStringContext return kj::READY_NOW; } +kj::Promise TestMoreStuffImpl::writeToFd(WriteToFdContext context) { + auto params = context.getParams(); + + auto promises = kj::heapArrayBuilder>(2); + + promises.add(params.getFdCap1().getFd() + .then([](kj::Maybe fd) { + kj::FdOutputStream(KJ_ASSERT_NONNULL(fd)).write("foo", 3); + })); + promises.add(params.getFdCap2().getFd() + .then([context](kj::Maybe fd) mutable { + context.getResults().setSecondFdPresent(fd != nullptr); + KJ_IF_MAYBE(f, fd) { + kj::FdOutputStream(*f).write("bar", 3); + } + })); + + int pair[2]; + KJ_SYSCALL(kj::miniposix::pipe(pair)); + kj::AutoCloseFd in(pair[0]); + kj::AutoCloseFd out(pair[1]); + + kj::FdOutputStream(kj::mv(out)).write("baz", 3); + context.getResults().setFdCap3(kj::heap(kj::mv(in))); + + return kj::joinPromises(promises.finish()); +} + +kj::Promise TestMoreStuffImpl::throwException(ThrowExceptionContext context) { + return KJ_EXCEPTION(FAILED, "test exception"); +} + +kj::Promise TestMoreStuffImpl::throwRemoteException(ThrowRemoteExceptionContext context) { + return KJ_EXCEPTION(FAILED, "remote exception: test exception"); +} + #endif // !CAPNP_LITE } // namespace _ (private) diff --git a/c++/src/capnp/test-util.h b/c++/src/capnp/test-util.h index 64fcaa5036..38b445718a 100644 --- a/c++/src/capnp/test-util.h +++ b/c++/src/capnp/test-util.h @@ -19,12 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef CAPNP_TEST_UTIL_H_ -#define CAPNP_TEST_UTIL_H_ - -#if defined(__GNUC__) && !defined(CAPNP_HEADER_WARNINGS) -#pragma GCC system_header -#endif +#pragma once #include #include @@ -33,8 +28,11 @@ #if !CAPNP_LITE #include "dynamic.h" +#include #endif // !CAPNP_LITE +CAPNP_BEGIN_HEADER + // TODO(cleanup): Auto-generate stringification functions for union discriminants. namespace capnproto_test { namespace capnp { @@ -214,6 +212,7 @@ class TestPipelineImpl final: public test::TestPipeline::Server { kj::Promise getCap(GetCapContext context) override; kj::Promise getAnyCap(GetAnyCapContext context) override; + kj::Promise getCapPipelineOnly(GetCapPipelineOnlyContext context) override; private: int& callCount; @@ -223,6 +222,8 @@ class TestCallOrderImpl final: public test::TestCallOrder::Server { public: kj::Promise getCallSequence(GetCallSequenceContext context) override; + uint getCount() { return count; } + private: uint count = 0; }; @@ -275,6 +276,12 @@ class TestMoreStuffImpl final: public test::TestMoreStuff::Server { kj::Promise getEnormousString(GetEnormousStringContext context) override; + kj::Promise writeToFd(WriteToFdContext context) override; + + kj::Promise throwException(ThrowExceptionContext context) override; + + kj::Promise throwRemoteException(ThrowRemoteExceptionContext context) override; + private: int& callCount; int& handleCount; @@ -304,9 +311,56 @@ class TestCapDestructor final: public test::TestInterface::Server { TestInterfaceImpl impl; }; +class TestFdCap final: public test::TestInterface::Server { + // Implementation of TestInterface that wraps a file descriptor. + +public: + TestFdCap(kj::AutoCloseFd fd): fd(kj::mv(fd)) {} + + kj::Maybe getFd() override { return fd.get(); } + +private: + kj::AutoCloseFd fd; +}; + +class TestStreamingImpl final: public test::TestStreaming::Server { +public: + uint iSum = 0; + uint jSum = 0; + kj::Maybe>> fulfiller; + bool jShouldThrow = false; + + kj::Promise doStreamI(DoStreamIContext context) override { + iSum += context.getParams().getI(); + auto paf = kj::newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + + kj::Promise doStreamJ(DoStreamJContext context) override { + jSum += context.getParams().getJ(); + + if (jShouldThrow) { + KJ_FAIL_ASSERT("throw requested") { break; } + return kj::READY_NOW; + } + + auto paf = kj::newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + + kj::Promise finishStream(FinishStreamContext context) override { + auto results = context.getResults(); + results.setTotalI(iSum); + results.setTotalJ(jSum); + return kj::READY_NOW; + } +}; + #endif // !CAPNP_LITE } // namespace _ (private) } // namespace capnp -#endif // TEST_UTIL_H_ +CAPNP_END_HEADER diff --git a/c++/src/capnp/test.capnp b/c++/src/capnp/test.capnp index a077c6c1de..11b52f6c31 100644 --- a/c++/src/capnp/test.capnp +++ b/c++/src/capnp/test.capnp @@ -127,7 +127,8 @@ struct TestDefaults { textList = ["quux", "corge", "grault"], dataList = ["garply", "waldo", "fred"], structList = [ - (textField = "x structlist 1"), + (textField = "x " "structlist" + " 1"), (textField = "x structlist 2"), (textField = "x structlist 3")], enumList = [qux, bar, grault] @@ -537,6 +538,9 @@ struct TestGenerics(Foo, Bar) { } } + list @4 :List(Inner); + # At one time this failed to compile with MSVC due to poor expression SFINAE support. + struct Inner { foo @0 :Foo; bar @1 :Bar; @@ -583,6 +587,9 @@ struct TestGenerics(Foo, Bar) { } } +struct BoxedText { text @0 :Text; } +using BrandedAlias = TestGenerics(BoxedText, Text); + struct TestGenericsWrapper(Foo, Bar) { value @0 :TestGenerics(Foo, Bar); } @@ -646,6 +653,8 @@ struct TestUseGenerics $TestGenerics(Text, Data).ann("foo") { inner2Bind = (baz = "text", innerBound = (foo = (int16Field = 123))), inner2Text = (baz = "text", innerBound = (foo = (int16Field = 123))), revFoo = [12, 34, 56]); + + bindEnumList @20 :TestGenerics(List(TestEnum), Text); } struct TestEmptyStruct {} @@ -701,7 +710,8 @@ struct TestConstants { textList = ["quux", "corge", "grault"], dataList = ["garply", "waldo", "fred"], structList = [ - (textField = "x structlist 1"), + (textField = "x " "structlist" + " 1"), (textField = "x structlist 2"), (textField = "x structlist 3")], enumList = [qux, bar, grault] @@ -750,12 +760,18 @@ const embeddedStruct :TestAllTypes = embed "testdata/binary"; const nonAsciiText :Text = "♫ é ✓"; +const blockText :Text = + `foo bar baz + `"qux" `corge` 'grault' + "regular\"quoted\"line" + `garply\nwaldo\tfred\"plugh\"xyzzy\'thud + ; + struct TestAnyPointerConstants { anyKindAsStruct @0 :AnyPointer; anyStructAsStruct @1 :AnyStruct; anyKindAsList @2 :AnyPointer; anyListAsList @3 :AnyList; - } const anyPointerConstants :TestAnyPointerConstants = ( @@ -765,6 +781,11 @@ const anyPointerConstants :TestAnyPointerConstants = ( anyListAsList = TestConstants.int32ListConst, ); +struct TestListOfAny { + capList @0 :List(Capability); + #listList @1 :List(AnyList); # TODO(someday): Make List(AnyList) work correctly in C++ generated code. +} + interface TestInterface { foo @0 (i :UInt32, j :Bool) -> (x :Text); bar @1 () -> (); @@ -784,6 +805,9 @@ interface TestPipeline { testPointers @1 (cap :TestInterface, obj :AnyPointer, list :List(TestInterface)) -> (); getAnyCap @2 (n: UInt32, inCap :Capability) -> (s: Text, outBox :AnyBox); + getCapPipelineOnly @3 () -> (outBox :Box); + # Never returns, but uses setPipeline() to make the pipeline work. + struct Box { cap @0 :TestInterface; } @@ -799,7 +823,7 @@ interface TestCallOrder { # The input `expected` is ignored but useful for disambiguating debug logs. } -interface TestTailCallee { +interface TestTailCallee $Cxx.allowCancellation { struct TailResult { i @0 :UInt32; t @1 :Text; @@ -813,6 +837,13 @@ interface TestTailCaller { foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult; } +interface TestStreaming $Cxx.allowCancellation { + doStreamI @0 (i :UInt32) -> stream; + doStreamJ @1 (j :UInt32) -> stream; + finishStream @2 () -> (totalI :UInt32, totalJ :UInt32); + # Test streaming. finishStream() returns the totals of the values streamed to the other calls. +} + interface TestHandle {} interface TestMoreStuff extends(TestCallOrder) { @@ -824,7 +855,7 @@ interface TestMoreStuff extends(TestCallOrder) { callFooWhenResolved @1 (cap :TestInterface) -> (s: Text); # Like callFoo but waits for `cap` to resolve first. - neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface); + neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface) $Cxx.allowCancellation; # Doesn't return. You should cancel it. hold @3 (cap :TestInterface) -> (); @@ -839,7 +870,7 @@ interface TestMoreStuff extends(TestCallOrder) { echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder); # Just returns the input cap. - expectCancel @7 (cap :TestInterface) -> (); + expectCancel @7 (cap :TestInterface) -> () $Cxx.allowCancellation; # evalLater()-loops forever, holding `cap`. Must be canceled. methodWithDefaults @8 (a :Text, b :UInt32 = 123, c :Text = "foo") -> (d :Text, e :Text = "bar"); @@ -855,6 +886,14 @@ interface TestMoreStuff extends(TestCallOrder) { getEnormousString @11 () -> (str :Text); # Attempts to return an 100MB string. Should always fail. + + writeToFd @13 (fdCap1 :TestInterface, fdCap2 :TestInterface) + -> (fdCap3 :TestInterface, secondFdPresent :Bool); + # Expects fdCap1 and fdCap2 wrap socket file descriptors. Writes "foo" to the first and "bar" to + # the second. Also creates a socketpair, writes "baz" to one end, and returns the other end. + + throwException @14 (); + throwRemoteException @15 (); } interface TestMembrane { @@ -863,6 +902,8 @@ interface TestMembrane { callIntercept @2 (thing :Thing, tailCall :Bool) -> Result; loopback @3 (thing :Thing) -> (thing :Thing); + waitForever @4 () $Cxx.allowCancellation; + interface Thing { passThrough @0 () -> Result; intercept @1 () -> Result; @@ -960,3 +1001,42 @@ struct TestNameAnnotation $Cxx.name("RenamedStruct") { interface TestNameAnnotationInterface $Cxx.name("RenamedInterface") { badlyNamedMethod @0 (badlyNamedParam :UInt8 $Cxx.name("renamedParam")) $Cxx.name("renamedMethod"); } + +struct TestImpliedFirstField { + struct TextStruct { + text @0 :Text; + i @1 :UInt32 = 321; + } + + textStruct @0 :TextStruct = "foo"; + textStructList @1 :List(TextStruct); + + intGroup :group { + i @2 :UInt32; + str @3 :Text = "corge"; + } +} + +const testImpliedFirstField :TestImpliedFirstField = ( + textStruct = "bar", + textStructList = ["baz", (text = "qux", i = 123)], + intGroup = 123 +); + +struct TestCycleANoCaps { + foo @0 :TestCycleBNoCaps; +} + +struct TestCycleBNoCaps { + foo @0 :List(TestCycleANoCaps); + bar @1 :TestAllTypes; +} + +struct TestCycleAWithCaps { + foo @0 :TestCycleBWithCaps; +} + +struct TestCycleBWithCaps { + foo @0 :List(TestCycleAWithCaps); + bar @1 :TestInterface; +} diff --git a/c++/src/capnp/testdata/annotated-json.binary b/c++/src/capnp/testdata/annotated-json.binary new file mode 100644 index 0000000000..6c54755184 Binary files /dev/null and b/c++/src/capnp/testdata/annotated-json.binary differ diff --git a/c++/src/capnp/testdata/annotated.json b/c++/src/capnp/testdata/annotated.json new file mode 100644 index 0000000000..bb514001ec --- /dev/null +++ b/c++/src/capnp/testdata/annotated.json @@ -0,0 +1,22 @@ +{ "names-can_contain!anything Really": "foo", + "flatFoo": 123, + "flatBar": "abc", + "renamed-flatBaz": {"hello": true}, + "flatQux": "cba", + "pfx.foo": "this is a long string in order to force multi-line pretty printing", + "pfx.renamed-bar": 321, + "pfx.baz": {"hello": true}, + "pfx.xfp.qux": "fed", + "union-type": "renamed-bar", + "barMember": 789, + "multiMember": "ghi", + "dependency": {"renamed-foo": "corge"}, + "simpleGroup": {"renamed-grault": "garply"}, + "enums": ["qux", "renamed-bar", "foo", "renamed-baz"], + "innerJson": [123, "hello", {"object": true}], + "testBase64": "ZnJlZA==", + "testHex": "706c756768", + "bUnion": "renamed-bar", + "bValue": 678, + "externalUnion": {"type": "bar", "value": "cba"}, + "unionWithVoid": {"type": "voidValue"} } diff --git a/c++/src/capnp/testdata/errors.capnp.nobuild b/c++/src/capnp/testdata/errors.capnp.nobuild index a909e970a0..9cd3beb541 100644 --- a/c++/src/capnp/testdata/errors.capnp.nobuild +++ b/c++/src/capnp/testdata/errors.capnp.nobuild @@ -97,6 +97,7 @@ struct Foo { listWithoutParam @31 :List; listWithTooManyParams @32 :List(Int32, Int64); listAnyPointer @33 :List(AnyPointer); + listAnyStruct @48 :List(AnyStruct); notAType @34 :notType; noParams @35 :Foo(Int32); @@ -141,6 +142,7 @@ enum DupEnumerants { const recursive: UInt32 = .recursive; struct Generic(T, U) { + foo @0 :UInt32 $T; } struct UseGeneric { @@ -158,4 +160,5 @@ using Baz = import "nosuchfile-unused.capnp".Baz; interface TestInterface { foo @0 (a :UInt32 = null); + bar @1 stream -> (); } diff --git a/c++/src/capnp/testdata/errors.txt b/c++/src/capnp/testdata/errors.txt index ed238e482c..a455cacab7 100644 --- a/c++/src/capnp/testdata/errors.txt +++ b/c++/src/capnp/testdata/errors.txt @@ -2,7 +2,7 @@ file:74:30-32: error: As of Cap'n Proto v0.3, it is no longer necessary to assig file:74:30-32: error: As of Cap'n Proto v0.3, the 'union' keyword should be prefixed with a colon for named unions, e.g. `foo :union {`. file:79:23-25: error: As of Cap'n Proto v0.3, it is no longer necessary to assign numbers to unions. However, removing the number will break binary compatibility. If this is an old protocol and you need to retain compatibility, please add an exclamation point after the number to indicate that it is really needed, e.g. `foo @1! :union {`. If this is a new protocol or compatibility doesn't matter, just remove the @n entirely. Sorry for the inconvenience, and thanks for being an early adopter! :) file:84:17-19: error: As of Cap'n Proto v0.3, the 'union' keyword should be prefixed with a colon for named unions, e.g. `foo :union {`. -file:132:7-10: error: 'using' declaration without '=' must specify a named declaration from a different scope. +file:133:7-10: error: 'using' declaration without '=' must specify a named declaration from a different scope. file:37:3-10: error: 'dupName' is already defined in this scope. file:36:3-10: error: 'dupName' previously defined here. file:52:5-12: error: 'dupName' is already defined in this scope. @@ -21,40 +21,44 @@ file:39:15-16: error: Duplicate ordinal number. file:38:15-16: error: Ordinal @2 originally used here. file:41:18-19: error: Skipped ordinal @3. Ordinals must be sequential with no holes. file:69:15-17: error: Union ordinal, if specified, must be greater than no more than one of its member ordinals (i.e. there can only be one field retroactively unionized). -file:116:31-50: error: Import failed: noshuchfile.capnp -file:118:26-32: error: Not defined: NoSuch -file:119:28-34: error: 'Foo' has no member named 'NoSuch' +file:117:31-50: error: Import failed: noshuchfile.capnp +file:119:26-32: error: Not defined: NoSuch +file:120:28-34: error: 'Foo' has no member named 'NoSuch' file:97:25-29: error: 'List' requires exactly one parameter. file:98:30-48: error: Too many generic parameters. file:98:30-34: error: 'List' requires exactly one parameter. file:99:23-39: error: 'List(AnyPointer)' is not supported. -file:100:17-24: error: 'notType' is not a type. -file:101:17-27: error: Declaration does not accept generic parameters. -file:103:34-41: error: Integer value out of range. -file:104:37-38: error: Integer value out of range. -file:105:32-35: error: Type mismatch; expected Text. -file:106:33-38: error: Type mismatch; expected Text. -file:107:33-55: error: Type mismatch; expected Text. -file:108:43-61: error: Integer is too big to be negative. -file:109:35-39: error: '.Foo' does not refer to a constant. -file:110:44-51: error: Constant names must be qualified to avoid confusion. Please replace 'notType' with '.notType', if that's what you intended. -file:117:28-34: error: Not defined: NoSuch -file:112:29-32: error: 'Foo' is not an annotation. -file:113:29-47: error: 'notFieldAnnotation' cannot be applied to this kind of declaration. -file:114:33-48: error: 'fieldAnnotation' requires a value. -file:126:35-46: error: Struct has no field named 'nosuchfield'. -file:127:49-52: error: Type mismatch; expected group. -file:125:52-55: error: Missing field name. -file:136:3-10: error: 'dupName' is already defined in this scope. -file:135:3-10: error: 'dupName' previously defined here. -file:138:15-16: error: Duplicate ordinal number. -file:137:15-16: error: Ordinal @2 originally used here. -file:141:7-16: error: Declaration recursively depends on itself. -file:147:14-27: error: Not enough generic parameters. -file:148:15-47: error: Too many generic parameters. -file:149:18-49: error: Double-application of generic parameters. -file:150:38-43: error: Sorry, only pointer types can be used as generic parameters. -file:153:30-44: error: Embeds can only be used when Text, Data, or a struct is expected. -file:154:37-51: error: Couldn't read file for embed: no-such-file -file:160:23-27: error: Only pointer parameters can declare their default as 'null'. -file:156:20-45: error: Import failed: nosuchfile-unused.capnp +file:101:17-24: error: 'notType' is not a type. +file:102:17-27: error: Declaration does not accept generic parameters. +file:104:34-41: error: Integer value out of range. +file:105:37-38: error: Integer value out of range. +file:106:32-35: error: Type mismatch; expected Text. +file:107:33-38: error: Type mismatch; expected Text. +file:108:33-55: error: Type mismatch; expected Text. +file:109:43-61: error: Integer is too big to be negative. +file:110:35-39: error: '.Foo' does not refer to a constant. +file:111:44-51: error: Constant names must be qualified to avoid confusion. Please replace 'notType' with '.notType', if that's what you intended. +file:118:28-34: error: Not defined: NoSuch +file:100:22-37: error: 'List(AnyStruct)' is not supported. +file:113:29-32: error: 'Foo' is not an annotation. +file:114:29-47: error: 'notFieldAnnotation' cannot be applied to this kind of declaration. +file:115:33-48: error: 'fieldAnnotation' requires a value. +file:127:35-46: error: Struct has no field named 'nosuchfield'. +file:128:49-52: error: Type mismatch; expected group. +file:126:52-55: error: Missing field name. +file:137:3-10: error: 'dupName' is already defined in this scope. +file:136:3-10: error: 'dupName' previously defined here. +file:139:15-16: error: Duplicate ordinal number. +file:138:15-16: error: Ordinal @2 originally used here. +file:142:7-16: error: Declaration recursively depends on itself. +file:145:19-20: error: 'T' is not an annotation. +file:149:14-27: error: Not enough generic parameters. +file:150:15-47: error: Too many generic parameters. +file:151:18-49: error: Double-application of generic parameters. +file:152:38-43: error: Sorry, only pointer types can be used as generic parameters. +file:155:30-44: error: Embeds can only be used when Text, Data, or a struct is expected. +file:156:37-51: error: Couldn't read file for embed: no-such-file +file:162:23-27: error: Only pointer parameters can declare their default as 'null'. +file:163:10-16: error: 'stream' can only appear after '->', not before. +file:163:10-16: error: A method declaration uses streaming, but '/capnp/stream.capnp' is not found in the import path. This is a standard file that should always be installed with the Cap'n Proto compiler. +file:158:20-45: error: Import failed: nosuchfile-unused.capnp diff --git a/c++/src/capnp/testdata/errors2.capnp.nobuild b/c++/src/capnp/testdata/errors2.capnp.nobuild new file mode 100644 index 0000000000..e2fbf016b4 --- /dev/null +++ b/c++/src/capnp/testdata/errors2.capnp.nobuild @@ -0,0 +1,37 @@ +# Copyright (c) 2020 Cloudflare, Inc. and contributors +# Licensed under the MIT License: +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +@0xea7dcf0ca9acfa97; +# This is much like errors.capnp.nobuild but expresses errors that occur in a later phase of +# compilation, which is never reached when building errors.capnp.nobuild because the compiler +# bails out after other errors. + +struct DummyType {} +const dummyValue :DummyType = (); + +struct TestDefaultValueForGeneric(A) { + single @0 :A = .dummyValue; + nested @1 :Box(A) = (val = .dummyValue); +} + +struct Box(B) { + val @0 :B; +} diff --git a/c++/src/capnp/testdata/errors2.txt b/c++/src/capnp/testdata/errors2.txt new file mode 100644 index 0000000000..ed65d05e88 --- /dev/null +++ b/c++/src/capnp/testdata/errors2.txt @@ -0,0 +1,2 @@ +file:31:18-29: error: Cannot interpret value because the type is a generic type parameter which is not yet bound. We don't know what type to expect here. +file:32:30-41: error: Cannot interpret value because the type is a generic type parameter which is not yet bound. We don't know what type to expect here. diff --git a/c++/src/capnp/testdata/no-file-id.capnp.nobuild b/c++/src/capnp/testdata/no-file-id.capnp.nobuild new file mode 100644 index 0000000000..98c92e2924 --- /dev/null +++ b/c++/src/capnp/testdata/no-file-id.capnp.nobuild @@ -0,0 +1 @@ +const foo :Text = "bar"; diff --git a/c++/src/capnp/testdata/pretty.json b/c++/src/capnp/testdata/pretty.json new file mode 100644 index 0000000000..abf82d6137 --- /dev/null +++ b/c++/src/capnp/testdata/pretty.json @@ -0,0 +1,88 @@ +{ "voidField": null, + "boolField": true, + "int8Field": -123, + "int16Field": -12345, + "int32Field": -12345678, + "int64Field": "-123456789012345", + "uInt8Field": 234, + "uInt16Field": 45678, + "uInt32Field": 3456789012, + "uInt64Field": "12345678901234567890", + "float32Field": 1234.5, + "float64Field": -1.23e47, + "textField": "foo", + "dataField": [98, 97, 114], + "structField": { + "voidField": null, + "boolField": true, + "int8Field": -12, + "int16Field": 3456, + "int32Field": -78901234, + "int64Field": "56789012345678", + "uInt8Field": 90, + "uInt16Field": 1234, + "uInt32Field": 56789012, + "uInt64Field": "345678901234567890", + "float32Field": -1.2499999646475857e-10, + "float64Field": 345, + "textField": "baz", + "dataField": [113, 117, 120], + "structField": { + "voidField": null, + "boolField": false, + "int8Field": 0, + "int16Field": 0, + "int32Field": 0, + "int64Field": "0", + "uInt8Field": 0, + "uInt16Field": 0, + "uInt32Field": 0, + "uInt64Field": "0", + "float32Field": 0, + "float64Field": 0, + "textField": "nested", + "structField": {"voidField": null, "boolField": false, "int8Field": 0, "int16Field": 0, "int32Field": 0, "int64Field": "0", "uInt8Field": 0, "uInt16Field": 0, "uInt32Field": 0, "uInt64Field": "0", "float32Field": 0, "float64Field": 0, "textField": "really nested", "enumField": "foo", "interfaceField": null}, + "enumField": "foo", + "interfaceField": null }, + "enumField": "baz", + "interfaceField": null, + "voidList": [null, null, null], + "boolList": [false, true, false, true, true], + "int8List": [12, -34, -128, 127], + "int16List": [1234, -5678, -32768, 32767], + "int32List": [12345678, -90123456, -2147483648, 2147483647], + "int64List": ["123456789012345", "-678901234567890", "-9223372036854775808", "9223372036854775807"], + "uInt8List": [12, 34, 0, 255], + "uInt16List": [1234, 5678, 0, 65535], + "uInt32List": [12345678, 90123456, 0, 4294967295], + "uInt64List": ["123456789012345", "678901234567890", "0", "18446744073709551615"], + "float32List": [0, 1234567, 9.9999999338158125e36, -9.9999999338158125e36, 9.99999991097579e-38, -9.99999991097579e-38], + "float64List": [0, 123456789012345, 1e306, -1e306, 1e-306, -1e-306], + "textList": ["quux", "corge", "grault"], + "dataList": [[103, 97, 114, 112, 108, 121], [119, 97, 108, 100, 111], [102, 114, 101, 100]], + "structList": [ + {"voidField": null, "boolField": false, "int8Field": 0, "int16Field": 0, "int32Field": 0, "int64Field": "0", "uInt8Field": 0, "uInt16Field": 0, "uInt32Field": 0, "uInt64Field": "0", "float32Field": 0, "float64Field": 0, "textField": "x structlist 1", "enumField": "foo", "interfaceField": null}, + {"voidField": null, "boolField": false, "int8Field": 0, "int16Field": 0, "int32Field": 0, "int64Field": "0", "uInt8Field": 0, "uInt16Field": 0, "uInt32Field": 0, "uInt64Field": "0", "float32Field": 0, "float64Field": 0, "textField": "x structlist 2", "enumField": "foo", "interfaceField": null}, + {"voidField": null, "boolField": false, "int8Field": 0, "int16Field": 0, "int32Field": 0, "int64Field": "0", "uInt8Field": 0, "uInt16Field": 0, "uInt32Field": 0, "uInt64Field": "0", "float32Field": 0, "float64Field": 0, "textField": "x structlist 3", "enumField": "foo", "interfaceField": null} ], + "enumList": ["qux", "bar", "grault"] }, + "enumField": "corge", + "interfaceField": null, + "voidList": [null, null, null, null, null, null], + "boolList": [true, false, false, true], + "int8List": [111, -111], + "int16List": [11111, -11111], + "int32List": [111111111, -111111111], + "int64List": ["1111111111111111111", "-1111111111111111111"], + "uInt8List": [111, 222], + "uInt16List": [33333, 44444], + "uInt32List": [3333333333], + "uInt64List": ["11111111111111111111"], + "float32List": [5555.5, "Infinity", "-Infinity", "NaN"], + "float64List": [7777.75, "Infinity", "-Infinity", "NaN"], + "textList": ["plugh", "xyzzy", "thud"], + "dataList": [[111, 111, 112, 115], [101, 120, 104, 97, 117, 115, 116, 101, 100], [114, 102, 99, 51, 48, 57, 50]], + "structList": [ + {"voidField": null, "boolField": false, "int8Field": 0, "int16Field": 0, "int32Field": 0, "int64Field": "0", "uInt8Field": 0, "uInt16Field": 0, "uInt32Field": 0, "uInt64Field": "0", "float32Field": 0, "float64Field": 0, "textField": "structlist 1", "enumField": "foo", "interfaceField": null}, + {"voidField": null, "boolField": false, "int8Field": 0, "int16Field": 0, "int32Field": 0, "int64Field": "0", "uInt8Field": 0, "uInt16Field": 0, "uInt32Field": 0, "uInt64Field": "0", "float32Field": 0, "float64Field": 0, "textField": "structlist 2", "enumField": "foo", "interfaceField": null}, + {"voidField": null, "boolField": false, "int8Field": 0, "int16Field": 0, "int32Field": 0, "int64Field": "0", "uInt8Field": 0, "uInt16Field": 0, "uInt32Field": 0, "uInt64Field": "0", "float32Field": 0, "float64Field": 0, "textField": "structlist 3", "enumField": "foo", "interfaceField": null} ], + "enumList": ["foo", "garply"] } diff --git a/c++/src/capnp/testdata/short.json b/c++/src/capnp/testdata/short.json new file mode 100644 index 0000000000..26cbfd0e87 --- /dev/null +++ b/c++/src/capnp/testdata/short.json @@ -0,0 +1 @@ +{"voidField":null,"boolField":true,"int8Field":-123,"int16Field":-12345,"int32Field":-12345678,"int64Field":"-123456789012345","uInt8Field":234,"uInt16Field":45678,"uInt32Field":3456789012,"uInt64Field":"12345678901234567890","float32Field":1234.5,"float64Field":-1.23e47,"textField":"foo","dataField":[98,97,114],"structField":{"voidField":null,"boolField":true,"int8Field":-12,"int16Field":3456,"int32Field":-78901234,"int64Field":"56789012345678","uInt8Field":90,"uInt16Field":1234,"uInt32Field":56789012,"uInt64Field":"345678901234567890","float32Field":-1.2499999646475857e-10,"float64Field":345,"textField":"baz","dataField":[113,117,120],"structField":{"voidField":null,"boolField":false,"int8Field":0,"int16Field":0,"int32Field":0,"int64Field":"0","uInt8Field":0,"uInt16Field":0,"uInt32Field":0,"uInt64Field":"0","float32Field":0,"float64Field":0,"textField":"nested","structField":{"voidField":null,"boolField":false,"int8Field":0,"int16Field":0,"int32Field":0,"int64Field":"0","uInt8Field":0,"uInt16Field":0,"uInt32Field":0,"uInt64Field":"0","float32Field":0,"float64Field":0,"textField":"really nested","enumField":"foo","interfaceField":null},"enumField":"foo","interfaceField":null},"enumField":"baz","interfaceField":null,"voidList":[null,null,null],"boolList":[false,true,false,true,true],"int8List":[12,-34,-128,127],"int16List":[1234,-5678,-32768,32767],"int32List":[12345678,-90123456,-2147483648,2147483647],"int64List":["123456789012345","-678901234567890","-9223372036854775808","9223372036854775807"],"uInt8List":[12,34,0,255],"uInt16List":[1234,5678,0,65535],"uInt32List":[12345678,90123456,0,4294967295],"uInt64List":["123456789012345","678901234567890","0","18446744073709551615"],"float32List":[0,1234567,9.9999999338158125e36,-9.9999999338158125e36,9.99999991097579e-38,-9.99999991097579e-38],"float64List":[0,123456789012345,1e306,-1e306,1e-306,-1e-306],"textList":["quux","corge","grault"],"dataList":[[103,97,114,112,108,121],[119,97,108,100,111],[102,114,101,100]],"structList":[{"voidField":null,"boolField":false,"int8Field":0,"int16Field":0,"int32Field":0,"int64Field":"0","uInt8Field":0,"uInt16Field":0,"uInt32Field":0,"uInt64Field":"0","float32Field":0,"float64Field":0,"textField":"x structlist 1","enumField":"foo","interfaceField":null},{"voidField":null,"boolField":false,"int8Field":0,"int16Field":0,"int32Field":0,"int64Field":"0","uInt8Field":0,"uInt16Field":0,"uInt32Field":0,"uInt64Field":"0","float32Field":0,"float64Field":0,"textField":"x structlist 2","enumField":"foo","interfaceField":null},{"voidField":null,"boolField":false,"int8Field":0,"int16Field":0,"int32Field":0,"int64Field":"0","uInt8Field":0,"uInt16Field":0,"uInt32Field":0,"uInt64Field":"0","float32Field":0,"float64Field":0,"textField":"x structlist 3","enumField":"foo","interfaceField":null}],"enumList":["qux","bar","grault"]},"enumField":"corge","interfaceField":null,"voidList":[null,null,null,null,null,null],"boolList":[true,false,false,true],"int8List":[111,-111],"int16List":[11111,-11111],"int32List":[111111111,-111111111],"int64List":["1111111111111111111","-1111111111111111111"],"uInt8List":[111,222],"uInt16List":[33333,44444],"uInt32List":[3333333333],"uInt64List":["11111111111111111111"],"float32List":[5555.5,"Infinity","-Infinity","NaN"],"float64List":[7777.75,"Infinity","-Infinity","NaN"],"textList":["plugh","xyzzy","thud"],"dataList":[[111,111,112,115],[101,120,104,97,117,115,116,101,100],[114,102,99,51,48,57,50]],"structList":[{"voidField":null,"boolField":false,"int8Field":0,"int16Field":0,"int32Field":0,"int64Field":"0","uInt8Field":0,"uInt16Field":0,"uInt32Field":0,"uInt64Field":"0","float32Field":0,"float64Field":0,"textField":"structlist 1","enumField":"foo","interfaceField":null},{"voidField":null,"boolField":false,"int8Field":0,"int16Field":0,"int32Field":0,"int64Field":"0","uInt8Field":0,"uInt16Field":0,"uInt32Field":0,"uInt64Field":"0","float32Field":0,"float64Field":0,"textField":"structlist 2","enumField":"foo","interfaceField":null},{"voidField":null,"boolField":false,"int8Field":0,"int16Field":0,"int32Field":0,"int64Field":"0","uInt8Field":0,"uInt16Field":0,"uInt32Field":0,"uInt64Field":"0","float32Field":0,"float64Field":0,"textField":"structlist 3","enumField":"foo","interfaceField":null}],"enumList":["foo","garply"]} diff --git a/c++/src/ekam-rules b/c++/src/ekam-rules new file mode 120000 index 0000000000..ff5b3b4f98 --- /dev/null +++ b/c++/src/ekam-rules @@ -0,0 +1 @@ +../deps/ekam/src/ekam/rules \ No newline at end of file diff --git a/c++/src/kj/BUILD.bazel b/c++/src/kj/BUILD.bazel new file mode 100644 index 0000000000..f5527bea95 --- /dev/null +++ b/c++/src/kj/BUILD.bazel @@ -0,0 +1,261 @@ +load("//:build/configure.bzl", "kj_configure") + +kj_configure() + +cc_library( + name = "kj", + srcs = [ + "arena.c++", + "array.c++", + "cidr.c++", + "common.c++", + "debug.c++", + "encoding.c++", + "exception.c++", + "filesystem.c++", + "filesystem-disk-unix.c++", + "filesystem-disk-win32.c++", + "hash.c++", + "io.c++", + "list.c++", + "main.c++", + "memory.c++", + "mutex.c++", + "parse/char.c++", + "refcount.c++", + "source-location.c++", + "string.c++", + "string-tree.c++", + "table.c++", + "test-helpers.c++", + "thread.c++", + "time.c++", + "units.c++", + ], + hdrs = [ + "arena.h", + "array.h", + "cidr.h", + "common.h", + "debug.h", + "encoding.h", + "exception.h", + "filesystem.h", + "function.h", + "hash.h", + "io.h", + "list.h", + "main.h", + "map.h", + "memory.h", + "miniposix.h", + "mutex.h", + "one-of.h", + "parse/char.h", + "parse/common.h", + "refcount.h", + "source-location.h", + "std/iostream.h", + "string.h", + "string-tree.h", + "table.h", + "test.h", + "thread.h", + "threadlocal.h", + "time.h", + "tuple.h", + "units.h", + "vector.h", + "win32-api-version.h", + "windows-sanity.h", + ], + include_prefix = "kj", + linkopts = select({ + "@platforms//os:windows": [], + ":use_libdl": [ + "-lpthread", + "-ldl", + ], + "//conditions:default": ["-lpthread"], + }), + visibility = ["//visibility:public"], + deps = [":kj-defines"], +) + +cc_library( + name = "kj-async", + srcs = [ + "async.c++", + "async-io.c++", + "async-io-unix.c++", + "async-io-win32.c++", + "async-unix.c++", + "async-win32.c++", + "timer.c++", + ], + hdrs = [ + "async.h", + "async-inl.h", + "async-io.h", + "async-io-internal.h", + "async-prelude.h", + "async-queue.h", + "async-unix.h", + "async-win32.h", + "timer.h", + ], + include_prefix = "kj", + linkopts = select({ + "@platforms//os:windows": [ + "Ws2_32.lib", + "Advapi32.lib", + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [":kj"], +) + +cc_library( + name = "kj-test", + srcs = [ + "test.c++", + ], + include_prefix = "kj", + visibility = ["//visibility:public"], + deps = [ + ":kj", + "//src/kj/compat:gtest", + ], +) + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + deps = [ + ":kj", + ":kj-async", + ":kj-test", + ], +) for f in [ + "arena-test.c++", + "array-test.c++", + "async-io-test.c++", + "async-queue-test.c++", + "async-test.c++", + "async-xthread-test.c++", + "common-test.c++", + "debug-test.c++", + "encoding-test.c++", + "exception-test.c++", + "filesystem-disk-test.c++", + "filesystem-test.c++", + "function-test.c++", + "io-test.c++", + "list-test.c++", + "map-test.c++", + "memory-test.c++", + "mutex-test.c++", + "one-of-test.c++", + "parse/char-test.c++", + "refcount-test.c++", + "std/iostream-test.c++", + "string-test.c++", + "string-tree-test.c++", + "table-test.c++", + "test-test.c++", + "threadlocal-test.c++", + "thread-test.c++", + "time-test.c++", + "tuple-test.c++", + "units-test.c++", +]] + +cc_test( + name = "async-coroutine-test", + srcs = ["async-coroutine-test.c++"], + target_compatible_with = select({ + ":use_coroutines": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kj-test", + "//src/kj/compat:kj-http", + ], +) + +cc_library( + name = "filesystem-disk-test-base", + hdrs = [ + "filesystem-disk-test.c++", + "filesystem-disk-unix.c++", + ], +) + +cc_test( + name = "filesystem-disk-generic-test", + srcs = ["filesystem-disk-generic-test.c++"], + deps = [ + ":filesystem-disk-test-base", + ":kj-test", + ], + target_compatible_with = [ + "@platforms//os:linux", + ], +) + +cc_test( + name = "filesystem-disk-old-kernel-test", + srcs = ["filesystem-disk-old-kernel-test.c++"], + deps = [ + ":filesystem-disk-test-base", + ":kj-test", + ], + target_compatible_with = [ + "@platforms//os:linux", + ], +) + +cc_test( + name = "async-os-test", + srcs = select({ + "@platforms//os:windows": ["async-win32-test.c++"], + "//conditions:default": ["async-unix-test.c++"], + }), + deps = [ + ":kj", + ":kj-async", + ":kj-test", + ], +) + +cc_library( + name = "async-os-xthread-test-base", + hdrs = ["async-xthread-test.c++"], +) + +cc_test( + name = "async-os-xthread-test", + srcs = select({ + "@platforms//os:windows": ["async-win32-xthread-test.c++"], + "//conditions:default": ["async-unix-xthread-test.c++"], + }), + deps = [ + ":async-os-xthread-test-base", + ":kj-async", + ":kj-test", + ], +) + +cc_test( + name = "exception-override-symbolizer-test", + srcs = ["exception-override-symbolizer-test.c++"], + deps = [ + ":kj", + ":kj-test", + ], + linkstatic = True, + target_compatible_with = [ + "@platforms//os:linux", + ], +) diff --git a/c++/src/kj/CMakeLists.txt b/c++/src/kj/CMakeLists.txt index 2b7da310ef..980c53e34c 100644 --- a/c++/src/kj/CMakeLists.txt +++ b/c++/src/kj/CMakeLists.txt @@ -3,6 +3,8 @@ set(kj_sources_lite array.c++ + cidr.c++ + list.c++ common.c++ debug.c++ exception.c++ @@ -10,15 +12,23 @@ set(kj_sources_lite memory.c++ mutex.c++ string.c++ + source-location.c++ + hash.c++ + table.c++ thread.c++ main.c++ arena.c++ test-helpers.c++ + units.c++ + encoding.c++ ) set(kj_sources_heavy - units.c++ refcount.c++ string-tree.c++ + time.c++ + filesystem.c++ + filesystem-disk-unix.c++ + filesystem-disk-win32.c++ parse/char.c++ ) if(NOT CAPNP_LITE) @@ -28,18 +38,24 @@ else() endif() set(kj_headers + cidr.h common.h units.h memory.h refcount.h array.h + list.h vector.h string.h string-tree.h + source-location.h + hash.h + table.h + map.h + encoding.h exception.h debug.h arena.h - miniposix.h io.h tuple.h one-of.h @@ -47,7 +63,10 @@ set(kj_headers mutex.h thread.h threadlocal.h + filesystem.h + time.h main.h + win32-api-version.h windows-sanity.h ) set(kj-parse_headers @@ -59,8 +78,8 @@ set(kj-std_headers ) add_library(kj ${kj_sources}) add_library(CapnProto::kj ALIAS kj) -target_compile_features(kj PUBLIC cxx_constexpr) -# Requiring the cxx_std_11 metafeature would be preferable, but that doesn't exist until CMake 3.8. +# TODO(cleanup): Use cxx_std_14 once it's safe to require cmake 3.8. +target_compile_features(kj PUBLIC cxx_generic_lambdas) if(UNIX AND NOT ANDROID) target_link_libraries(kj PUBLIC pthread) @@ -68,10 +87,13 @@ endif() #make sure the lite flag propagates to all users (internal + external) of this library target_compile_definitions(kj PUBLIC ${CAPNP_LITE_FLAG}) #make sure external consumers don't need to manually set the include dirs -target_include_directories(kj INTERFACE - $ +get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) +target_include_directories(kj PUBLIC + $ $ ) +# Ensure the library has a version set to match autotools build +set_target_properties(kj PROPERTIES VERSION ${VERSION}) install(TARGETS kj ${INSTALL_TARGETS_DEFAULT_ARGS}) install(FILES ${kj_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj") install(FILES ${kj-parse_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/parse") @@ -89,6 +111,8 @@ set(kj-test-compat_headers add_library(kj-test ${kj-test_sources}) add_library(CapnProto::kj-test ALIAS kj-test) target_link_libraries(kj-test PUBLIC kj) +# Ensure the library has a version set to match autotools build +set_target_properties(kj-test PROPERTIES VERSION ${VERSION}) install(TARGETS kj-test ${INSTALL_TARGETS_DEFAULT_ARGS}) install(FILES ${kj-test_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj") install(FILES ${kj-test-compat_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/compat") @@ -100,7 +124,7 @@ set(kj-async_sources async-io-win32.c++ async-io.c++ async-io-unix.c++ - time.c++ + timer.c++ ) set(kj-async_headers async-prelude.h @@ -109,18 +133,31 @@ set(kj-async_headers async-unix.h async-win32.h async-io.h - time.h + async-queue.h + cidr.h + timer.h ) if(NOT CAPNP_LITE) add_library(kj-async ${kj-async_sources}) add_library(CapnProto::kj-async ALIAS kj-async) target_link_libraries(kj-async PUBLIC kj) + if(WITH_FIBERS) + target_compile_definitions(kj-async PUBLIC KJ_USE_FIBERS) + if(_WITH_LIBUCONTEXT) + target_link_libraries(kj-async PUBLIC PkgConfig::libucontext) + endif() + else() + target_compile_definitions(kj-async PUBLIC KJ_USE_FIBERS=0) + endif() + if(UNIX) # external clients of this library need to link to pthreads target_compile_options(kj-async INTERFACE "-pthread") elseif(WIN32) target_link_libraries(kj-async PUBLIC ws2_32) endif() + # Ensure the library has a version set to match autotools build + set_target_properties(kj-async PROPERTIES VERSION ${VERSION}) install(TARGETS kj-async ${INSTALL_TARGETS_DEFAULT_ARGS}) install(FILES ${kj-async_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj") endif() @@ -128,19 +165,76 @@ endif() # kj-http ====================================================================== set(kj-http_sources + compat/url.c++ compat/http.c++ ) set(kj-http_headers + compat/url.h compat/http.h ) if(NOT CAPNP_LITE) add_library(kj-http ${kj-http_sources}) add_library(CapnProto::kj-http ALIAS kj-http) - target_link_libraries(kj-http PUBLIC kj-async kj) + if(WITH_ZLIB) + target_compile_definitions(kj-http PRIVATE KJ_HAS_ZLIB) + target_link_libraries(kj-http PUBLIC kj-async kj ZLIB::ZLIB) + else() + target_link_libraries(kj-http PUBLIC kj-async kj) + endif() + # Ensure the library has a version set to match autotools build + set_target_properties(kj-http PROPERTIES VERSION ${VERSION}) install(TARGETS kj-http ${INSTALL_TARGETS_DEFAULT_ARGS}) install(FILES ${kj-http_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/compat") endif() +# kj-tls ====================================================================== +if(WITH_OPENSSL) + set(kj-tls_sources + compat/readiness-io.c++ + compat/tls.c++ + ) + set(kj-tls_headers + compat/readiness-io.h + compat/tls.h + ) + if(NOT CAPNP_LITE) + add_library(kj-tls ${kj-tls_sources}) + add_library(CapnProto::kj-tls ALIAS kj-tls) + target_link_libraries(kj-tls PUBLIC kj-async) + + target_compile_definitions(kj-tls PRIVATE KJ_HAS_OPENSSL) + target_link_libraries(kj-tls PRIVATE OpenSSL::SSL OpenSSL::Crypto) + + # Ensure the library has a version set to match autotools build + set_target_properties(kj-tls PROPERTIES VERSION ${VERSION}) + install(TARGETS kj-tls ${INSTALL_TARGETS_DEFAULT_ARGS}) + install(FILES ${kj-tls_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/compat") + endif() +endif() + +# kj-gzip ====================================================================== + +if(WITH_ZLIB) + set(kj-gzip_sources + compat/gzip.c++ + ) + set(kj-gzip_headers + compat/gzip.h + ) + if(NOT CAPNP_LITE) + add_library(kj-gzip ${kj-gzip_sources}) + add_library(CapnProto::kj-gzip ALIAS kj-gzip) + + target_compile_definitions(kj-gzip PRIVATE KJ_HAS_ZLIB) + target_link_libraries(kj-gzip PUBLIC kj-async kj ZLIB::ZLIB) + + # Ensure the library has a version set to match autotools build + set_target_properties(kj-gzip PROPERTIES VERSION ${VERSION}) + install(TARGETS kj-gzip ${INSTALL_TARGETS_DEFAULT_ARGS}) + install(FILES ${kj-gzip_headers} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kj/compat") + endif() +endif() + # Tests ======================================================================== if(BUILD_TESTING) @@ -148,11 +242,17 @@ if(BUILD_TESTING) common-test.c++ memory-test.c++ array-test.c++ + list-test.c++ string-test.c++ + table-test.c++ + map-test.c++ exception-test.c++ + # this test overrides symbolizer and has to be linked separately + # exception-override-symbolizer-test.c++ debug-test.c++ io-test.c++ mutex-test.c++ + time-test.c++ threadlocal-test.c++ test-test.c++ std/iostream-test.c++ @@ -165,22 +265,49 @@ if(BUILD_TESTING) if(NOT CAPNP_LITE) add_executable(kj-heavy-tests async-test.c++ + async-xthread-test.c++ + async-coroutine-test.c++ async-unix-test.c++ + async-unix-xthread-test.c++ async-win32-test.c++ + async-win32-xthread-test.c++ async-io-test.c++ + async-queue-test.c++ refcount-test.c++ string-tree-test.c++ + encoding-test.c++ arena-test.c++ units-test.c++ tuple-test.c++ one-of-test.c++ function-test.c++ - threadlocal-pthread-test.c++ + filesystem-test.c++ + filesystem-disk-test.c++ parse/common-test.c++ parse/char-test.c++ + compat/url-test.c++ compat/http-test.c++ + compat/gzip-test.c++ + compat/tls-test.c++ ) target_link_libraries(kj-heavy-tests kj-http kj-async kj-test kj) + if(WITH_OPENSSL) + target_link_libraries(kj-heavy-tests kj-tls) + # tls-test.c++ needs to use OpenSSL directly. + target_link_libraries(kj-heavy-tests OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(kj-heavy-tests PRIVATE KJ_HAS_OPENSSL) + set_property( + SOURCE compat/tls-test.c++ + APPEND PROPERTY COMPILE_DEFINITIONS KJ_HAS_OPENSSL + ) + endif() + if(WITH_ZLIB) + target_link_libraries(kj-heavy-tests kj-gzip) + set_property( + SOURCE compat/gzip-test.c++ + APPEND PROPERTY COMPILE_DEFINITIONS KJ_HAS_ZLIB + ) + endif() add_dependencies(check kj-heavy-tests) add_test(NAME kj-heavy-tests-run COMMAND kj-heavy-tests) endif() # NOT CAPNP_LITE diff --git a/c++/src/kj/arena.h b/c++/src/kj/arena.h index 32c1f61c51..a16b291121 100644 --- a/c++/src/kj/arena.h +++ b/c++/src/kj/arena.h @@ -19,17 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_ARENA_H_ -#define KJ_ARENA_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "memory.h" #include "array.h" #include "string.h" +KJ_BEGIN_HEADER + namespace kj { class Arena { @@ -50,7 +47,7 @@ class Arena { explicit Arena(ArrayPtr scratch); // Allocates from the given scratch space first, only resorting to the heap when it runs out. - KJ_DISALLOW_COPY(Arena); + KJ_DISALLOW_COPY_AND_MOVE(Arena); ~Arena() noexcept(false); template @@ -137,11 +134,11 @@ class Arena { template T& Arena::allocate(Params&&... params) { T& result = *reinterpret_cast(allocateBytes( - sizeof(T), alignof(T), !__has_trivial_destructor(T))); - if (!__has_trivial_constructor(T) || sizeof...(Params) > 0) { + sizeof(T), alignof(T), !KJ_HAS_TRIVIAL_DESTRUCTOR(T))); + if (!KJ_HAS_TRIVIAL_CONSTRUCTOR(T) || sizeof...(Params) > 0) { ctor(result, kj::fwd(params)...); } - if (!__has_trivial_destructor(T)) { + if (!KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { setDestructor(&result, &destroyObject); } return result; @@ -149,11 +146,11 @@ T& Arena::allocate(Params&&... params) { template ArrayPtr Arena::allocateArray(size_t size) { - if (__has_trivial_destructor(T)) { + if (KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { ArrayPtr result = arrayPtr(reinterpret_cast(allocateBytes( sizeof(T) * size, alignof(T), false)), size); - if (!__has_trivial_constructor(T)) { + if (!KJ_HAS_TRIVIAL_CONSTRUCTOR(T)) { for (size_t i = 0; i < size; i++) { ctor(result[i]); } @@ -168,7 +165,7 @@ ArrayPtr Arena::allocateArray(size_t size) { arrayPtr(reinterpret_cast(reinterpret_cast(base) + prefixSize), size); setDestructor(base, &destroyArray); - if (__has_trivial_constructor(T)) { + if (KJ_HAS_TRIVIAL_CONSTRUCTOR(T)) { tag = size; } else { // In case of constructor exceptions, we need the tag to end up storing the number of objects @@ -186,7 +183,7 @@ ArrayPtr Arena::allocateArray(size_t size) { template Own Arena::allocateOwn(Params&&... params) { T& result = *reinterpret_cast(allocateBytes(sizeof(T), alignof(T), false)); - if (!__has_trivial_constructor(T) || sizeof...(Params) > 0) { + if (!KJ_HAS_TRIVIAL_CONSTRUCTOR(T) || sizeof...(Params) > 0) { ctor(result, kj::fwd(params)...); } return Own(&result, DestructorOnlyDisposer::instance); @@ -210,4 +207,4 @@ ArrayBuilder Arena::allocateOwnArrayBuilder(size_t capacity) { } // namespace kj -#endif // KJ_ARENA_H_ +KJ_END_HEADER diff --git a/c++/src/kj/array-test.c++ b/c++/src/kj/array-test.c++ index 17ae15b8c0..07aadbd1b1 100644 --- a/c++/src/kj/array-test.c++ +++ b/c++/src/kj/array-test.c++ @@ -378,5 +378,142 @@ TEST(Array, ReleaseAsBytesOrChars) { } } +#if __cplusplus > 201402L +KJ_TEST("kj::arr()") { + kj::Array array = kj::arr(kj::str("foo"), kj::str(123)); + KJ_EXPECT(array == kj::ArrayPtr({"foo", "123"})); +} + +struct ImmovableInt { + ImmovableInt(int i): i(i) {} + KJ_DISALLOW_COPY_AND_MOVE(ImmovableInt); + int i; +}; + +KJ_TEST("kj::arrOf()") { + kj::Array array = kj::arrOf(123, 456, 789); + KJ_ASSERT(array.size() == 3); + KJ_EXPECT(array[0].i == 123); + KJ_EXPECT(array[1].i == 456); + KJ_EXPECT(array[2].i == 789); +} +#endif + +struct DestructionOrderRecorder { + DestructionOrderRecorder(uint& counter, uint& recordTo) + : counter(counter), recordTo(recordTo) {} + ~DestructionOrderRecorder() { + recordTo = ++counter; + } + + uint& counter; + uint& recordTo; +}; + +TEST(Array, Attach) { + uint counter = 0; + uint destroyed1 = 0; + uint destroyed2 = 0; + uint destroyed3 = 0; + + auto obj1 = kj::heap(counter, destroyed1); + auto obj2 = kj::heap(counter, destroyed2); + auto obj3 = kj::heap(counter, destroyed3); + + auto builder = kj::heapArrayBuilder>(1); + builder.add(kj::mv(obj1)); + auto arr = builder.finish(); + auto ptr = arr.begin(); + + Array> combined = arr.attach(kj::mv(obj2), kj::mv(obj3)); + + KJ_EXPECT(combined.begin() == ptr); + + KJ_EXPECT(obj1.get() == nullptr); + KJ_EXPECT(obj2.get() == nullptr); + KJ_EXPECT(obj3.get() == nullptr); + KJ_EXPECT(destroyed1 == 0); + KJ_EXPECT(destroyed2 == 0); + KJ_EXPECT(destroyed3 == 0); + + combined = nullptr; + + KJ_EXPECT(destroyed1 == 1, destroyed1); + KJ_EXPECT(destroyed2 == 2, destroyed2); + KJ_EXPECT(destroyed3 == 3, destroyed3); +} + +TEST(Array, AttachNested) { + uint counter = 0; + uint destroyed1 = 0; + uint destroyed2 = 0; + uint destroyed3 = 0; + + auto obj1 = kj::heap(counter, destroyed1); + auto obj2 = kj::heap(counter, destroyed2); + auto obj3 = kj::heap(counter, destroyed3); + + auto builder = kj::heapArrayBuilder>(1); + builder.add(kj::mv(obj1)); + auto arr = builder.finish(); + auto ptr = arr.begin(); + + Array> combined = arr.attach(kj::mv(obj2)).attach(kj::mv(obj3)); + + KJ_EXPECT(combined.begin() == ptr); + KJ_EXPECT(combined.size() == 1); + + KJ_EXPECT(obj1.get() == nullptr); + KJ_EXPECT(obj2.get() == nullptr); + KJ_EXPECT(obj3.get() == nullptr); + KJ_EXPECT(destroyed1 == 0); + KJ_EXPECT(destroyed2 == 0); + KJ_EXPECT(destroyed3 == 0); + + combined = nullptr; + + KJ_EXPECT(destroyed1 == 1, destroyed1); + KJ_EXPECT(destroyed2 == 2, destroyed2); + KJ_EXPECT(destroyed3 == 3, destroyed3); +} + +TEST(Array, AttachFromArrayPtr) { + uint counter = 0; + uint destroyed1 = 0; + uint destroyed2 = 0; + uint destroyed3 = 0; + + auto obj1 = kj::heap(counter, destroyed1); + auto obj2 = kj::heap(counter, destroyed2); + auto obj3 = kj::heap(counter, destroyed3); + + auto builder = kj::heapArrayBuilder>(1); + builder.add(kj::mv(obj1)); + auto arr = builder.finish(); + auto ptr = arr.begin(); + + Array> combined = + arr.asPtr().attach(kj::mv(obj2)).attach(kj::mv(obj3)); + KJ_EXPECT(arr != nullptr); + + KJ_EXPECT(combined.begin() == ptr); + + KJ_EXPECT(obj1.get() == nullptr); + KJ_EXPECT(obj2.get() == nullptr); + KJ_EXPECT(obj3.get() == nullptr); + KJ_EXPECT(destroyed1 == 0); + KJ_EXPECT(destroyed2 == 0); + KJ_EXPECT(destroyed3 == 0); + + combined = nullptr; + + KJ_EXPECT(destroyed2 == 1, destroyed2); + KJ_EXPECT(destroyed3 == 2, destroyed3); + + arr = nullptr; + + KJ_EXPECT(destroyed1 == 3, destroyed1); +} + } // namespace } // namespace kj diff --git a/c++/src/kj/array.h b/c++/src/kj/array.h index 51b5dcf319..55f7ace8e0 100644 --- a/c++/src/kj/array.h +++ b/c++/src/kj/array.h @@ -19,17 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_ARRAY_H_ -#define KJ_ARRAY_H_ +#pragma once -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif - -#include "common.h" +#include "memory.h" #include #include +KJ_BEGIN_HEADER + namespace kj { // ======================================================================================= @@ -59,7 +56,7 @@ class ArrayDisposer { // an exception. private: - template + template struct Dispose_; }; @@ -77,7 +74,7 @@ class ExceptionSafeArrayUtil { : pos(reinterpret_cast(ptr) + elementSize * constructedElementCount), elementSize(elementSize), constructedElementCount(constructedElementCount), destroyElement(destroyElement) {} - KJ_DISALLOW_COPY(ExceptionSafeArrayUtil); + KJ_DISALLOW_COPY_AND_MOVE(ExceptionSafeArrayUtil); inline ~ExceptionSafeArrayUtil() noexcept(false) { if (constructedElementCount > 0) destroyAll(); @@ -143,53 +140,62 @@ class Array { other.ptr = nullptr; other.size_ = 0; } - inline Array(T* firstElement, size_t size, const ArrayDisposer& disposer) + inline Array(T* firstElement KJ_LIFETIMEBOUND, size_t size, const ArrayDisposer& disposer) : ptr(firstElement), size_(size), disposer(&disposer) {} KJ_DISALLOW_COPY(Array); inline ~Array() noexcept { dispose(); } - inline operator ArrayPtr() { + inline operator ArrayPtr() KJ_LIFETIMEBOUND { return ArrayPtr(ptr, size_); } - inline operator ArrayPtr() const { + inline operator ArrayPtr() const KJ_LIFETIMEBOUND { return ArrayPtr(ptr, size_); } - inline ArrayPtr asPtr() { + inline ArrayPtr asPtr() KJ_LIFETIMEBOUND { return ArrayPtr(ptr, size_); } - inline ArrayPtr asPtr() const { + inline ArrayPtr asPtr() const KJ_LIFETIMEBOUND { return ArrayPtr(ptr, size_); } inline size_t size() const { return size_; } - inline T& operator[](size_t index) const { + inline T& operator[](size_t index) KJ_LIFETIMEBOUND { + KJ_IREQUIRE(index < size_, "Out-of-bounds Array access."); + return ptr[index]; + } + inline const T& operator[](size_t index) const KJ_LIFETIMEBOUND { KJ_IREQUIRE(index < size_, "Out-of-bounds Array access."); return ptr[index]; } - inline const T* begin() const { return ptr; } - inline const T* end() const { return ptr + size_; } - inline const T& front() const { return *ptr; } - inline const T& back() const { return *(ptr + size_ - 1); } - inline T* begin() { return ptr; } - inline T* end() { return ptr + size_; } - inline T& front() { return *ptr; } - inline T& back() { return *(ptr + size_ - 1); } + inline const T* begin() const KJ_LIFETIMEBOUND { return ptr; } + inline const T* end() const KJ_LIFETIMEBOUND { return ptr + size_; } + inline const T& front() const KJ_LIFETIMEBOUND { return *ptr; } + inline const T& back() const KJ_LIFETIMEBOUND { return *(ptr + size_ - 1); } + inline T* begin() KJ_LIFETIMEBOUND { return ptr; } + inline T* end() KJ_LIFETIMEBOUND { return ptr + size_; } + inline T& front() KJ_LIFETIMEBOUND { return *ptr; } + inline T& back() KJ_LIFETIMEBOUND { return *(ptr + size_ - 1); } - inline ArrayPtr slice(size_t start, size_t end) { + template + inline bool operator==(const U& other) const { return asPtr() == other; } + template + inline bool operator!=(const U& other) const { return asPtr() != other; } + + inline ArrayPtr slice(size_t start, size_t end) KJ_LIFETIMEBOUND { KJ_IREQUIRE(start <= end && end <= size_, "Out-of-bounds Array::slice()."); return ArrayPtr(ptr + start, end - start); } - inline ArrayPtr slice(size_t start, size_t end) const { + inline ArrayPtr slice(size_t start, size_t end) const KJ_LIFETIMEBOUND { KJ_IREQUIRE(start <= end && end <= size_, "Out-of-bounds Array::slice()."); return ArrayPtr(ptr + start, end - start); } - inline ArrayPtr asBytes() const { return asPtr().asBytes(); } - inline ArrayPtr> asBytes() { return asPtr().asBytes(); } - inline ArrayPtr asChars() const { return asPtr().asChars(); } - inline ArrayPtr> asChars() { return asPtr().asChars(); } + inline ArrayPtr asBytes() const KJ_LIFETIMEBOUND { return asPtr().asBytes(); } + inline ArrayPtr> asBytes() KJ_LIFETIMEBOUND { return asPtr().asBytes(); } + inline ArrayPtr asChars() const KJ_LIFETIMEBOUND { return asPtr().asChars(); } + inline ArrayPtr> asChars() KJ_LIFETIMEBOUND { return asPtr().asChars(); } inline Array> releaseAsBytes() { // Like asBytes() but transfers ownership. @@ -230,6 +236,10 @@ class Array { return *this; } + template + Array attach(Attachments&&... attachments) KJ_WARN_UNUSED_RESULT; + // Like Own::attach(), but attaches to an Array. + private: T* ptr; size_t size_; @@ -249,6 +259,8 @@ class Array { template friend class Array; + template + friend class ArrayBuilder; }; static_assert(!canMemcpy>(), "canMemcpy<>() is broken"); @@ -273,8 +285,8 @@ class HeapArrayDisposer final: public ArrayDisposer { virtual void disposeImpl(void* firstElement, size_t elementSize, size_t elementCount, size_t capacity, void (*destroyElement)(void*)) const override; - template + template struct Allocate_; }; @@ -321,37 +333,48 @@ class ArrayBuilder { other.pos = nullptr; other.endPtr = nullptr; } + ArrayBuilder(Array&& other) + : ptr(other.ptr), pos(other.ptr + other.size_), endPtr(pos), disposer(other.disposer) { + // Create an already-full ArrayBuilder from an Array of the same type. This constructor + // primarily exists to enable Vector to be constructed from Array. + other.ptr = nullptr; + other.size_ = 0; + } KJ_DISALLOW_COPY(ArrayBuilder); inline ~ArrayBuilder() noexcept(false) { dispose(); } - inline operator ArrayPtr() { + inline operator ArrayPtr() KJ_LIFETIMEBOUND { return arrayPtr(ptr, pos); } - inline operator ArrayPtr() const { + inline operator ArrayPtr() const KJ_LIFETIMEBOUND { return arrayPtr(ptr, pos); } - inline ArrayPtr asPtr() { + inline ArrayPtr asPtr() KJ_LIFETIMEBOUND { return arrayPtr(ptr, pos); } - inline ArrayPtr asPtr() const { + inline ArrayPtr asPtr() const KJ_LIFETIMEBOUND { return arrayPtr(ptr, pos); } inline size_t size() const { return pos - ptr; } inline size_t capacity() const { return endPtr - ptr; } - inline T& operator[](size_t index) const { + inline T& operator[](size_t index) KJ_LIFETIMEBOUND { + KJ_IREQUIRE(index < implicitCast(pos - ptr), "Out-of-bounds Array access."); + return ptr[index]; + } + inline const T& operator[](size_t index) const KJ_LIFETIMEBOUND { KJ_IREQUIRE(index < implicitCast(pos - ptr), "Out-of-bounds Array access."); return ptr[index]; } - inline const T* begin() const { return ptr; } - inline const T* end() const { return pos; } - inline const T& front() const { return *ptr; } - inline const T& back() const { return *(pos - 1); } - inline T* begin() { return ptr; } - inline T* end() { return pos; } - inline T& front() { return *ptr; } - inline T& back() { return *(pos - 1); } + inline const T* begin() const KJ_LIFETIMEBOUND { return ptr; } + inline const T* end() const KJ_LIFETIMEBOUND { return pos; } + inline const T& front() const KJ_LIFETIMEBOUND { return *ptr; } + inline const T& back() const KJ_LIFETIMEBOUND { return *(pos - 1); } + inline T* begin() KJ_LIFETIMEBOUND { return ptr; } + inline T* end() KJ_LIFETIMEBOUND { return pos; } + inline T& front() KJ_LIFETIMEBOUND { return *ptr; } + inline T& back() KJ_LIFETIMEBOUND { return *(pos - 1); } ArrayBuilder& operator=(ArrayBuilder&& other) { dispose(); @@ -370,7 +393,7 @@ class ArrayBuilder { } template - T& add(Params&&... params) { + T& add(Params&&... params) KJ_LIFETIMEBOUND { KJ_IREQUIRE(pos < endPtr, "Added too many elements to ArrayBuilder."); ctor(*pos, kj::fwd(params)...); return *pos++; @@ -394,7 +417,7 @@ class ArrayBuilder { KJ_IREQUIRE(size <= this->size(), "can't use truncate() to expand"); T* target = ptr + size; - if (__has_trivial_destructor(T)) { + if (KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { pos = target; } else { while (pos > target) { @@ -403,13 +426,23 @@ class ArrayBuilder { } } + void clear() { + if (KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { + pos = ptr; + } else { + while (pos > ptr) { + kj::dtor(*--pos); + } + } + } + void resize(size_t size) { KJ_IREQUIRE(size <= capacity(), "can't resize past capacity"); T* target = ptr + size; if (target > pos) { // expand - if (__has_trivial_constructor(T)) { + if (KJ_HAS_TRIVIAL_CONSTRUCTOR(T)) { pos = target; } else { while (pos < target) { @@ -418,7 +451,7 @@ class ArrayBuilder { } } else { // truncate - if (__has_trivial_destructor(T)) { + if (KJ_HAS_TRIVIAL_DESTRUCTOR(T)) { pos = target; } else { while (pos > target) { @@ -430,7 +463,7 @@ class ArrayBuilder { Array finish() { // We could safely remove this check if we assume that the disposer implementation doesn't - // need to know the original capacity, as is thes case with HeapArrayDisposer since it uses + // need to know the original capacity, as is the case with HeapArrayDisposer since it uses // operator new() or if we created a custom disposer for ArrayBuilder which stores the capacity // in a prefix. But that would make it hard to write cleverer heap allocators, and anyway this // check might catch bugs. Probably people should use Vector if they want to build arrays @@ -451,7 +484,7 @@ class ArrayBuilder { T* ptr; RemoveConst* pos; T* endPtr; - const ArrayDisposer* disposer; + const ArrayDisposer* disposer = &NullArrayDisposer::instance; inline void dispose() { // Make sure that if an exception is thrown, we are left with a null ptr, so we won't possibly @@ -485,21 +518,23 @@ class FixedArray { // A fixed-width array whose storage is allocated inline rather than on the heap. public: - inline size_t size() const { return fixedSize; } - inline T* begin() { return content; } - inline T* end() { return content + fixedSize; } - inline const T* begin() const { return content; } - inline const T* end() const { return content + fixedSize; } + inline constexpr size_t size() const { return fixedSize; } + inline constexpr T* begin() KJ_LIFETIMEBOUND { return content; } + inline constexpr T* end() KJ_LIFETIMEBOUND { return content + fixedSize; } + inline constexpr const T* begin() const KJ_LIFETIMEBOUND { return content; } + inline constexpr const T* end() const KJ_LIFETIMEBOUND { return content + fixedSize; } - inline operator ArrayPtr() { + inline constexpr operator ArrayPtr() KJ_LIFETIMEBOUND { return arrayPtr(content, fixedSize); } - inline operator ArrayPtr() const { + inline constexpr operator ArrayPtr() const KJ_LIFETIMEBOUND { return arrayPtr(content, fixedSize); } - inline T& operator[](size_t index) { return content[index]; } - inline const T& operator[](size_t index) const { return content[index]; } + inline constexpr T& operator[](size_t index) KJ_LIFETIMEBOUND { return content[index]; } + inline constexpr const T& operator[](size_t index) const KJ_LIFETIMEBOUND { + return content[index]; + } private: T content[fixedSize]; @@ -518,20 +553,20 @@ class CappedArray { inline size_t size() const { return currentSize; } inline void setSize(size_t s) { KJ_IREQUIRE(s <= fixedSize); currentSize = s; } - inline T* begin() { return content; } - inline T* end() { return content + currentSize; } - inline const T* begin() const { return content; } - inline const T* end() const { return content + currentSize; } + inline T* begin() KJ_LIFETIMEBOUND { return content; } + inline T* end() KJ_LIFETIMEBOUND { return content + currentSize; } + inline const T* begin() const KJ_LIFETIMEBOUND { return content; } + inline const T* end() const KJ_LIFETIMEBOUND { return content + currentSize; } - inline operator ArrayPtr() { + inline operator ArrayPtr() KJ_LIFETIMEBOUND { return arrayPtr(content, currentSize); } - inline operator ArrayPtr() const { + inline operator ArrayPtr() const KJ_LIFETIMEBOUND { return arrayPtr(content, currentSize); } - inline T& operator[](size_t index) { return content[index]; } - inline const T& operator[](size_t index) const { return content[index]; } + inline T& operator[](size_t index) KJ_LIFETIMEBOUND { return content[index]; } + inline const T& operator[](size_t index) const KJ_LIFETIMEBOUND { return content[index]; } private: size_t currentSize; @@ -604,7 +639,8 @@ struct ArrayDisposer::Dispose_ { static void dispose(T* firstElement, size_t elementCount, size_t capacity, const ArrayDisposer& disposer) { - disposer.disposeImpl(firstElement, sizeof(T), elementCount, capacity, &destruct); + disposer.disposeImpl(const_cast*>(firstElement), + sizeof(T), elementCount, capacity, &destruct); } }; @@ -662,7 +698,9 @@ struct CopyConstructArray_; template struct CopyConstructArray_ { static inline T* apply(T* __restrict__ pos, T* start, T* end) { - memcpy(pos, start, reinterpret_cast(end) - reinterpret_cast(start)); + if (end != start) { + memcpy(pos, start, reinterpret_cast(end) - reinterpret_cast(start)); + } return pos + (end - start); } }; @@ -670,7 +708,9 @@ struct CopyConstructArray_ { template struct CopyConstructArray_ { static inline T* apply(T* __restrict__ pos, const T* start, const T* end) { - memcpy(pos, start, reinterpret_cast(end) - reinterpret_cast(start)); + if (end != start) { + memcpy(pos, start, reinterpret_cast(end) - reinterpret_cast(start)); + } return pos + (end - start); } }; @@ -808,6 +848,66 @@ inline Array heapArray(std::initializer_list init) { return heapArray(init.begin(), init.end()); } +#if __cplusplus > 201402L +template +inline Array> arr(T&& param1, Params&&... params) { + ArrayBuilder> builder = heapArrayBuilder>(sizeof...(params) + 1); + (builder.add(kj::fwd(param1)), ... , builder.add(kj::fwd(params))); + return builder.finish(); +} +template +inline Array> arrOf(Params&&... params) { + ArrayBuilder> builder = heapArrayBuilder>(sizeof...(params)); + (... , builder.add(kj::fwd(params))); + return builder.finish(); +} +#endif + +namespace _ { // private + +template +struct ArrayDisposableOwnedBundle final: public ArrayDisposer, public OwnedBundle { + ArrayDisposableOwnedBundle(T&&... values): OwnedBundle(kj::fwd(values)...) {} + void disposeImpl(void*, size_t, size_t, size_t, void (*)(void*)) const override { delete this; } +}; + +} // namespace _ (private) + +template +template +Array Array::attach(Attachments&&... attachments) { + T* ptrCopy = ptr; + auto sizeCopy = size_; + + KJ_IREQUIRE(ptrCopy != nullptr, "cannot attach to null pointer"); + + // HACK: If someone accidentally calls .attach() on a null pointer in opt mode, try our best to + // accomplish reasonable behavior: We turn the pointer non-null but still invalid, so that the + // disposer will still be called when the pointer goes out of scope. + if (ptrCopy == nullptr) ptrCopy = reinterpret_cast(1); + + auto bundle = new _::ArrayDisposableOwnedBundle, Attachments...>( + kj::mv(*this), kj::fwd(attachments)...); + return Array(ptrCopy, sizeCopy, *bundle); +} + +template +template +Array ArrayPtr::attach(Attachments&&... attachments) const { + T* ptrCopy = ptr; + + KJ_IREQUIRE(ptrCopy != nullptr, "cannot attach to null pointer"); + + // HACK: If someone accidentally calls .attach() on a null pointer in opt mode, try our best to + // accomplish reasonable behavior: We turn the pointer non-null but still invalid, so that the + // disposer will still be called when the pointer goes out of scope. + if (ptrCopy == nullptr) ptrCopy = reinterpret_cast(1); + + auto bundle = new _::ArrayDisposableOwnedBundle( + kj::fwd(attachments)...); + return Array(ptrCopy, size_, *bundle); +} + } // namespace kj -#endif // KJ_ARRAY_H_ +KJ_END_HEADER diff --git a/c++/src/kj/async-coroutine-test.c++ b/c++/src/kj/async-coroutine-test.c++ new file mode 100644 index 0000000000..6d76b608bf --- /dev/null +++ b/c++/src/kj/async-coroutine-test.c++ @@ -0,0 +1,578 @@ +// Copyright (c) 2020 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include +#include +#include +#include +#include + +namespace kj { +namespace { + +#ifdef KJ_HAS_COROUTINE + +template +Promise> identity(T&& value) { + co_return kj::fwd(value); +} +// Work around a bonkers MSVC ICE with a separate overload. +Promise identity(const char* value) { + co_return value; +} + +KJ_TEST("Identity coroutine") { + EventLoop loop; + WaitScope waitScope(loop); + + KJ_EXPECT(identity(123).wait(waitScope) == 123); + KJ_EXPECT(*identity(kj::heap(456)).wait(waitScope) == 456); + + { + auto p = identity("we can cancel the coroutine"); + } +} + +template +Promise simpleCoroutine(kj::Promise result, kj::Promise dontThrow = true) { + KJ_ASSERT(co_await dontThrow); + co_return co_await result; +} + +KJ_TEST("Simple coroutine test") { + EventLoop loop; + WaitScope waitScope(loop); + + simpleCoroutine(kj::Promise(kj::READY_NOW)).wait(waitScope); + + KJ_EXPECT(simpleCoroutine(kj::Promise(123)).wait(waitScope) == 123); +} + +struct Counter { + size_t& wind; + size_t& unwind; + Counter(size_t& wind, size_t& unwind): wind(wind), unwind(unwind) { ++wind; } + ~Counter() { ++unwind; } + KJ_DISALLOW_COPY_AND_MOVE(Counter); +}; + +kj::Promise countAroundAwait(size_t& wind, size_t& unwind, kj::Promise promise) { + Counter counter1(wind, unwind); + co_await promise; + Counter counter2(wind, unwind); + co_return; +}; + +KJ_TEST("co_awaiting initial immediate promises suspends even if event loop is empty and running") { + // The coroutine PromiseNode implementation contains an optimization which allows us to avoid + // suspending the coroutine and instead immediately call PromiseNode::get() and proceed with + // execution, but only if the coroutine has suspended at least once. This test verifies that the + // optimization is disabled for this initial suspension. + + EventLoop loop; + WaitScope waitScope(loop); + + // The immediate-execution optimization is only enabled when the event loop is running, so use an + // eagerly-evaluated evalLater() to perform the test from within the event loop. (If we didn't + // eagerly-evaluate the promise, the result would be extracted after the loop finished, which + // would disable the optimization anyway.) + kj::evalLater([&]() { + size_t wind = 0, unwind = 0; + + auto promise = kj::Promise(kj::READY_NOW); + auto coroPromise = countAroundAwait(wind, unwind, kj::READY_NOW); + + // `coro` has not completed. + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); + }).eagerlyEvaluate(nullptr).wait(waitScope); + + kj::evalLater([&]() { + // If there are no background tasks in the queue, coroutines execute through an evalLater() + // without suspending. + + size_t wind = 0, unwind = 0; + bool evalLaterRan = false; + + auto promise = kj::evalLater([&]() { evalLaterRan = true; }); + auto coroPromise = countAroundAwait(wind, unwind, kj::mv(promise)); + + KJ_EXPECT(evalLaterRan == false); + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); + }).eagerlyEvaluate(nullptr).wait(waitScope); +} + +KJ_TEST("co_awaiting an immediate promise suspends if the event loop is not running") { + // We only want to enable the immediate-execution optimization if the event loop is running, or + // else a whole bunch of RPC tests break, because some .then()s get evaluated on promise + // construction, before any .wait() call. + + EventLoop loop; + WaitScope waitScope(loop); + + size_t wind = 0, unwind = 0; + + auto promise = kj::Promise(kj::READY_NOW); + auto coroPromise = countAroundAwait(wind, unwind, kj::READY_NOW); + + // In the previous test, this exact same code executed immediately because the event loop was + // running. + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); +} + +KJ_TEST("co_awaiting immediate promises suspends if the event loop is not empty") { + // We want to make sure that we can still return to the event loop when we need to. + + EventLoop loop; + WaitScope waitScope(loop); + + // The immediate-execution optimization is only enabled when the event loop is running, so use an + // eagerly-evaluated evalLater() to perform the test from within the event loop. (If we didn't + // eagerly-evaluate the promise, the result would be extracted after the loop finished.) + kj::evalLater([&]() { + size_t wind = 0, unwind = 0; + + // We need to enqueue an Event on the event loop to inhibit the immediate-execution + // optimization. Creating and then immediately fulfilling an EagerPromiseNode is a convenient + // way to do so. + auto paf = newPromiseAndFulfiller(); + paf.promise = paf.promise.eagerlyEvaluate(nullptr); + paf.fulfiller->fulfill(); + + auto promise = kj::Promise(kj::READY_NOW); + auto coroPromise = countAroundAwait(wind, unwind, kj::READY_NOW); + + // We didn't immediately extract the READY_NOW. + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); + }).eagerlyEvaluate(nullptr).wait(waitScope); + + kj::evalLater([&]() { + size_t wind = 0, unwind = 0; + bool evalLaterRan = false; + + // We need to enqueue an Event on the event loop to inhibit the immediate-execution + // optimization. Creating and then immediately fulfilling an EagerPromiseNode is a convenient + // way to do so. + auto paf = newPromiseAndFulfiller(); + paf.promise = paf.promise.eagerlyEvaluate(nullptr); + paf.fulfiller->fulfill(); + + auto promise = kj::evalLater([&]() { evalLaterRan = true; }); + auto coroPromise = countAroundAwait(wind, unwind, kj::mv(promise)); + + // We didn't continue through the evalLater() promise, because the background promise's + // continuation was next in the event loop's queue. + KJ_EXPECT(evalLaterRan == false); + // No Counter destructor has run. + KJ_EXPECT(wind == 1); + KJ_EXPECT(unwind == 0); + }).eagerlyEvaluate(nullptr).wait(waitScope); +} + +KJ_TEST("Exceptions propagate through layered coroutines") { + EventLoop loop; + WaitScope waitScope(loop); + + auto throwy = simpleCoroutine(kj::Promise(kj::NEVER_DONE), false); + + KJ_EXPECT_THROW_RECOVERABLE(FAILED, simpleCoroutine(kj::mv(throwy)).wait(waitScope)); +} + +KJ_TEST("Exceptions before the first co_await don't escape, but reject the promise") { + EventLoop loop; + WaitScope waitScope(loop); + + auto throwEarly = []() -> Promise { + KJ_FAIL_ASSERT("test exception"); +#ifdef __GNUC__ +// Yes, this `co_return` is unreachable. But without it, this function is no longer a coroutine. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunreachable-code" +#endif // __GNUC__ + co_return; +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif // __GNUC__ + }; + + auto throwy = throwEarly(); + + KJ_EXPECT_THROW_RECOVERABLE(FAILED, throwy.wait(waitScope)); +} + +KJ_TEST("Coroutines can catch exceptions from co_await") { + EventLoop loop; + WaitScope waitScope(loop); + + kj::String description; + + auto tryCatch = [&](kj::Promise promise) -> kj::Promise { + try { + co_await promise; + } catch (const kj::Exception& exception) { + co_return kj::str(exception.getDescription()); + } + KJ_FAIL_EXPECT("should have thrown"); + KJ_UNREACHABLE; + }; + + { + // Immediately ready case. + auto promise = kj::Promise(KJ_EXCEPTION(FAILED, "catch me")); + KJ_EXPECT(tryCatch(kj::mv(promise)).wait(waitScope) == "catch me"); + } + + { + // Ready later case. + auto promise = kj::evalLater([]() -> kj::Promise { + return KJ_EXCEPTION(FAILED, "catch me"); + }); + KJ_EXPECT(tryCatch(kj::mv(promise)).wait(waitScope) == "catch me"); + } +} + +KJ_TEST("Coroutines can be canceled while suspended") { + EventLoop loop; + WaitScope waitScope(loop); + + size_t wind = 0, unwind = 0; + + auto coro = [&](kj::Promise promise) -> kj::Promise { + Counter counter1(wind, unwind); + co_await kj::evalLater([](){}); + Counter counter2(wind, unwind); + co_await promise; + }; + + { + auto neverDone = kj::Promise(kj::NEVER_DONE); + neverDone = neverDone.attach(kj::heap(wind, unwind)); + auto promise = coro(kj::mv(neverDone)); + KJ_EXPECT(!promise.poll(waitScope)); + } + + // Stack variables on both sides of a co_await, plus coroutine arguments are destroyed. + KJ_EXPECT(wind == 3); + KJ_EXPECT(unwind == 3); +} + +kj::Promise deferredThrowCoroutine(kj::Promise awaitMe) { + KJ_DEFER(kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown during unwind"))); + co_await awaitMe; + co_return; +}; + +KJ_TEST("Exceptions during suspended coroutine frame-unwind propagate via destructor") { + EventLoop loop; + WaitScope waitScope(loop); + + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + deferredThrowCoroutine(kj::NEVER_DONE); + })); + + KJ_EXPECT(exception.getDescription() == "thrown during unwind"); +}; + +KJ_TEST("Exceptions during suspended coroutine frame-unwind do not cause a memory leak") { + EventLoop loop; + WaitScope waitScope(loop); + + // We can't easily test for memory leaks without hooking operator new and delete. However, we can + // arrange for the test to crash on failure, by having the coroutine suspend at a promise that we + // later fulfill, thus arming the Coroutine's Event. If we fail to destroy the coroutine in this + // state, EventLoop will throw on destruction because it can still see the Event in its list. + + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + auto paf = kj::newPromiseAndFulfiller(); + + auto coroPromise = deferredThrowCoroutine(kj::mv(paf.promise)); + + // Arm the Coroutine's Event. + paf.fulfiller->fulfill(); + + // If destroying `coroPromise` does not run ~Event(), then ~EventLoop() will crash later. + })); + + KJ_EXPECT(exception.getDescription() == "thrown during unwind"); +}; + +KJ_TEST("Exceptions during completed coroutine frame-unwind propagate via returned Promise") { + EventLoop loop; + WaitScope waitScope(loop); + + { + // First, prove that exceptions don't escape the destructor of a completed coroutine. + auto promise = deferredThrowCoroutine(kj::READY_NOW); + KJ_EXPECT(promise.poll(waitScope)); + } + + { + // Next, prove that they show up via the returned Promise. + auto promise = deferredThrowCoroutine(kj::READY_NOW); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("thrown during unwind", promise.wait(waitScope)); + } +} + +KJ_TEST("Coroutine destruction exceptions are ignored if there is another exception in flight") { + EventLoop loop; + WaitScope waitScope(loop); + + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + auto promise = deferredThrowCoroutine(kj::NEVER_DONE); + kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown before destroying throwy promise")); + })); + + KJ_EXPECT(exception.getDescription() == "thrown before destroying throwy promise"); +} + +KJ_TEST("co_await only sees coroutine destruction exceptions if promise was not rejected") { + EventLoop loop; + WaitScope waitScope(loop); + + // throwyDtorPromise is an immediate void promise that will throw when it's destroyed, which + // we expect to be able to catch from a coroutine which co_awaits it. + auto throwyDtorPromise = kj::Promise(kj::READY_NOW) + .attach(kj::defer([]() { + kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown during unwind")); + })); + + // rejectedThrowyDtorPromise is a rejected promise. When co_awaited in a coroutine, + // Awaiter::await_resume() will throw that exception for us to catch, but before we can catch it, + // the temporary promise will be destroyed. The exception it throws during unwind will be ignored, + // and the caller of the coroutine will see only the "thrown during execution" exception. + auto rejectedThrowyDtorPromise = kj::evalNow([&]() -> kj::Promise { + kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown during execution")); + }).attach(kj::defer([]() { + kj::throwFatalException(KJ_EXCEPTION(FAILED, "thrown during unwind")); + })); + + auto awaitPromise = [](kj::Promise promise) -> kj::Promise { + co_await promise; + }; + + KJ_EXPECT_THROW_MESSAGE("thrown during unwind", + awaitPromise(kj::mv(throwyDtorPromise)).wait(waitScope)); + + KJ_EXPECT_THROW_MESSAGE("thrown during execution", + awaitPromise(kj::mv(rejectedThrowyDtorPromise)).wait(waitScope)); +} + +#if (!_MSC_VER || defined(__clang__)) && !__aarch64__ +uint countLines(StringPtr s) { + uint lines = 0; + for (char c: s) { + lines += c == '\n'; + } + return lines; +} + +// TODO(msvc): This test relies on GetFunctorStartAddress, which is not supported on MSVC currently, +// so skip the test. +// TODO(someday): Test is flakey on arm64, depending on how it's compiled. I haven't had a chance to +// investigate much, but noticed that it failed in a debug build, but passed in a local opt build. +KJ_TEST("Can trace through coroutines") { + // This verifies that async traces, generated either from promises or from events, can see through + // coroutines. + // + // This test may be a bit brittle because it depends on specific trace counts. + + // Enable stack traces, even in release mode. + class EnableFullStackTrace: public ExceptionCallback { + public: + StackTraceMode stackTraceMode() override { return StackTraceMode::FULL; } + }; + EnableFullStackTrace exceptionCallback; + + EventLoop loop; + WaitScope waitScope(loop); + + auto paf = newPromiseAndFulfiller(); + + // Get an async trace when the promise is fulfilled. We eagerlyEvaluate() to make sure the + // continuation executes while the event loop is running. + paf.promise = paf.promise.then([]() { + auto trace = getAsyncTrace(); + // We expect one entry for waitImpl(), one for the coroutine, and one for this continuation. + // When building in debug mode with CMake, I observed this count can be 2. The missing frame is + // probably this continuation. Let's just expect a range. + auto count = countLines(trace); + KJ_EXPECT(0 < count && count <= 3); + }).eagerlyEvaluate(nullptr); + + auto coroPromise = [&]() -> kj::Promise { + co_await paf.promise; + }(); + + { + auto trace = coroPromise.trace(); + // One for the Coroutine PromiseNode, one for paf.promise. + KJ_EXPECT(countLines(trace) >= 2); + } + + paf.fulfiller->fulfill(); + + coroPromise.wait(waitScope); +} +#endif // !_MSC_VER || defined(__clang__) + +Promise sendData(Promise> addressPromise) { + auto address = co_await addressPromise; + auto client = co_await address->connect(); + co_await client->write("foo", 3); +} + +Promise receiveDataCoroutine(Own listener) { + auto server = co_await listener->accept(); + char buffer[4]; + auto n = co_await server->read(buffer, 3, 4); + KJ_EXPECT(3u == n); + co_return heapString(buffer, n); +} + +KJ_TEST("Simple network test with coroutine") { + auto io = setupAsyncIo(); + auto& network = io.provider->getNetwork(); + + Own serverAddress = network.parseAddress("*", 0).wait(io.waitScope); + Own listener = serverAddress->listen(); + + sendData(network.parseAddress("localhost", listener->getPort())) + .detach([](Exception&& exception) { + KJ_FAIL_EXPECT(exception); + }); + + String result = receiveDataCoroutine(kj::mv(listener)).wait(io.waitScope); + + KJ_EXPECT("foo" == result); +} + +Promise> httpClientConnect(AsyncIoContext& io) { + auto addr = co_await io.provider->getNetwork().parseAddress("capnproto.org", 80); + co_return co_await addr->connect(); +} + +Promise httpClient(Own connection) { + // Borrowed and rewritten from compat/http-test.c++. + + HttpHeaderTable table; + auto client = newHttpClient(table, *connection); + + HttpHeaders headers(table); + headers.set(HttpHeaderId::HOST, "capnproto.org"); + + auto response = co_await client->request(HttpMethod::GET, "/", headers).response; + KJ_EXPECT(response.statusCode / 100 == 3); + auto location = KJ_ASSERT_NONNULL(response.headers->get(HttpHeaderId::LOCATION)); + KJ_EXPECT(location == "https://capnproto.org/"); + + auto body = co_await response.body->readAllText(); +} + +KJ_TEST("HttpClient to capnproto.org with a coroutine") { + auto io = setupAsyncIo(); + + auto promise = httpClientConnect(io).then([](Own connection) { + return httpClient(kj::mv(connection)); + }, [](Exception&&) { + KJ_LOG(WARNING, "skipping test because couldn't connect to capnproto.org"); + }); + + promise.wait(io.waitScope); +} + +// ======================================================================================= +// coCapture() tests + +KJ_TEST("Verify coCapture() functors can only be run once") { + auto io = kj::setupAsyncIo(); + + auto functor = coCapture([](kj::Timer& timer) -> kj::Promise { + co_await timer.afterDelay(1 * kj::MILLISECONDS); + }); + + auto promise = functor(io.lowLevelProvider->getTimer()); + KJ_EXPECT_THROW(FAILED, functor(io.lowLevelProvider->getTimer())); + + promise.wait(io.waitScope); +} + +auto makeDelayedIntegerFunctor(size_t i) { + return [i](kj::Timer& timer) -> kj::Promise { + co_await timer.afterDelay(1 * kj::MILLISECONDS); + co_return i; + }; +} + +KJ_TEST("Verify coCapture() with local scoped functors") { + auto io = kj::setupAsyncIo(); + + constexpr size_t COUNT = 100; + kj::Vector> promises; + for (size_t i = 0; i < COUNT; ++i) { + auto functor = coCapture(makeDelayedIntegerFunctor(i)); + promises.add(functor(io.lowLevelProvider->getTimer())); + } + + for (size_t i = COUNT; i > 0 ; --i) { + auto j = i-1; + auto result = promises[j].wait(io.waitScope); + KJ_REQUIRE(result == j); + } +} + +auto makeCheckThenDelayedIntegerFunctor(kj::Timer& timer, size_t i) { + return [&timer, i](size_t val) -> kj::Promise { + KJ_REQUIRE(val == i); + co_await timer.afterDelay(1 * kj::MILLISECONDS); + co_return i; + }; +} + +KJ_TEST("Verify coCapture() with continuation functors") { + // This test usually works locally without `coCapture()()`. It does however, fail in + // ASAN. + auto io = kj::setupAsyncIo(); + + constexpr size_t COUNT = 100; + kj::Vector> promises; + for (size_t i = 0; i < COUNT; ++i) { + auto promise = io.lowLevelProvider->getTimer().afterDelay(1 * kj::MILLISECONDS).then([i]() { + return i; + }); + promise = promise.then(coCapture( + makeCheckThenDelayedIntegerFunctor(io.lowLevelProvider->getTimer(), i))); + promises.add(kj::mv(promise)); + } + + for (size_t i = COUNT; i > 0 ; --i) { + auto j = i-1; + auto result = promises[j].wait(io.waitScope); + KJ_REQUIRE(result == j); + } +} + +#endif // KJ_HAS_COROUTINE + +} // namespace +} // namespace kj diff --git a/c++/src/kj/async-inl.h b/c++/src/kj/async-inl.h index f11e4fcd5b..c7a7c59c51 100644 --- a/c++/src/kj/async-inl.h +++ b/c++/src/kj/async-inl.h @@ -24,18 +24,21 @@ // // Non-inline declarations here are defined in async.c++. -#ifndef KJ_ASYNC_H_ +#pragma once + +#ifndef KJ_ASYNC_H_INCLUDED #error "Do not include this directly; include kj/async.h." #include "async.h" // help IDE parse this file #endif -#ifndef KJ_ASYNC_INL_H_ -#define KJ_ASYNC_INL_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header +#if _MSC_VER && KJ_HAS_COROUTINE +#include #endif +#include + +KJ_BEGIN_HEADER + namespace kj { namespace _ { // private @@ -79,14 +82,76 @@ class ExceptionOr: public ExceptionOrValue { Maybe value; }; -class Event { +template +inline T convertToReturn(ExceptionOr&& result) { + KJ_IF_MAYBE(value, result.value) { + KJ_IF_MAYBE(exception, result.exception) { + throwRecoverableException(kj::mv(*exception)); + } + return _::returnMaybeVoid(kj::mv(*value)); + } else KJ_IF_MAYBE(exception, result.exception) { + throwFatalException(kj::mv(*exception)); + } else { + // Result contained neither a value nor an exception? + KJ_UNREACHABLE; + } +} + +inline void convertToReturn(ExceptionOr&& result) { + // Override case to use throwRecoverableException(). + + if (result.value != nullptr) { + KJ_IF_MAYBE(exception, result.exception) { + throwRecoverableException(kj::mv(*exception)); + } + } else KJ_IF_MAYBE(exception, result.exception) { + throwRecoverableException(kj::mv(*exception)); + } else { + // Result contained neither a value nor an exception? + KJ_UNREACHABLE; + } +} + +class TraceBuilder { + // Helper for methods that build a call trace. +public: + TraceBuilder(ArrayPtr space) + : start(space.begin()), current(space.begin()), limit(space.end()) {} + + inline void add(void* addr) { + if (current < limit) { + *current++ = addr; + } + } + + inline bool full() const { return current == limit; } + + ArrayPtr finish() { + return arrayPtr(start, current); + } + + String toString(); + +private: + void** start; + void** current; + void** limit; +}; + +struct alignas(void*) PromiseArena { + // Space in which a chain of promises may be allocated. See PromiseDisposer. + byte bytes[1024]; +}; + +class Event: private AsyncObject { // An event waiting to be executed. Not for direct use by applications -- promises use this // internally. public: - Event(); + Event(SourceLocation location); + Event(kj::EventLoop& loop, SourceLocation location); ~Event() noexcept(false); - KJ_DISALLOW_COPY(Event); + KJ_DISALLOW_COPY_AND_MOVE(Event); void armDepthFirst(); // Enqueue this event so that `fire()` will be called from the event loop soon. @@ -106,12 +171,37 @@ class Event { void armBreadthFirst(); // Like `armDepthFirst()` except that the event is placed at the end of the queue. - kj::String trace(); - // Dump debug info about this event. + void armLast(); + // Enqueues this event to happen after all other events have run to completion and there is + // really nothing left to do except wait for I/O. + + bool isNext(); + // True if the Event has been armed and is next in line to be fired. This can be used after + // calling PromiseNode::onReady(event) to determine if a promise being waited is immediately + // ready, in which case continuations may be optimistically run without returning to the event + // loop. Note that this optimization is only valid if we know that we would otherwise immediately + // return to the event loop without running more application code. So this turns out to be useful + // in fairly narrow circumstances, chiefly when a coroutine is about to suspend, but discovers it + // doesn't need to. + // + // Returns false if the event loop is not currently running. This ensures that promise + // continuations don't execute except under a call to .wait(). + + void disarm(); + // If the event is armed but hasn't fired, cancel it. (Destroying the event does this + // implicitly.) + + virtual void traceEvent(TraceBuilder& builder) = 0; + // Build a trace of the callers leading up to this event. `builder` will be populated with + // "return addresses" of the promise chain waiting on this event. The return addresses may + // actually the addresses of lambdas passed to .then(), but in any case, feeding them into + // addr2line should produce useful source code locations. + // + // `traceEvent()` may be called from an async signal handler while `fire()` is executing. It + // must not allocate nor take locks. - virtual _::PromiseNode* getInnerForTrace(); - // If this event wraps a PromiseNode, get that node. Used for debug tracing. - // Default implementation returns nullptr. + String traceEvent(); + // Helper that builds a trace and stringifies it. protected: virtual Maybe> fire() = 0; @@ -125,9 +215,46 @@ class Event { Event* next; Event** prev; bool firing = false; + + static constexpr uint MAGIC_LIVE_VALUE = 0x1e366381u; + uint live = MAGIC_LIVE_VALUE; + SourceLocation location; +}; + +class PromiseArenaMember { + // An object that is allocated in a PromiseArena. `PromiseNode` inherits this, and most + // arena-allocated objects are `PromiseNode` subclasses, but `TaskSet::Task`, ForkHub, and + // potentially other objects that commonly live on the end of a promise chain can also leverage + // this. + +public: + virtual void destroy() = 0; + // Destroys and frees the node. + // + // If the node was allocated using allocPromise(), then destroy() must call + // freePromise(this). If it was allocated some other way, then it is `destroy()`'s + // responsibility to complete any necessary cleanup of memory, e.g. call `delete this`. + // + // We use this instead of a virtual destructor for two reasons: + // 1. Coroutine nodes are not independent objects, they have to call destroy() on the coroutine + // handle to delete themselves. + // 2. XThreadEvents sometimes leave it up to a different thread to actually delete the object. + +private: + PromiseArena* arena = nullptr; + // If non-null, then this PromiseNode is the last node allocated within the given arena, and + // therefore owns the arena. After this node is destroyed, the arena should be deleted. + // + // PromiseNodes are allocated within the arena starting from the end, and `PromiseNode`s + // allocated this way are required to have `PromiseNode` itself as their leftmost inherited type, + // so that the pointers match. Thus, the space in `arena` from its start to the location of the + // `PromiseNode` is known to be available for subsequent allocations (which should then take + // ownership of the arena). + + friend class PromiseDisposer; }; -class PromiseNode { +class PromiseNode: public PromiseArenaMember, private AsyncObject { // A Promise contains a chain of PromiseNodes tracking the pending transformations. // // To reduce generated code bloat, PromiseNode is not a template. Instead, it makes very hacky @@ -136,10 +263,14 @@ class PromiseNode { // internal implementation details. public: - virtual void onReady(Event& event) noexcept = 0; + virtual void onReady(Event* event) noexcept = 0; // Arms the given event when ready. + // + // May be called multiple times. If called again before the event was armed, the old event will + // never be armed, only the new one. If called again after the event was armed, the new event + // will be armed immediately. Can be called with nullptr to un-register the existing event. - virtual void setSelfPointer(Own* selfPtr) noexcept; + virtual void setSelfPointer(OwnPromiseNode* selfPtr) noexcept; // Tells the node that `selfPtr` is the pointer that owns this node, and will continue to own // this node until it is destroyed or setSelfPointer() is called again. ChainPromiseNode uses // this to shorten redundant chains. The default implementation does nothing; only @@ -150,27 +281,190 @@ class PromiseNode { // Can only be called once, and only after the node is ready. Must be called directly from the // event loop, with no application code on the stack. - virtual PromiseNode* getInnerForTrace(); - // If this node wraps some other PromiseNode, get the wrapped node. Used for debug tracing. - // Default implementation returns nullptr. + virtual void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) = 0; + // Build a trace of this promise chain, showing what it is currently waiting on. + // + // Since traces are ordered callee-before-caller, PromiseNode::tracePromise() should typically + // recurse to its child first, then after the child returns, add itself to the trace. + // + // If `stopAtNextEvent` is true, then the trace should stop as soon as it hits a PromiseNode that + // also implements Event, and should not trace that node or its children. This is used in + // conjunction with Event::traceEvent(). The chain of Events is often more sparse than the chain + // of PromiseNodes, because a TransformPromiseNode (which implements .then()) is not itself an + // Event. TransformPromiseNode instead tells its child node to directly notify its *parent* node + // when it is ready, and then TransformPromiseNode applies the .then() transformation during the + // call to .get(). + // + // So, when we trace the chain of Events backwards, we end up hoping over segments of + // TransformPromiseNodes (and other similar types). In order to get those added to the trace, + // each Event must call back down the PromiseNode chain in the opposite direction, using this + // method. + // + // `tracePromise()` may be called from an async signal handler while `get()` is executing. It + // must not allocate nor take locks. + + template + static OwnPromiseNode from(T&& promise) { + // Given a Promise, extract the PromiseNode. + return kj::mv(promise.node); + } + template + static PromiseNode& from(T& promise) { + // Given a Promise, extract the PromiseNode. + return *promise.node; + } + template + static T to(OwnPromiseNode&& node) { + // Construct a Promise from a PromiseNode. (T should be a Promise type.) + return T(false, kj::mv(node)); + } protected: class OnReadyEvent { // Helper class for implementing onReady(). public: - void init(Event& newEvent); - // Returns true if arm() was already called. + void init(Event* newEvent); void arm(); - // Arms the event if init() has already been called and makes future calls to init() return - // true. + void armBreadthFirst(); + // Arms the event if init() has already been called and makes future calls to init() + // automatically arm the event. + + inline void traceEvent(TraceBuilder& builder) { + if (event != nullptr && !builder.full()) event->traceEvent(builder); + } private: Event* event = nullptr; }; }; +class PromiseDisposer { +public: + template + static constexpr bool canArenaAllocate() { + // We can only use arena allocation for types that fit in an arena and have pointer-size + // alignment. Anything else will need to be allocated as a separate heap object. + return sizeof(T) <= sizeof(PromiseArena) && alignof(T) <= alignof(void*); + } + + static void dispose(PromiseArenaMember* node) { + PromiseArena* arena = node->arena; + node->destroy(); + delete arena; // reminder: `delete` automatically ignores null pointers + } + + template + static kj::Own alloc(Params&&... params) noexcept { + // Implements allocPromise(). + T* ptr; + if (!canArenaAllocate()) { + // Node too big (or needs weird alignment), fall back to regular heap allocation. + ptr = new T(kj::fwd(params)...); + } else { + // Start a new arena. + // + // NOTE: As in append() (below), we don't implement exception-safety because it causes code + // bloat and these constructors probably don't throw. Instead this function is noexcept, so + // if a constructor does throw, it'll crash rather than leak memory. + auto* arena = new PromiseArena; + ptr = reinterpret_cast(arena + 1) - 1; + ctor(*ptr, kj::fwd(params)...); + ptr->arena = arena; + KJ_IREQUIRE(reinterpret_cast(ptr) == + reinterpret_cast(static_cast(ptr)), + "PromiseArenaMember must be the leftmost inherited type."); + } + return kj::Own(ptr); + } + + template + static kj::Own append( + OwnPromiseNode&& next, Params&&... params) noexcept { + // Implements appendPromise(). + + PromiseArena* arena = next->arena; + + if (!canArenaAllocate() || arena == nullptr || + reinterpret_cast(next.get()) - reinterpret_cast(arena) < sizeof(T)) { + // No arena available, or not enough space, or weird alignment needed. Start new arena. + return alloc(kj::mv(next), kj::fwd(params)...); + } else { + // Append to arena. + // + // NOTE: When we call ctor(), it takes ownership of `next`, so we shouldn't assume `next` + // still exists after it returns. So we have to remove ownership of the arena before that. + // In theory if we wanted this to be exception-safe, we'd also have to arrange to delete + // the arena if the constructor throws. However, in practice none of the PromiseNode + // constructors throw, so we just mark the whole method noexcept in order to avoid the + // code bloat to handle this case. + next->arena = nullptr; + T* ptr = reinterpret_cast(next.get()) - 1; + ctor(*ptr, kj::mv(next), kj::fwd(params)...); + ptr->arena = arena; + KJ_IREQUIRE(reinterpret_cast(ptr) == + reinterpret_cast(static_cast(ptr)), + "PromiseArenaMember must be the leftmost inherited type."); + return kj::Own(ptr); + } + } +}; + +template +static kj::Own allocPromise(Params&&... params) { + // Allocate a PromiseNode without appending it to any existing promise arena. Space for a new + // arena will be allocated. + return PromiseDisposer::alloc(kj::fwd(params)...); +} + +template ()> +struct FreePromiseNode; +template +struct FreePromiseNode { + static inline void free(T* ptr) { + // The object will have been allocated in an arena, so we only want to run the destructor. + // The arena's memory will be freed separately. + kj::dtor(*ptr); + } +}; +template +struct FreePromiseNode { + static inline void free(T* ptr) { + // The object will have been allocated separately on the heap. + return delete ptr; + } +}; + +template +static void freePromise(T* ptr) { + // Free a PromiseNode originally allocated using `allocPromise()`. The implementation of + // PromiseNode::destroy() must call this for any type that is allocated using allocPromise(). + FreePromiseNode::free(ptr); +} + +template +static kj::Own appendPromise(OwnPromiseNode&& next, Params&&... params) { + // Append a promise to the arena that currently ends with `next`. `next` is also still passed as + // the first parameter to the new object's constructor. + // + // This is semantically the same as `allocPromise()` except that it may avoid the underlying + // memory allocation. `next` must end up being destroyed before the new object (i.e. the new + // object must never transfer away ownership of `next`). + return PromiseDisposer::append(kj::mv(next), kj::fwd(params)...); +} + +// ------------------------------------------------------------------- + +inline ReadyNow::operator Promise() const { + return PromiseNode::to>(readyNow()); +} + +template +inline NeverDone::operator Promise() const { + return PromiseNode::to>(neverDone()); +} + // ------------------------------------------------------------------- class ImmediatePromiseNodeBase: public PromiseNode { @@ -178,7 +472,8 @@ class ImmediatePromiseNodeBase: public PromiseNode { ImmediatePromiseNodeBase(); ~ImmediatePromiseNodeBase() noexcept(false); - void onReady(Event& event) noexcept override; + void onReady(Event* event) noexcept override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; }; template @@ -187,6 +482,7 @@ class ImmediatePromiseNode final: public ImmediatePromiseNodeBase { public: ImmediatePromiseNode(ExceptionOr&& result): result(kj::mv(result)) {} + void destroy() override { freePromise(this); } void get(ExceptionOrValue& output) noexcept override { output.as() = kj::mv(result); @@ -199,6 +495,7 @@ class ImmediatePromiseNode final: public ImmediatePromiseNodeBase { class ImmediateBrokenPromiseNode final: public ImmediatePromiseNodeBase { public: ImmediateBrokenPromiseNode(Exception&& exception); + void destroy() override; void get(ExceptionOrValue& output) noexcept override; @@ -206,18 +503,27 @@ class ImmediateBrokenPromiseNode final: public ImmediatePromiseNodeBase { Exception exception; }; +template +class ConstPromiseNode: public ImmediatePromiseNodeBase { +public: + void destroy() override {} + void get(ExceptionOrValue& output) noexcept override { + output.as() = value; + } +}; + // ------------------------------------------------------------------- class AttachmentPromiseNodeBase: public PromiseNode { public: - AttachmentPromiseNodeBase(Own&& dependency); + AttachmentPromiseNodeBase(OwnPromiseNode&& dependency); - void onReady(Event& event) noexcept override; + void onReady(Event* event) noexcept override; void get(ExceptionOrValue& output) noexcept override; - PromiseNode* getInnerForTrace() override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; private: - Own dependency; + OwnPromiseNode dependency; void dropDependency(); @@ -231,9 +537,10 @@ class AttachmentPromiseNode final: public AttachmentPromiseNodeBase { // object) until the promise resolves. public: - AttachmentPromiseNode(Own&& dependency, Attachment&& attachment) + AttachmentPromiseNode(OwnPromiseNode&& dependency, Attachment&& attachment) : AttachmentPromiseNodeBase(kj::mv(dependency)), attachment(kj::mv(attachment)) {} + void destroy() override { freePromise(this); } ~AttachmentPromiseNode() noexcept(false) { // We need to make sure the dependency is deleted before we delete the attachment because the @@ -247,12 +554,42 @@ class AttachmentPromiseNode final: public AttachmentPromiseNodeBase { // ------------------------------------------------------------------- +#if __GNUC__ >= 8 && !__clang__ +// GCC 8's class-memaccess warning rightly does not like the memcpy()'s below, but there's no +// "legal" way for us to extract the content of a PTMF so too bad. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#if __GNUC__ >= 11 +// GCC 11's array-bounds is similarly upset with us for digging into "private" implementation +// details. But the format is well-defined by the ABI which cannot change so please just let us +// do it kthx. +#pragma GCC diagnostic ignored "-Warray-bounds" +#endif +#endif + +template +void* getMethodStartAddress(T& obj, ReturnType (T::*method)(ParamTypes...)); +template +void* getMethodStartAddress(const T& obj, ReturnType (T::*method)(ParamTypes...) const); +// Given an object and a pointer-to-method, return the start address of the method's code. The +// intent is that this address can be used in a trace; addr2line should map it to the start of +// the function's definition. For virtual methods, this does a vtable lookup on `obj` to determine +// the address of the specific implementation (otherwise, `obj` wouldn't be needed). +// +// Note that if the method is overloaded or is a template, you will need to explicitly specify +// the param and return types, otherwise the compiler won't know which overload / template +// specialization you are requesting. + class PtmfHelper { - // This class is a private helper for GetFunctorStartAddress. The class represents the internal - // representation of a pointer-to-member-function. + // This class is a private helper for GetFunctorStartAddress and getMethodStartAddress(). The + // class represents the internal representation of a pointer-to-member-function. template friend struct GetFunctorStartAddress; + template + friend void* getMethodStartAddress(T& obj, ReturnType (T::*method)(ParamTypes...)); + template + friend void* getMethodStartAddress(const T& obj, ReturnType (T::*method)(ParamTypes...) const); #if __GNUG__ @@ -260,7 +597,7 @@ class PtmfHelper { ptrdiff_t adj; // Layout of a pointer-to-member-function used by GCC and compatible compilers. - void* apply(void* obj) { + void* apply(const void* obj) { #if defined(__arm__) || defined(__mips__) || defined(__aarch64__) if (adj & 1) { ptrdiff_t voff = (ptrdiff_t)ptr; @@ -283,7 +620,7 @@ class PtmfHelper { #else // __GNUG__ - void* apply(void* obj) { return nullptr; } + void* apply(const void* obj) { return nullptr; } // TODO(port): PTMF instruction address extraction #define BODY return PtmfHelper{} @@ -308,6 +645,19 @@ class PtmfHelper { #undef BODY }; +#if __GNUC__ >= 8 && !__clang__ +#pragma GCC diagnostic pop +#endif + +template +void* getMethodStartAddress(T& obj, ReturnType (T::*method)(ParamTypes...)) { + return PtmfHelper::from(method).apply(&obj); +} +template +void* getMethodStartAddress(const T& obj, ReturnType (T::*method)(ParamTypes...) const) { + return PtmfHelper::from(method).apply(&obj); +} + template struct GetFunctorStartAddress { // Given a functor (any object defining operator()), return the start address of the function, @@ -336,14 +686,14 @@ struct GetFunctorStartAddress: public GetFunctorStartAddress<> {}; class TransformPromiseNodeBase: public PromiseNode { public: - TransformPromiseNodeBase(Own&& dependency, void* continuationTracePtr); + TransformPromiseNodeBase(OwnPromiseNode&& dependency, void* continuationTracePtr); - void onReady(Event& event) noexcept override; + void onReady(Event* event) noexcept override; void get(ExceptionOrValue& output) noexcept override; - PromiseNode* getInnerForTrace() override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; private: - Own dependency; + OwnPromiseNode dependency; void* continuationTracePtr; void dropDependency(); @@ -361,10 +711,11 @@ class TransformPromiseNode final: public TransformPromiseNodeBase { // function (implements `then()`). public: - TransformPromiseNode(Own&& dependency, Func&& func, ErrorFunc&& errorHandler) - : TransformPromiseNodeBase(kj::mv(dependency), - GetFunctorStartAddress::apply(func)), + TransformPromiseNode(OwnPromiseNode&& dependency, Func&& func, ErrorFunc&& errorHandler, + void* continuationTracePtr) + : TransformPromiseNodeBase(kj::mv(dependency), continuationTracePtr), func(kj::fwd(func)), errorHandler(kj::fwd(errorHandler)) {} + void destroy() override { freePromise(this); } ~TransformPromiseNode() noexcept(false) { // We need to make sure the dependency is deleted before we delete the continuations because it @@ -400,18 +751,19 @@ class TransformPromiseNode final: public TransformPromiseNodeBase { // ------------------------------------------------------------------- class ForkHubBase; +using OwnForkHubBase = Own; class ForkBranchBase: public PromiseNode { public: - ForkBranchBase(Own&& hub); + ForkBranchBase(OwnForkHubBase&& hub); ~ForkBranchBase() noexcept(false); void hubReady() noexcept; // Called by the hub to indicate that it is ready. // implements PromiseNode ------------------------------------------ - void onReady(Event& event) noexcept override; - PromiseNode* getInnerForTrace() override; + void onReady(Event* event) noexcept override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; protected: inline ExceptionOrValue& getHubResultRef(); @@ -422,7 +774,7 @@ class ForkBranchBase: public PromiseNode { private: OnReadyEvent onReadyEvent; - Own hub; + OwnForkHubBase hub; ForkBranchBase* next = nullptr; ForkBranchBase** prevPtr = nullptr; @@ -431,6 +783,11 @@ class ForkBranchBase: public PromiseNode { template T copyOrAddRef(T& t) { return t; } template Own copyOrAddRef(Own& t) { return t->addRef(); } +template Maybe> copyOrAddRef(Maybe>& t) { + return t.map([](Own& ptr) { + return ptr->addRef(); + }); +} template class ForkBranch final: public ForkBranchBase { @@ -438,7 +795,8 @@ class ForkBranch final: public ForkBranchBase { // a const reference. public: - ForkBranch(Own&& hub): ForkBranchBase(kj::mv(hub)) {} + ForkBranch(OwnForkHubBase&& hub): ForkBranchBase(kj::mv(hub)) {} + void destroy() override { freePromise(this); } void get(ExceptionOrValue& output) noexcept override { ExceptionOr& hubResult = getHubResultRef().template as(); @@ -458,7 +816,8 @@ class SplitBranch final: public ForkBranchBase { // a const reference. public: - SplitBranch(Own&& hub): ForkBranchBase(kj::mv(hub)) {} + SplitBranch(OwnForkHubBase&& hub): ForkBranchBase(kj::mv(hub)) {} + void destroy() override { freePromise(this); } typedef kj::Decay(kj::instance()))> Element; @@ -476,14 +835,31 @@ class SplitBranch final: public ForkBranchBase { // ------------------------------------------------------------------- -class ForkHubBase: public Refcounted, protected Event { +class ForkHubBase: public PromiseArenaMember, protected Event { public: - ForkHubBase(Own&& inner, ExceptionOrValue& resultRef); + ForkHubBase(OwnPromiseNode&& inner, ExceptionOrValue& resultRef, SourceLocation location); inline ExceptionOrValue& getResultRef() { return resultRef; } + inline bool isShared() const { return refcount > 1; } + + Own addRef() { + ++refcount; + return Own(this); + } + + static void dispose(ForkHubBase* obj) { + if (--obj->refcount == 0) { + PromiseDisposer::dispose(obj); + } + } + private: - Own inner; + uint refcount = 1; + // We manually implement refcounting for ForkHubBase so that we can use it together with + // PromiseDisposer's arena allocation. + + OwnPromiseNode inner; ExceptionOrValue& resultRef; ForkBranchBase* headBranch = nullptr; @@ -491,7 +867,7 @@ class ForkHubBase: public Refcounted, protected Event { // Tail becomes null once the inner promise is ready and all branches have been notified. Maybe> fire() override; - _::PromiseNode* getInnerForTrace() override; + void traceEvent(TraceBuilder& builder) override; friend class ForkBranchBase; }; @@ -503,29 +879,33 @@ class ForkHub final: public ForkHubBase { // possible). public: - ForkHub(Own&& inner): ForkHubBase(kj::mv(inner), result) {} + ForkHub(OwnPromiseNode&& inner, SourceLocation location) + : ForkHubBase(kj::mv(inner), result, location) {} + void destroy() override { freePromise(this); } Promise<_::UnfixVoid> addBranch() { - return Promise<_::UnfixVoid>(false, kj::heap>(addRef(*this))); + return _::PromiseNode::to>>( + allocPromise>(addRef())); } - _::SplitTuplePromise split() { - return splitImpl(MakeIndexes()>()); + _::SplitTuplePromise split(SourceLocation location) { + return splitImpl(MakeIndexes()>(), location); } private: ExceptionOr result; template - _::SplitTuplePromise splitImpl(Indexes) { - return kj::tuple(addSplit()...); + _::SplitTuplePromise splitImpl(Indexes, SourceLocation location) { + return kj::tuple(addSplit(location)...); } template - Promise::Element>> addSplit() { - return Promise::Element>>( - false, maybeChain(kj::heap>(addRef(*this)), - implicitCast::Element*>(nullptr))); + ReducePromises::Element> addSplit(SourceLocation location) { + return _::PromiseNode::to::Element>>( + maybeChain(allocPromise>(addRef()), + implicitCast::Element*>(nullptr), + location)); } }; @@ -542,13 +922,14 @@ class ChainPromiseNode final: public PromiseNode, public Event { // Own. Ugh, templates and private... public: - explicit ChainPromiseNode(Own inner); + explicit ChainPromiseNode(OwnPromiseNode inner, SourceLocation location); ~ChainPromiseNode() noexcept(false); + void destroy() override; - void onReady(Event& event) noexcept override; - void setSelfPointer(Own* selfPtr) noexcept override; + void onReady(Event* event) noexcept override; + void setSelfPointer(OwnPromiseNode* selfPtr) noexcept override; void get(ExceptionOrValue& output) noexcept override; - PromiseNode* getInnerForTrace() override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; private: enum State { @@ -558,52 +939,67 @@ class ChainPromiseNode final: public PromiseNode, public Event { State state; - Own inner; + OwnPromiseNode inner; // In STEP1, a PromiseNode for a Promise. // In STEP2, a PromiseNode for a T. Event* onReadyEvent = nullptr; - Own* selfPtr = nullptr; + OwnPromiseNode* selfPtr = nullptr; Maybe> fire() override; + void traceEvent(TraceBuilder& builder) override; }; template -Own maybeChain(Own&& node, Promise*) { - return heap(kj::mv(node)); +OwnPromiseNode maybeChain(OwnPromiseNode&& node, Promise*, SourceLocation location) { + return appendPromise(kj::mv(node), location); } template -Own&& maybeChain(Own&& node, T*) { +OwnPromiseNode&& maybeChain(OwnPromiseNode&& node, T*, SourceLocation location) { return kj::mv(node); } +template >()))> +inline Result maybeReduce(Promise&& promise, bool) { + return T::reducePromise(kj::mv(promise)); +} + +template +inline Promise maybeReduce(Promise&& promise, ...) { + return kj::mv(promise); +} + // ------------------------------------------------------------------- class ExclusiveJoinPromiseNode final: public PromiseNode { public: - ExclusiveJoinPromiseNode(Own left, Own right); + ExclusiveJoinPromiseNode(OwnPromiseNode left, OwnPromiseNode right, SourceLocation location); ~ExclusiveJoinPromiseNode() noexcept(false); + void destroy() override; - void onReady(Event& event) noexcept override; + void onReady(Event* event) noexcept override; void get(ExceptionOrValue& output) noexcept override; - PromiseNode* getInnerForTrace() override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; private: class Branch: public Event { public: - Branch(ExclusiveJoinPromiseNode& joinNode, Own dependency); + Branch(ExclusiveJoinPromiseNode& joinNode, OwnPromiseNode dependency, + SourceLocation location); ~Branch() noexcept(false); bool get(ExceptionOrValue& output); // Returns true if this is the side that finished. Maybe> fire() override; - _::PromiseNode* getInnerForTrace() override; + void traceEvent(TraceBuilder& builder) override; private: ExclusiveJoinPromiseNode& joinNode; - Own dependency; + OwnPromiseNode dependency; + + friend class ExclusiveJoinPromiseNode; }; Branch left; @@ -613,40 +1009,49 @@ class ExclusiveJoinPromiseNode final: public PromiseNode { // ------------------------------------------------------------------- +enum class ArrayJoinBehavior { + LAZY, + EAGER, +}; + class ArrayJoinPromiseNodeBase: public PromiseNode { public: - ArrayJoinPromiseNodeBase(Array> promises, - ExceptionOrValue* resultParts, size_t partSize); + ArrayJoinPromiseNodeBase(Array promises, + ExceptionOrValue* resultParts, size_t partSize, + SourceLocation location, + ArrayJoinBehavior joinBehavior); ~ArrayJoinPromiseNodeBase() noexcept(false); - void onReady(Event& event) noexcept override final; + void onReady(Event* event) noexcept override final; void get(ExceptionOrValue& output) noexcept override final; - PromiseNode* getInnerForTrace() override final; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override final; protected: virtual void getNoError(ExceptionOrValue& output) noexcept = 0; // Called to compile the result only in the case where there were no errors. private: + const ArrayJoinBehavior joinBehavior; + uint countLeft; OnReadyEvent onReadyEvent; + bool armed = false; class Branch final: public Event { public: - Branch(ArrayJoinPromiseNodeBase& joinNode, Own dependency, - ExceptionOrValue& output); + Branch(ArrayJoinPromiseNodeBase& joinNode, OwnPromiseNode dependency, + ExceptionOrValue& output, SourceLocation location); ~Branch() noexcept(false); Maybe> fire() override; - _::PromiseNode* getInnerForTrace() override; - - Maybe getPart(); - // Calls dependency->get(output). If there was an exception, return it. + void traceEvent(TraceBuilder& builder) override; private: ArrayJoinPromiseNodeBase& joinNode; - Own dependency; + OwnPromiseNode dependency; ExceptionOrValue& output; + + friend class ArrayJoinPromiseNodeBase; }; Array branches; @@ -655,10 +1060,14 @@ class ArrayJoinPromiseNodeBase: public PromiseNode { template class ArrayJoinPromiseNode final: public ArrayJoinPromiseNodeBase { public: - ArrayJoinPromiseNode(Array> promises, - Array> resultParts) - : ArrayJoinPromiseNodeBase(kj::mv(promises), resultParts.begin(), sizeof(ExceptionOr)), + ArrayJoinPromiseNode(Array promises, + Array> resultParts, + SourceLocation location, + ArrayJoinBehavior joinBehavior) + : ArrayJoinPromiseNodeBase(kj::mv(promises), resultParts.begin(), sizeof(ExceptionOr), + location, joinBehavior), resultParts(kj::mv(resultParts)) {} + void destroy() override { freePromise(this); } protected: void getNoError(ExceptionOrValue& output) noexcept override { @@ -678,9 +1087,12 @@ class ArrayJoinPromiseNode final: public ArrayJoinPromiseNodeBase { template <> class ArrayJoinPromiseNode final: public ArrayJoinPromiseNodeBase { public: - ArrayJoinPromiseNode(Array> promises, - Array> resultParts); + ArrayJoinPromiseNode(Array promises, + Array> resultParts, + SourceLocation location, + ArrayJoinBehavior joinBehavior); ~ArrayJoinPromiseNode(); + void destroy() override; protected: void getNoError(ExceptionOrValue& output) noexcept override; @@ -696,25 +1108,28 @@ class EagerPromiseNodeBase: public PromiseNode, protected Event { // evaluate it. public: - EagerPromiseNodeBase(Own&& dependency, ExceptionOrValue& resultRef); + EagerPromiseNodeBase(OwnPromiseNode&& dependency, ExceptionOrValue& resultRef, + SourceLocation location); - void onReady(Event& event) noexcept override; - PromiseNode* getInnerForTrace() override; + void onReady(Event* event) noexcept override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; private: - Own dependency; + OwnPromiseNode dependency; OnReadyEvent onReadyEvent; ExceptionOrValue& resultRef; Maybe> fire() override; + void traceEvent(TraceBuilder& builder) override; }; template class EagerPromiseNode final: public EagerPromiseNodeBase { public: - EagerPromiseNode(Own&& dependency) - : EagerPromiseNodeBase(kj::mv(dependency), result) {} + EagerPromiseNode(OwnPromiseNode&& dependency, SourceLocation location) + : EagerPromiseNodeBase(kj::mv(dependency), result, location) {} + void destroy() override { freePromise(this); } void get(ExceptionOrValue& output) noexcept override { output.as() = kj::mv(result); @@ -725,17 +1140,18 @@ class EagerPromiseNode final: public EagerPromiseNodeBase { }; template -Own spark(Own&& node) { +OwnPromiseNode spark(OwnPromiseNode&& node, SourceLocation location) { // Forces evaluation of the given node to begin as soon as possible, even if no one is waiting // on it. - return heap>(kj::mv(node)); + return appendPromise>(kj::mv(node), location); } // ------------------------------------------------------------------- class AdapterPromiseNodeBase: public PromiseNode { public: - void onReady(Event& event) noexcept override; + void onReady(Event* event) noexcept override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; protected: inline void setReady() { @@ -755,6 +1171,7 @@ class AdapterPromiseNode final: public AdapterPromiseNodeBase, template AdapterPromiseNode(Params&&... params) : adapter(static_cast>&>(*this), kj::fwd(params)...) {} + void destroy() override { freePromise(this); } void get(ExceptionOrValue& output) noexcept override { KJ_IREQUIRE(!isWaiting()); @@ -787,28 +1204,102 @@ class AdapterPromiseNode final: public AdapterPromiseNodeBase, } }; +// ------------------------------------------------------------------- + +class FiberBase: public PromiseNode, private Event { + // Base class for the outer PromiseNode representing a fiber. + +public: + explicit FiberBase(size_t stackSize, _::ExceptionOrValue& result, SourceLocation location); + explicit FiberBase(const FiberPool& pool, _::ExceptionOrValue& result, SourceLocation location); + ~FiberBase() noexcept(false); + + void start() { armDepthFirst(); } + // Call immediately after construction to begin executing the fiber. + + class WaitDoneEvent; + + void onReady(_::Event* event) noexcept override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; + +protected: + bool isFinished() { return state == FINISHED; } + void cancel(); + +private: + enum { WAITING, RUNNING, CANCELED, FINISHED } state; + + _::PromiseNode* currentInner = nullptr; + OnReadyEvent onReadyEvent; + Own stack; + _::ExceptionOrValue& result; + + void run(); + virtual void runImpl(WaitScope& waitScope) = 0; + + Maybe> fire() override; + void traceEvent(TraceBuilder& builder) override; + // Implements Event. Each time the event is fired, switchToFiber() is called. + + friend class FiberStack; + friend void _::waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, + WaitScope& waitScope, SourceLocation location); + friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location); +}; + +template +class Fiber final: public FiberBase { +public: + explicit Fiber(size_t stackSize, Func&& func, SourceLocation location) + : FiberBase(stackSize, result, location), func(kj::fwd(func)) {} + explicit Fiber(const FiberPool& pool, Func&& func, SourceLocation location) + : FiberBase(pool, result, location), func(kj::fwd(func)) {} + ~Fiber() noexcept(false) { cancel(); } + void destroy() override { freePromise(this); } + + typedef FixVoid()(kj::instance()))> ResultType; + + void get(ExceptionOrValue& output) noexcept override { + KJ_IREQUIRE(isFinished()); + output.as() = kj::mv(result); + } + +private: + Func func; + ExceptionOr result; + + void runImpl(WaitScope& waitScope) override { + result.template as() = + MaybeVoidCaller::apply(func, waitScope); + } +}; + } // namespace _ (private) // ======================================================================================= template Promise::Promise(_::FixVoid value) - : PromiseBase(heap<_::ImmediatePromiseNode<_::FixVoid>>(kj::mv(value))) {} + : PromiseBase(_::allocPromise<_::ImmediatePromiseNode<_::FixVoid>>(kj::mv(value))) {} template Promise::Promise(kj::Exception&& exception) - : PromiseBase(heap<_::ImmediateBrokenPromiseNode>(kj::mv(exception))) {} + : PromiseBase(_::allocPromise<_::ImmediateBrokenPromiseNode>(kj::mv(exception))) {} template template -PromiseForResult Promise::then(Func&& func, ErrorFunc&& errorHandler) { +PromiseForResult Promise::then(Func&& func, ErrorFunc&& errorHandler, + SourceLocation location) { typedef _::FixVoid<_::ReturnType> ResultT; - Own<_::PromiseNode> intermediate = - heap<_::TransformPromiseNode, Func, ErrorFunc>>( - kj::mv(node), kj::fwd(func), kj::fwd(errorHandler)); - return PromiseForResult(false, - _::maybeChain(kj::mv(intermediate), implicitCast(nullptr))); + void* continuationTracePtr = _::GetFunctorStartAddress<_::FixVoid&&>::apply(func); + _::OwnPromiseNode intermediate = + _::appendPromise<_::TransformPromiseNode, Func, ErrorFunc>>( + kj::mv(node), kj::fwd(func), kj::fwd(errorHandler), + continuationTracePtr); + auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType>>( + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); + return _::maybeReduce(kj::mv(result), false); } namespace _ { // private @@ -841,57 +1332,41 @@ struct IdentityFunc> { template template -Promise Promise::catch_(ErrorFunc&& errorHandler) { +Promise Promise::catch_(ErrorFunc&& errorHandler, SourceLocation location) { // then()'s ErrorFunc can only return a Promise if Func also returns a Promise. In this case, // Func is being filled in automatically. We want to make sure ErrorFunc can return a Promise, // but we don't want the extra overhead of promise chaining if ErrorFunc doesn't actually // return a promise. So we make our Func return match ErrorFunc. - return then(_::IdentityFunc()))>(), - kj::fwd(errorHandler)); + typedef _::IdentityFunc()))> Func; + typedef _::FixVoid<_::ReturnType> ResultT; + + // The reason catch_() isn't simply implemented in terms of then() is because we want the trace + // pointer to be based on ErrorFunc rather than Func. + void* continuationTracePtr = _::GetFunctorStartAddress::apply(errorHandler); + _::OwnPromiseNode intermediate = + _::appendPromise<_::TransformPromiseNode, Func, ErrorFunc>>( + kj::mv(node), Func(), kj::fwd(errorHandler), continuationTracePtr); + auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType>>( + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); + return _::maybeReduce(kj::mv(result), false); } template -T Promise::wait(WaitScope& waitScope) { +T Promise::wait(WaitScope& waitScope, SourceLocation location) { _::ExceptionOr<_::FixVoid> result; - - waitImpl(kj::mv(node), result, waitScope); - - KJ_IF_MAYBE(value, result.value) { - KJ_IF_MAYBE(exception, result.exception) { - throwRecoverableException(kj::mv(*exception)); - } - return _::returnMaybeVoid(kj::mv(*value)); - } else KJ_IF_MAYBE(exception, result.exception) { - throwFatalException(kj::mv(*exception)); - } else { - // Result contained neither a value nor an exception? - KJ_UNREACHABLE; - } + _::waitImpl(kj::mv(node), result, waitScope, location); + return convertToReturn(kj::mv(result)); } -template <> -inline void Promise::wait(WaitScope& waitScope) { - // Override case to use throwRecoverableException(). - - _::ExceptionOr<_::Void> result; - - waitImpl(kj::mv(node), result, waitScope); - - if (result.value != nullptr) { - KJ_IF_MAYBE(exception, result.exception) { - throwRecoverableException(kj::mv(*exception)); - } - } else KJ_IF_MAYBE(exception, result.exception) { - throwRecoverableException(kj::mv(*exception)); - } else { - // Result contained neither a value nor an exception? - KJ_UNREACHABLE; - } +template +bool Promise::poll(WaitScope& waitScope, SourceLocation location) { + return _::pollImpl(*node, waitScope, location); } template -ForkedPromise Promise::fork() { - return ForkedPromise(false, refcounted<_::ForkHub<_::FixVoid>>(kj::mv(node))); +ForkedPromise Promise::fork(SourceLocation location) { + return ForkedPromise(false, + _::PromiseDisposer::alloc<_::ForkHub<_::FixVoid>, _::ForkHubBase>(kj::mv(node), location)); } template @@ -900,34 +1375,41 @@ Promise ForkedPromise::addBranch() { } template -_::SplitTuplePromise Promise::split() { - return refcounted<_::ForkHub<_::FixVoid>>(kj::mv(node))->split(); +bool ForkedPromise::hasBranches() { + return hub->isShared(); +} + +template +_::SplitTuplePromise Promise::split(SourceLocation location) { + return _::PromiseDisposer::alloc<_::ForkHub<_::FixVoid>, _::ForkHubBase>( + kj::mv(node), location)->split(location); } template -Promise Promise::exclusiveJoin(Promise&& other) { - return Promise(false, heap<_::ExclusiveJoinPromiseNode>(kj::mv(node), kj::mv(other.node))); +Promise Promise::exclusiveJoin(Promise&& other, SourceLocation location) { + return Promise(false, _::appendPromise<_::ExclusiveJoinPromiseNode>( + kj::mv(node), kj::mv(other.node), location)); } template template Promise Promise::attach(Attachments&&... attachments) { - return Promise(false, kj::heap<_::AttachmentPromiseNode>>( + return Promise(false, _::appendPromise<_::AttachmentPromiseNode>>( kj::mv(node), kj::tuple(kj::fwd(attachments)...))); } template template -Promise Promise::eagerlyEvaluate(ErrorFunc&& errorHandler) { +Promise Promise::eagerlyEvaluate(ErrorFunc&& errorHandler, SourceLocation location) { // See catch_() for commentary. return Promise(false, _::spark<_::FixVoid>(then( _::IdentityFunc()))>(), - kj::fwd(errorHandler)).node)); + kj::fwd(errorHandler)).node, location)); } template -Promise Promise::eagerlyEvaluate(decltype(nullptr)) { - return Promise(false, _::spark<_::FixVoid>(kj::mv(node))); +Promise Promise::eagerlyEvaluate(decltype(nullptr), SourceLocation location) { + return Promise(false, _::spark<_::FixVoid>(kj::mv(node), location)); } template @@ -935,11 +1417,22 @@ kj::String Promise::trace() { return PromiseBase::trace(); } +template +inline Promise constPromise() { + static _::ConstPromiseNode NODE; + return _::PromiseNode::to>(_::OwnPromiseNode(&NODE)); +} + template inline PromiseForResult evalLater(Func&& func) { return _::yield().then(kj::fwd(func), _::PropagateException()); } +template +inline PromiseForResult evalLast(Func&& func) { + return _::yieldHarder().then(kj::fwd(func), _::PropagateException()); +} + template inline PromiseForResult evalNow(Func&& func) { PromiseForResult result = nullptr; @@ -951,6 +1444,68 @@ inline PromiseForResult evalNow(Func&& func) { return result; } +template +struct RetryOnDisconnect_ { + static inline PromiseForResult apply(Func&& func) { + return evalLater([func = kj::mv(func)]() mutable -> PromiseForResult { + auto promise = evalNow(func); + return promise.catch_([func = kj::mv(func)](kj::Exception&& e) mutable -> PromiseForResult { + if (e.getType() == kj::Exception::Type::DISCONNECTED) { + return func(); + } else { + return kj::mv(e); + } + }); + }); + } +}; +template +struct RetryOnDisconnect_ { + // Specialization for references. Needed because the syntax for capturing references in a + // lambda is different. :( + static inline PromiseForResult apply(Func& func) { + auto promise = evalLater(func); + return promise.catch_([&func](kj::Exception&& e) -> PromiseForResult { + if (e.getType() == kj::Exception::Type::DISCONNECTED) { + return func(); + } else { + return kj::mv(e); + } + }); + } +}; + +template +inline PromiseForResult retryOnDisconnect(Func&& func) { + return RetryOnDisconnect_::apply(kj::fwd(func)); +} + +template +inline PromiseForResult startFiber( + size_t stackSize, Func&& func, SourceLocation location) { + typedef _::FixVoid<_::ReturnType> ResultT; + + auto intermediate = _::allocPromise<_::Fiber>( + stackSize, kj::fwd(func), location); + intermediate->start(); + auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType>>( + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); + return _::maybeReduce(kj::mv(result), false); +} + +template +inline PromiseForResult FiberPool::startFiber( + Func&& func, SourceLocation location) const { + typedef _::FixVoid<_::ReturnType> ResultT; + + auto intermediate = _::allocPromise<_::Fiber>( + *this, kj::fwd(func), location); + intermediate->start(); + auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType>>( + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); + return _::maybeReduce(kj::mv(result), false); +} + template template void Promise::detach(ErrorFunc&& errorHandler) { @@ -964,18 +1519,47 @@ void Promise::detach(ErrorFunc&& errorHandler) { } template -Promise> joinPromises(Array>&& promises) { - return Promise>(false, kj::heap<_::ArrayJoinPromiseNode>( - KJ_MAP(p, promises) { return kj::mv(p.node); }, - heapArray<_::ExceptionOr>(promises.size()))); +Promise> joinPromises(Array>&& promises, SourceLocation location) { + return _::PromiseNode::to>>(_::allocPromise<_::ArrayJoinPromiseNode>( + KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); }, + heapArray<_::ExceptionOr>(promises.size()), location, + _::ArrayJoinBehavior::LAZY)); +} + +template +Promise> joinPromisesFailFast(Array>&& promises, SourceLocation location) { + return _::PromiseNode::to>>(_::allocPromise<_::ArrayJoinPromiseNode>( + KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); }, + heapArray<_::ExceptionOr>(promises.size()), location, + _::ArrayJoinBehavior::EAGER)); } // ======================================================================================= namespace _ { // private +class WeakFulfillerBase: protected kj::Disposer { +protected: + WeakFulfillerBase(): inner(nullptr) {} + virtual ~WeakFulfillerBase() noexcept(false) {} + + template + inline PromiseFulfiller* getInner() { + return static_cast*>(inner); + }; + template + inline void setInner(PromiseFulfiller* ptr) { + inner = ptr; + }; + +private: + mutable PromiseRejector* inner; + + void disposeImpl(void* pointer) const override; +}; + template -class WeakFulfiller final: public PromiseFulfiller, private kj::Disposer { +class WeakFulfiller final: public PromiseFulfiller, public WeakFulfillerBase { // A wrapper around PromiseFulfiller which can be detached. // // There are a couple non-trivialities here: @@ -990,7 +1574,7 @@ class WeakFulfiller final: public PromiseFulfiller, private kj::Disposer { // fulfiller and detach() is called when the promise is destroyed. public: - KJ_DISALLOW_COPY(WeakFulfiller); + KJ_DISALLOW_COPY_AND_MOVE(WeakFulfiller); static kj::Own make() { WeakFulfiller* ptr = new WeakFulfiller; @@ -998,54 +1582,37 @@ class WeakFulfiller final: public PromiseFulfiller, private kj::Disposer { } void fulfill(FixVoid&& value) override { - if (inner != nullptr) { - inner->fulfill(kj::mv(value)); + if (getInner() != nullptr) { + getInner()->fulfill(kj::mv(value)); } } void reject(Exception&& exception) override { - if (inner != nullptr) { - inner->reject(kj::mv(exception)); + if (getInner() != nullptr) { + getInner()->reject(kj::mv(exception)); } } bool isWaiting() override { - return inner != nullptr && inner->isWaiting(); + return getInner() != nullptr && getInner()->isWaiting(); } void attach(PromiseFulfiller& newInner) { - inner = &newInner; + setInner(&newInner); } void detach(PromiseFulfiller& from) { - if (inner == nullptr) { + if (getInner() == nullptr) { // Already disposed. delete this; } else { - KJ_IREQUIRE(inner == &from); - inner = nullptr; + KJ_IREQUIRE(getInner() == &from); + setInner(nullptr); } } private: - mutable PromiseFulfiller* inner; - - WeakFulfiller(): inner(nullptr) {} - - void disposeImpl(void* pointer) const override { - // TODO(perf): Factor some of this out so it isn't regenerated for every fulfiller type? - - if (inner == nullptr) { - // Already detached. - delete this; - } else { - if (inner->isWaiting()) { - inner->reject(kj::Exception(kj::Exception::Type::FAILED, __FILE__, __LINE__, - kj::heapString("PromiseFulfiller was destroyed without fulfilling the promise."))); - } - inner = nullptr; - } - } + WeakFulfiller() {} }; template @@ -1090,23 +1657,672 @@ bool PromiseFulfiller::rejectIfThrows(Func&& func) { } template -Promise newAdaptedPromise(Params&&... adapterConstructorParams) { - return Promise(false, heap<_::AdapterPromiseNode<_::FixVoid, Adapter>>( - kj::fwd(adapterConstructorParams)...)); +_::ReducePromises newAdaptedPromise(Params&&... adapterConstructorParams) { + _::OwnPromiseNode intermediate( + _::allocPromise<_::AdapterPromiseNode<_::FixVoid, Adapter>>( + kj::fwd(adapterConstructorParams)...)); + // We can't capture SourceLocation in this function's arguments since it is a vararg template. :( + return _::PromiseNode::to<_::ReducePromises>( + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), SourceLocation())); } template -PromiseFulfillerPair newPromiseAndFulfiller() { +PromiseFulfillerPair newPromiseAndFulfiller(SourceLocation location) { auto wrapper = _::WeakFulfiller::make(); - Own<_::PromiseNode> intermediate( - heap<_::AdapterPromiseNode<_::FixVoid, _::PromiseAndFulfillerAdapter>>(*wrapper)); - Promise<_::JoinPromises> promise(false, - _::maybeChain(kj::mv(intermediate), implicitCast(nullptr))); + _::OwnPromiseNode intermediate( + _::allocPromise<_::AdapterPromiseNode< + _::FixVoid, _::PromiseAndFulfillerAdapter>>(*wrapper)); + auto promise = _::PromiseNode::to<_::ReducePromises>( + _::maybeChain(kj::mv(intermediate), implicitCast(nullptr), location)); return PromiseFulfillerPair { kj::mv(promise), kj::mv(wrapper) }; } +// ======================================================================================= +// cross-thread stuff + +namespace _ { // (private) + +class XThreadEvent: public PromiseNode, // it's a PromiseNode in the requesting thread + private Event { // it's an event in the target thread +public: + XThreadEvent(ExceptionOrValue& result, const Executor& targetExecutor, EventLoop& loop, + void* funcTracePtr, SourceLocation location); + + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; + +protected: + void ensureDoneOrCanceled(); + // MUST be called in destructor of subclasses to make sure the object is not destroyed while + // still being accessed by the other thread. (This can't be placed in ~XThreadEvent() because + // that destructor doesn't run until the subclass has already been destroyed.) + + virtual kj::Maybe execute() = 0; + // Run the function. If the function returns a promise, returns the inner PromiseNode, otherwise + // returns null. + + // implements PromiseNode ---------------------------------------------------- + void onReady(Event* event) noexcept override; + +private: + ExceptionOrValue& result; + void* funcTracePtr; + + kj::Own targetExecutor; + Maybe replyExecutor; // If executeAsync() was used. + + kj::Maybe promiseNode; + // Accessed only in target thread. + + ListLink targetLink; + // Membership in one of the linked lists in the target Executor's work list or cancel list. These + // fields are protected by the target Executor's mutex. + + enum { + UNUSED, + // Object was never queued on another thread. + + QUEUED, + // Target thread has not yet dequeued the event from the state.start list. The requesting + // thread can cancel execution by removing the event from the list. + + EXECUTING, + // Target thread has dequeued the event from state.start and moved it to state.executing. To + // cancel, the requesting thread must add the event to the state.cancel list and change the + // state to CANCELING. + + CANCELING, + // Requesting thread is trying to cancel this event. The target thread will change the state to + // `DONE` once canceled. + + DONE + // Target thread has completed handling this event and will not touch it again. The requesting + // thread can safely delete the object. The `state` is updated to `DONE` using an atomic + // release operation after ensuring that the event will not be touched again, so that the + // requesting can safely skip locking if it observes the state is already DONE. + } state = UNUSED; + // State, which is also protected by `targetExecutor`'s mutex. + + ListLink replyLink; + // Membership in `replyExecutor`'s reply list. Protected by `replyExecutor`'s mutex. The + // executing thread places the event in the reply list near the end of the `EXECUTING` state. + // Because the thread cannot lock two mutexes at once, it's possible that the reply executor + // will receive the reply while the event is still listed in the EXECUTING state, but it can + // ignore the state and proceed with the result. + + OnReadyEvent onReadyEvent; + // Accessed only in requesting thread. + + friend class kj::Executor; + + void done(); + // Sets the state to `DONE` and notifies the originating thread that this event is done. Do NOT + // call under lock. + + void sendReply(); + // Notifies the originating thread that this event is done, but doesn't set the state to DONE + // yet. Do NOT call under lock. + + void setDoneState(); + // Assigns `state` to `DONE`, being careful to use an atomic-release-store if needed. This must + // only be called in the destination thread, and must either be called under lock, or the thread + // must take the lock and release it again shortly after setting the state (because some threads + // may be waiting on the DONE state using a conditional wait on the mutex). After calling + // setDoneState(), the destination thread MUST NOT touch this object ever again; it now belongs + // solely to the requesting thread. + + void setDisconnected(); + // Sets the result to a DISCONNECTED exception indicating that the target event loop exited. + + class DelayedDoneHack; + + // implements Event ---------------------------------------------------------- + Maybe> fire() override; + // If called with promiseNode == nullptr, it's time to call execute(). If promiseNode != nullptr, + // then it just indicated readiness and we need to get its result. + + void traceEvent(TraceBuilder& builder) override; +}; + +template >> +class XThreadEventImpl final: public XThreadEvent { + // Implementation for a function that does not return a Promise. +public: + XThreadEventImpl(Func&& func, const Executor& target, EventLoop& loop, SourceLocation location) + : XThreadEvent(result, target, loop, GetFunctorStartAddress<>::apply(func), location), + func(kj::fwd(func)) {} + ~XThreadEventImpl() noexcept(false) { ensureDoneOrCanceled(); } + void destroy() override { freePromise(this); } + + typedef _::FixVoid<_::ReturnType> ResultT; + + kj::Maybe<_::OwnPromiseNode> execute() override { + result.value = MaybeVoidCaller>::apply(func, Void()); + return nullptr; + } + + // implements PromiseNode ---------------------------------------------------- + void get(ExceptionOrValue& output) noexcept override { + output.as() = kj::mv(result); + } + +private: + Func func; + ExceptionOr result; + friend Executor; +}; + +template +class XThreadEventImpl> final: public XThreadEvent { + // Implementation for a function that DOES return a Promise. +public: + XThreadEventImpl(Func&& func, const Executor& target, EventLoop& loop, SourceLocation location) + : XThreadEvent(result, target, loop, GetFunctorStartAddress<>::apply(func), location), + func(kj::fwd(func)) {} + ~XThreadEventImpl() noexcept(false) { ensureDoneOrCanceled(); } + void destroy() override { freePromise(this); } + + typedef _::FixVoid<_::UnwrapPromise>> ResultT; + + kj::Maybe<_::OwnPromiseNode> execute() override { + auto result = _::PromiseNode::from(func()); + KJ_IREQUIRE(result.get() != nullptr); + return kj::mv(result); + } + + // implements PromiseNode ---------------------------------------------------- + void get(ExceptionOrValue& output) noexcept override { + output.as() = kj::mv(result); + } + +private: + Func func; + ExceptionOr result; + friend Executor; +}; + +} // namespace _ (private) + +template +_::UnwrapPromise> Executor::executeSync( + Func&& func, SourceLocation location) const { + _::XThreadEventImpl event(kj::fwd(func), *this, getLoop(), location); + send(event, true); + return convertToReturn(kj::mv(event.result)); +} + +template +PromiseForResult Executor::executeAsync(Func&& func, SourceLocation location) const { + // HACK: We call getLoop() here, rather than have XThreadEvent's constructor do it, so that if it + // throws we don't crash due to `allocPromise()` being `noexcept`. + auto event = _::allocPromise<_::XThreadEventImpl>( + kj::fwd(func), *this, getLoop(), location); + send(*event, false); + return _::PromiseNode::to>(kj::mv(event)); +} + +// ----------------------------------------------------------------------------- + +namespace _ { // (private) + +template +class XThreadFulfiller; + +class XThreadPaf: public PromiseNode { +public: + XThreadPaf(); + virtual ~XThreadPaf() noexcept(false); + void destroy() override; + + // implements PromiseNode ---------------------------------------------------- + void onReady(Event* event) noexcept override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; + +private: + enum { + WAITING, + // Not yet fulfilled, and the waiter is still waiting. + // + // Starting from this state, the state may transition to either FULFILLING or CANCELED + // using an atomic compare-and-swap. + + FULFILLING, + // The fulfiller thread atomically transitions the state from WAITING to FULFILLING when it + // wishes to fulfill the promise. By doing so, it guarantees that the `executor` will not + // disappear out from under it. It then fills in the result value, locks the executor mutex, + // adds the object to the executor's list of fulfilled XThreadPafs, changes the state to + // FULFILLED, and finally unlocks the mutex. + // + // If the waiting thread tries to cancel but discovers the object in this state, then it + // must perform a conditional wait on the executor mutex to await the state becoming FULFILLED. + // It can then delete the object. + + FULFILLED, + // The fulfilling thread has completed filling in the result value and inserting the object + // into the waiting thread's executor event queue. Moreover, the fulfilling thread no longer + // holds any pointers to this object. The waiting thread is responsible for deleting it. + + DISPATCHED, + // The object reached FULFILLED state, and then was dispatched from the waiting thread's + // executor's event queue. Therefore, the object is completely owned by the waiting thread with + // no need to lock anything. + + CANCELED + // The waiting thread atomically transitions the state from WAITING to CANCELED if it is no + // longer listening. In this state, it is the fulfiller thread's responsibility to destroy the + // object. + } state; + + const Executor& executor; + // Executor of the waiting thread. Only guaranteed to be valid when state is `WAITING` or + // `FULFILLING`. After any other state has been reached, this reference may be invalidated. + + ListLink link; + // In the FULFILLING/FULFILLED states, the object is placed in a linked list within the waiting + // thread's executor. In those states, these pointers are guarded by said executor's mutex. + + OnReadyEvent onReadyEvent; + + class FulfillScope; + + static kj::Exception unfulfilledException(); + // Construct appropriate exception to use to reject an unfulfilled XThreadPaf. + + template + friend class XThreadFulfiller; + friend Executor; +}; + +template +class XThreadPafImpl final: public XThreadPaf { +public: + // implements PromiseNode ---------------------------------------------------- + void get(ExceptionOrValue& output) noexcept override { + output.as>() = kj::mv(result); + } + +private: + ExceptionOr> result; + + friend class XThreadFulfiller; +}; + +class XThreadPaf::FulfillScope { + // Create on stack while setting `XThreadPafImpl::result`. + // + // This ensures that: + // - Only one call is carried out, even if multiple threads try to fulfill concurrently. + // - The waiting thread is correctly signaled. +public: + FulfillScope(XThreadPaf** pointer); + // Atomically nulls out *pointer and takes ownership of the pointer. + + ~FulfillScope() noexcept(false); + + KJ_DISALLOW_COPY_AND_MOVE(FulfillScope); + + bool shouldFulfill() { return obj != nullptr; } + + template + XThreadPafImpl* getTarget() { return static_cast*>(obj); } + +private: + XThreadPaf* obj; +}; + +template +class XThreadFulfiller final: public CrossThreadPromiseFulfiller { +public: + XThreadFulfiller(XThreadPafImpl* target): target(target) {} + + ~XThreadFulfiller() noexcept(false) { + if (target != nullptr) { + reject(XThreadPaf::unfulfilledException()); + } + } + void fulfill(FixVoid&& value) const override { + XThreadPaf::FulfillScope scope(&target); + if (scope.shouldFulfill()) { + scope.getTarget()->result = kj::mv(value); + } + } + void reject(Exception&& exception) const override { + XThreadPaf::FulfillScope scope(&target); + if (scope.shouldFulfill()) { + scope.getTarget()->result.addException(kj::mv(exception)); + } + } + bool isWaiting() const override { + KJ_IF_MAYBE(t, target) { +#if _MSC_VER && !__clang__ + // Just assume 1-byte loads are atomic... on what kind of absurd platform would they not be? + return t->state == XThreadPaf::WAITING; +#else + return __atomic_load_n(&t->state, __ATOMIC_RELAXED) == XThreadPaf::WAITING; +#endif + } else { + return false; + } + } + +private: + mutable XThreadPaf* target; // accessed using atomic ops +}; + +template +class XThreadFulfiller> { +public: + static_assert(sizeof(T) < 0, + "newCrosssThreadPromiseAndFulfiller>() is not currently supported"); + // TODO(someday): Is this worth supporting? Presumably, when someone calls `fulfill(somePromise)`, + // then `somePromise` should be assumed to be a promise owned by the fulfilling thread, not + // the waiting thread. +}; + +} // namespace _ (private) + +template +PromiseCrossThreadFulfillerPair newPromiseAndCrossThreadFulfiller() { + kj::Own<_::XThreadPafImpl, _::PromiseDisposer> node(new _::XThreadPafImpl); + auto fulfiller = kj::heap<_::XThreadFulfiller>(node); + return { _::PromiseNode::to<_::ReducePromises>(kj::mv(node)), kj::mv(fulfiller) }; +} + } // namespace kj -#endif // KJ_ASYNC_INL_H_ +#if KJ_HAS_COROUTINE + +// ======================================================================================= +// Coroutines TS integration with kj::Promise. +// +// Here's a simple coroutine: +// +// Promise> connectToService(Network& n) { +// auto a = co_await n.parseAddress(IP, PORT); +// auto c = co_await a->connect(); +// co_return kj::mv(c); +// } +// +// The presence of the co_await and co_return keywords tell the compiler it is a coroutine. +// Although it looks similar to a function, it has a couple large differences. First, everything +// that would normally live in the stack frame lives instead in a heap-based coroutine frame. +// Second, the coroutine has the ability to return from its scope without deallocating this frame +// (to suspend, in other words), and the ability to resume from its last suspension point. +// +// In order to know how to suspend, resume, and return from a coroutine, the compiler looks up a +// coroutine implementation type via a traits class parameterized by the coroutine return and +// parameter types. We'll name our coroutine implementation `kj::_::Coroutine`, + +namespace kj::_ { template class Coroutine; } + +// Specializing the appropriate traits class tells the compiler about `kj::_::Coroutine`. + +namespace KJ_COROUTINE_STD_NAMESPACE { + +template +struct coroutine_traits, Args...> { + // `Args...` are the coroutine's parameter types. + + using promise_type = kj::_::Coroutine; + // The Coroutines TS calls this the "promise type". This makes sense when thinking of coroutines + // returning `std::future`, since the coroutine implementation would be a wrapper around + // a `std::promise`. It's extremely confusing from a KJ perspective, however, so I call it + // the "coroutine implementation type" instead. +}; + +} // namespace KJ_COROUTINE_STD_NAMESPACE + +// Now when the compiler sees our `connectToService()` coroutine above, it default-constructs a +// `coroutine_traits>, Network&>::promise_type`, or +// `kj::_::Coroutine>`. +// +// The implementation object lives in the heap-allocated coroutine frame. It gets destroyed and +// deallocated when the frame does. + +namespace kj::_ { + +namespace stdcoro = KJ_COROUTINE_STD_NAMESPACE; + +class CoroutineBase: public PromiseNode, + public Event { +public: + CoroutineBase(stdcoro::coroutine_handle<> coroutine, ExceptionOrValue& resultRef, + SourceLocation location); + ~CoroutineBase() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(CoroutineBase); + void destroy() override; + + auto initial_suspend() { return stdcoro::suspend_never(); } + auto final_suspend() noexcept { return stdcoro::suspend_always(); } + // These adjust the suspension behavior of coroutines immediately upon initiation, and immediately + // after completion. + // + // The initial suspension point could allow us to defer the initial synchronous execution of a + // coroutine -- everything before its first co_await, that is. + // + // The final suspension point is useful to delay deallocation of the coroutine frame to match the + // lifetime of the enclosing promise. + + void unhandled_exception(); + +protected: + class AwaiterBase; + + bool isWaiting() { return waiting; } + void scheduleResumption() { + onReadyEvent.arm(); + waiting = false; + } + +private: + // ------------------------------------------------------- + // PromiseNode implementation + + void onReady(Event* event) noexcept override; + void tracePromise(TraceBuilder& builder, bool stopAtNextEvent) override; + + // ------------------------------------------------------- + // Event implementation + + Maybe> fire() override; + void traceEvent(TraceBuilder& builder) override; + + stdcoro::coroutine_handle<> coroutine; + ExceptionOrValue& resultRef; + + OnReadyEvent onReadyEvent; + bool waiting = true; + + bool hasSuspendedAtLeastOnce = false; + + Maybe promiseNodeForTrace; + // Whenever this coroutine is suspended waiting on another promise, we keep a reference to that + // promise so tracePromise()/traceEvent() can trace into it. + + UnwindDetector unwindDetector; + + struct DisposalResults { + bool destructorRan = false; + Maybe exception; + }; + Maybe maybeDisposalResults; + // Only non-null during destruction. Before calling coroutine.destroy(), our disposer sets this + // to point to a DisposalResults on the stack so unhandled_exception() will have some place to + // store unwind exceptions. We can't store them in this Coroutine, because we'll be destroyed once + // coroutine.destroy() has returned. Our disposer then rethrows as needed. +}; + +template +class CoroutineMixin; +// CRTP mixin, covered later. + +template +class Coroutine final: public CoroutineBase, + public CoroutineMixin, T> { + // The standard calls this the `promise_type` object. We can call this the "coroutine + // implementation object" since the word promise means different things in KJ and std styles. This + // is where we implement how a `kj::Promise` is returned from a coroutine, and how that promise + // is later fulfilled. We also fill in a few lifetime-related details. + // + // The implementation object is also where we can customize memory allocation of coroutine frames, + // by implementing a member `operator new(size_t, Args...)` (same `Args...` as in + // coroutine_traits). + // + // We can also customize how await-expressions are transformed within `kj::Promise`-based + // coroutines by implementing an `await_transform(P)` member function, where `P` is some type for + // which we want to implement co_await support, e.g. `kj::Promise`. This feature allows us to + // provide an optimized `kj::EventLoop` integration when the coroutine's return type and the + // await-expression's type are both `kj::Promise` instantiations -- see further comments under + // `await_transform()`. + +public: + using Handle = stdcoro::coroutine_handle>; + + Coroutine(SourceLocation location = {}) + : CoroutineBase(Handle::from_promise(*this), result, location) {} + + Promise get_return_object() { + // Called after coroutine frame construction and before initial_suspend() to create the + // coroutine's return object. `this` itself lives inside the coroutine frame, and we arrange for + // the returned Promise to own `this` via a custom Disposer and by always leaving the + // coroutine in a suspended state. + return PromiseNode::to>(OwnPromiseNode(this)); + } + +public: + template + class Awaiter; + + template + Awaiter await_transform(kj::Promise& promise) { return Awaiter(kj::mv(promise)); } + template + Awaiter await_transform(kj::Promise&& promise) { return Awaiter(kj::mv(promise)); } + // Called when someone writes `co_await promise`, where `promise` is a kj::Promise. We return + // an Awaiter, which implements coroutine suspension and resumption in terms of the KJ async + // event system. + // + // There is another hook we could implement: an `operator co_await()` free function. However, a + // free function would be unaware of the type of the enclosing coroutine. Since Awaiter is a + // member class template of Coroutine, it is able to implement an + // `await_suspend(Coroutine::Handle)` override, providing it type-safe access to our enclosing + // coroutine's PromiseNode. An `operator co_await()` free function would have to implement + // a type-erased `await_suspend(stdcoro::coroutine_handle)` override, and implement + // suspension and resumption in terms of .then(). Yuck! + +private: + // ------------------------------------------------------- + // PromiseNode implementation + + void get(ExceptionOrValue& output) noexcept override { + output.as>() = kj::mv(result); + } + + void fulfill(FixVoid&& value) { + // Called by the return_value()/return_void() functions in our mixin class. + + if (isWaiting()) { + result = kj::mv(value); + scheduleResumption(); + } + } + + ExceptionOr> result; + + friend class CoroutineMixin, T>; +}; + +template +class CoroutineMixin { +public: + void return_value(T value) { + static_cast(this)->fulfill(kj::mv(value)); + } +}; +template +class CoroutineMixin { +public: + void return_void() { + static_cast(this)->fulfill(_::Void()); + } +}; +// The Coroutines spec has no `_::FixVoid` equivalent to unify valueful and valueless co_return +// statements, and programs are ill-formed if the coroutine implementation object (Coroutine) has +// both a `return_value()` and `return_void()`. No amount of EnableIffery can get around it, so +// these return_* functions live in a CRTP mixin. + +class CoroutineBase::AwaiterBase { +public: + explicit AwaiterBase(OwnPromiseNode node); + AwaiterBase(AwaiterBase&&); + ~AwaiterBase() noexcept(false); + KJ_DISALLOW_COPY(AwaiterBase); + + bool await_ready() const { return false; } + // This could return "`node->get()` is safe to call" instead, which would make suspension-less + // co_awaits possible for immediately-fulfilled promises. However, we need an Event to figure that + // out, and we won't have access to the Coroutine Event until await_suspend() is called. So, we + // must return false here. Fortunately, await_suspend() has a trick up its sleeve to enable + // suspension-less co_awaits. + +protected: + void getImpl(ExceptionOrValue& result, void* awaitedAt); + bool awaitSuspendImpl(CoroutineBase& coroutineEvent); + +private: + UnwindDetector unwindDetector; + OwnPromiseNode node; + + Maybe maybeCoroutineEvent; + // If we do suspend waiting for our wrapped promise, we store a reference to `node` in our + // enclosing Coroutine for tracing purposes. To guard against any edge cases where an async stack + // trace is generated when an Awaiter was destroyed without Coroutine::fire() having been called, + // we need our own reference to the enclosing Coroutine. (I struggle to think up any such + // scenarios, but perhaps they could occur when destroying a suspended coroutine.) +}; + +template +template +class Coroutine::Awaiter: public AwaiterBase { + // Wrapper around a co_await'ed promise and some storage space for the result of that promise. + // The compiler arranges to call our await_suspend() to suspend, which arranges to be woken up + // when the awaited promise is settled. Once that happens, the enclosing coroutine's Event + // implementation resumes the coroutine, which transitively calls await_resume() to unwrap the + // awaited promise result. + +public: + explicit Awaiter(Promise promise): AwaiterBase(PromiseNode::from(kj::mv(promise))) {} + + U await_resume() KJ_NOINLINE { + // This is marked noinline in order to ensure __builtin_return_address() is accurate for stack + // trace purposes. In my experimentation, this method was not inlined anyway even in opt + // builds, but I want to make sure it doesn't suddenly start being inlined later causing stack + // traces to break. (I also tried always-inline, but this did not appear to cause the compiler + // to inline the method -- perhaps a limitation of coroutines?) +#if __GNUC__ + getImpl(result, __builtin_return_address(0)); +#elif _MSC_VER + getImpl(result, _ReturnAddress()); +#else + #error "please implement for your compiler" +#endif + auto value = kj::_::readMaybe(result.value); + KJ_IASSERT(value != nullptr, "Neither exception nor value present."); + return U(kj::mv(*value)); + } + + bool await_suspend(Coroutine::Handle coroutine) { + return awaitSuspendImpl(coroutine.promise()); + } + +private: + ExceptionOr> result; +}; + +#undef KJ_COROUTINE_STD_NAMESPACE + +} // namespace kj::_ (private) + +#endif // KJ_HAS_COROUTINE + +KJ_END_HEADER diff --git a/c++/src/kj/async-io-internal.h b/c++/src/kj/async-io-internal.h new file mode 100644 index 0000000000..d030ad9577 --- /dev/null +++ b/c++/src/kj/async-io-internal.h @@ -0,0 +1,70 @@ +// Copyright (c) 2017 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "string.h" +#include "vector.h" +#include "async-io.h" +#include +#include "one-of.h" +#include "cidr.h" + +KJ_BEGIN_HEADER + +struct sockaddr; +struct sockaddr_un; + +namespace kj { +namespace _ { // private + +// ======================================================================================= + +#if !_WIN32 +kj::ArrayPtr safeUnixPath(const struct sockaddr_un* addr, uint addrlen); +// sockaddr_un::sun_path is not required to have a NUL terminator! Thus to be safe unix address +// paths MUST be read using this function. +#endif + +class NetworkFilter: public LowLevelAsyncIoProvider::NetworkFilter { +public: + NetworkFilter(); + NetworkFilter(ArrayPtr allow, ArrayPtr deny, + NetworkFilter& next); + + bool shouldAllow(const struct sockaddr* addr, uint addrlen) override; + bool shouldAllowParse(const struct sockaddr* addr, uint addrlen); + +private: + Vector allowCidrs; + Vector denyCidrs; + bool allowUnix; + bool allowAbstractUnix; + bool allowPublic = false; + bool allowNetwork = false; + + kj::Maybe next; +}; + +} // namespace _ (private) +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/async-io-test.c++ b/c++/src/kj/async-io-test.c++ index a7c8631aa8..e8892b79e6 100644 --- a/c++/src/kj/async-io-test.c++ +++ b/c++/src/kj/async-io-test.c++ @@ -19,15 +19,35 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if _WIN32 +// Request Vista-level APIs. +#include "win32-api-version.h" +#elif !defined(_GNU_SOURCE) +#define _GNU_SOURCE +#endif + #include "async-io.h" +#include "async-io-internal.h" #include "debug.h" +#include "io.h" +#include "cidr.h" +#include "miniposix.h" #include +#include #include +#include #if _WIN32 #include #include "windows-sanity.h" +#define inet_pton InetPtonA +#define inet_ntop InetNtopA #else #include +#include +#include +#include +#include +#include #endif namespace kj { @@ -71,15 +91,301 @@ TEST(AsyncIo, SimpleNetwork) { EXPECT_EQ("foo", result); } +#if !_WIN32 // TODO(someday): Implement NetworkPeerIdentity for Win32. +TEST(AsyncIo, SimpleNetworkAuthentication) { + auto ioContext = setupAsyncIo(); + auto& network = ioContext.provider->getNetwork(); + + Own listener; + Own server; + Own client; + + char receiveBuffer[4]; + + auto port = newPromiseAndFulfiller(); + + port.promise.then([&](uint portnum) { + return network.parseAddress("localhost", portnum); + }).then([&](Own&& addr) { + auto promise = addr->connectAuthenticated(); + return promise.then([&,addr=kj::mv(addr)](AuthenticatedStream result) mutable { + auto id = result.peerIdentity.downcast(); + + // `addr` was resolved from `localhost` and may contain multiple addresses, but + // result.peerIdentity tells us the specific address that was used. So it should be one + // of the ones on the list, but only one. + KJ_EXPECT(strstr(addr->toString().cStr(), id->getAddress().toString().cStr()) != nullptr); + KJ_EXPECT(id->getAddress().toString().findFirst(',') == nullptr); + + client = kj::mv(result.stream); + + // `id` should match client->getpeername(). + union { + struct sockaddr generic; + struct sockaddr_in ip4; + struct sockaddr_in6 ip6; + } rawAddr; + uint len = sizeof(rawAddr); + client->getpeername(&rawAddr.generic, &len); + auto peername = network.getSockaddr(&rawAddr.generic, len); + KJ_EXPECT(id->toString() == peername->toString()); + + return client->write("foo", 3); + }); + }).detach([](kj::Exception&& exception) { + KJ_FAIL_EXPECT(exception); + }); + + kj::String result = network.parseAddress("*").then([&](Own&& result) { + listener = result->listen(); + port.fulfiller->fulfill(listener->getPort()); + return listener->acceptAuthenticated(); + }).then([&](AuthenticatedStream result) { + auto id = result.peerIdentity.downcast(); + server = kj::mv(result.stream); + + // `id` should match server->getpeername(). + union { + struct sockaddr generic; + struct sockaddr_in ip4; + struct sockaddr_in6 ip6; + } addr; + uint len = sizeof(addr); + server->getpeername(&addr.generic, &len); + auto peername = network.getSockaddr(&addr.generic, len); + KJ_EXPECT(id->toString() == peername->toString()); + + return server->tryRead(receiveBuffer, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer, n); + }).wait(ioContext.waitScope); + + EXPECT_EQ("foo", result); +} +#endif + +#if !_WIN32 && !__CYGWIN__ // TODO(someday): Debug why this deadlocks on Cygwin. + +#if __ANDROID__ +#define TMPDIR "/data/local/tmp" +#else +#define TMPDIR "/tmp" +#endif + +TEST(AsyncIo, UnixSocket) { + auto ioContext = setupAsyncIo(); + auto& network = ioContext.provider->getNetwork(); + + auto path = kj::str(TMPDIR "/kj-async-io-test.", getpid()); + KJ_DEFER(unlink(path.cStr())); + + Own listener; + Own server; + Own client; + + char receiveBuffer[4]; + + auto ready = newPromiseAndFulfiller(); + + ready.promise.then([&]() { + return network.parseAddress(kj::str("unix:", path)); + }).then([&](Own&& addr) { + auto promise = addr->connectAuthenticated(); + return promise.then([&,addr=kj::mv(addr)](AuthenticatedStream result) mutable { + auto id = result.peerIdentity.downcast(); + auto creds = id->getCredentials(); + KJ_IF_MAYBE(p, creds.pid) { + KJ_EXPECT(*p == getpid()); +#if __linux__ || __APPLE__ + } else { + KJ_FAIL_EXPECT("LocalPeerIdentity for unix socket had null PID"); +#endif + } + KJ_IF_MAYBE(u, creds.uid) { + KJ_EXPECT(*u == getuid()); + } else { + KJ_FAIL_EXPECT("LocalPeerIdentity for unix socket had null UID"); + } + + client = kj::mv(result.stream); + return client->write("foo", 3); + }); + }).detach([](kj::Exception&& exception) { + KJ_FAIL_EXPECT(exception); + }); + + kj::String result = network.parseAddress(kj::str("unix:", path)) + .then([&](Own&& result) { + listener = result->listen(); + ready.fulfiller->fulfill(); + return listener->acceptAuthenticated(); + }).then([&](AuthenticatedStream result) { + auto id = result.peerIdentity.downcast(); + auto creds = id->getCredentials(); + KJ_IF_MAYBE(p, creds.pid) { + KJ_EXPECT(*p == getpid()); +#if __linux__ || __APPLE__ + } else { + KJ_FAIL_EXPECT("LocalPeerIdentity for unix socket had null PID"); +#endif + } + KJ_IF_MAYBE(u, creds.uid) { + KJ_EXPECT(*u == getuid()); + } else { + KJ_FAIL_EXPECT("LocalPeerIdentity for unix socket had null UID"); + } + + server = kj::mv(result.stream); + return server->tryRead(receiveBuffer, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer, n); + }).wait(ioContext.waitScope); + + EXPECT_EQ("foo", result); +} + +TEST(AsyncIo, AncillaryMessageHandlerNoMsg) { + auto ioContext = setupAsyncIo(); + auto& network = ioContext.provider->getNetwork(); + + Own listener; + Own server; + Own client; + + char receiveBuffer[4]; + + bool clientHandlerCalled = false; + kj::Function)> clientHandler = + [&](kj::ArrayPtr) { + clientHandlerCalled = true; + }; + bool serverHandlerCalled = false; + kj::Function)> serverHandler = + [&](kj::ArrayPtr) { + serverHandlerCalled = true; + }; + + auto port = newPromiseAndFulfiller(); + + port.promise.then([&](uint portnum) { + return network.parseAddress("localhost", portnum); + }).then([&](Own&& addr) { + auto promise = addr->connectAuthenticated(); + return promise.then([&,addr=kj::mv(addr)](AuthenticatedStream result) mutable { + client = kj::mv(result.stream); + client->registerAncillaryMessageHandler(kj::mv(clientHandler)); + return client->write("foo", 3); + }); + }).detach([](kj::Exception&& exception) { + KJ_FAIL_EXPECT(exception); + }); + + kj::String result = network.parseAddress("*").then([&](Own&& result) { + listener = result->listen(); + port.fulfiller->fulfill(listener->getPort()); + return listener->acceptAuthenticated(); + }).then([&](AuthenticatedStream result) { + server = kj::mv(result.stream); + server->registerAncillaryMessageHandler(kj::mv(serverHandler)); + return server->tryRead(receiveBuffer, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer, n); + }).wait(ioContext.waitScope); + + EXPECT_EQ("foo", result); + EXPECT_FALSE(clientHandlerCalled); + EXPECT_FALSE(serverHandlerCalled); +} +#endif + +// This test uses SO_TIMESTAMP on a SOCK_STREAM, which is only supported by Linux. Ideally we'd +// rewrite the test to use some other message type that is widely supported on streams. But for +// now we just limit the test to Linux. Also, it doesn't work on Android for some reason, and it +// isn't worth investigating, so we skip it there. +#if __linux__ && !__ANDROID__ +TEST(AsyncIo, AncillaryMessageHandler) { + auto ioContext = setupAsyncIo(); + auto& network = ioContext.provider->getNetwork(); + + Own listener; + Own server; + Own client; + + char receiveBuffer[4]; + + bool clientHandlerCalled = false; + kj::Function)> clientHandler = + [&](kj::ArrayPtr) { + clientHandlerCalled = true; + }; + bool serverHandlerCalled = false; + kj::Function)> serverHandler = + [&](kj::ArrayPtr msgs) { + serverHandlerCalled = true; + EXPECT_EQ(1, msgs.size()); + EXPECT_EQ(SOL_SOCKET, msgs[0].getLevel()); + EXPECT_EQ(SO_TIMESTAMP, msgs[0].getType()); + }; + + auto port = newPromiseAndFulfiller(); + + port.promise.then([&](uint portnum) { + return network.parseAddress("localhost", portnum); + }).then([&](Own&& addr) { + auto promise = addr->connectAuthenticated(); + return promise.then([&,addr=kj::mv(addr)](AuthenticatedStream result) mutable { + client = kj::mv(result.stream); + client->registerAncillaryMessageHandler(kj::mv(clientHandler)); + return client->write("foo", 3); + }); + }).detach([](kj::Exception&& exception) { + KJ_FAIL_EXPECT(exception); + }); + + kj::String result = network.parseAddress("*").then([&](Own&& result) { + listener = result->listen(); + // Register interest in having the timestamp delivered via cmsg on each recvmsg. + int yes = 1; + listener->setsockopt(SOL_SOCKET, SO_TIMESTAMP, &yes, sizeof(yes)); + port.fulfiller->fulfill(listener->getPort()); + return listener->acceptAuthenticated(); + }).then([&](AuthenticatedStream result) { + server = kj::mv(result.stream); + server->registerAncillaryMessageHandler(kj::mv(serverHandler)); + return server->tryRead(receiveBuffer, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer, n); + }).wait(ioContext.waitScope); + + EXPECT_EQ("foo", result); + EXPECT_FALSE(clientHandlerCalled); + EXPECT_TRUE(serverHandlerCalled); +} +#endif + String tryParse(WaitScope& waitScope, Network& network, StringPtr text, uint portHint = 0) { return network.parseAddress(text, portHint).wait(waitScope)->toString(); } -bool hasIpv6() { - // Can getaddrinfo() parse ipv6 addresses? This is only true if ipv6 is configured on at least - // one interface. (The loopback interface usually has it even if others don't... but not always.) +bool systemSupportsAddress(StringPtr addr, StringPtr service = nullptr) { + // Can getaddrinfo() parse this addresses? This is only true if the address family (e.g., ipv6) + // is configured on at least one interface. (The loopback interface usually has both ipv4 and + // ipv6 configured, but not always.) + struct addrinfo hints; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = 0; + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG; + hints.ai_protocol = 0; + hints.ai_canonname = nullptr; + hints.ai_addr = nullptr; + hints.ai_next = nullptr; struct addrinfo* list; - int status = getaddrinfo("::", nullptr, nullptr, &list); + int status = getaddrinfo( + addr.cStr(), service == nullptr ? nullptr : service.cStr(), &hints, &list); if (status == 0) { freeaddrinfo(list); return true; @@ -100,23 +406,32 @@ TEST(AsyncIo, AddressParsing) { #if !_WIN32 EXPECT_EQ("unix:foo/bar/baz", tryParse(w, network, "unix:foo/bar/baz")); + EXPECT_EQ("unix-abstract:foo/bar/baz", tryParse(w, network, "unix-abstract:foo/bar/baz")); #endif // We can parse services by name... -#if !__ANDROID__ // Service names not supported on Android for some reason? - EXPECT_EQ("1.2.3.4:80", tryParse(w, network, "1.2.3.4:http", 5678)); - EXPECT_EQ("*:80", tryParse(w, network, "*:http", 5678)); -#endif + // + // For some reason, Android and some various Linux distros do not support service names. + if (systemSupportsAddress("1.2.3.4", "http")) { + EXPECT_EQ("1.2.3.4:80", tryParse(w, network, "1.2.3.4:http", 5678)); + EXPECT_EQ("*:80", tryParse(w, network, "*:http", 5678)); + } else { + KJ_LOG(WARNING, "system does not support resolving service names on ipv4; skipping tests"); + } // IPv6 tests. Annoyingly, these don't work on machines that don't have IPv6 configured on any // interfaces. - if (hasIpv6()) { + if (systemSupportsAddress("::")) { EXPECT_EQ("[::]:123", tryParse(w, network, "0::0", 123)); EXPECT_EQ("[12ab:cd::34]:321", tryParse(w, network, "[12ab:cd:0::0:34]:321", 432)); -#if !__ANDROID__ // Service names not supported on Android for some reason? - EXPECT_EQ("[::]:80", tryParse(w, network, "[::]:http", 5678)); - EXPECT_EQ("[12ab:cd::34]:80", tryParse(w, network, "[12ab:cd::34]:http", 5678)); -#endif + if (systemSupportsAddress("12ab:cd::34", "http")) { + EXPECT_EQ("[::]:80", tryParse(w, network, "[::]:http", 5678)); + EXPECT_EQ("[12ab:cd::34]:80", tryParse(w, network, "[12ab:cd::34]:http", 5678)); + } else { + KJ_LOG(WARNING, "system does not support resolving service names on ipv6; skipping tests"); + } + } else { + KJ_LOG(WARNING, "system does not support ipv6; skipping tests"); } // It would be nice to test DNS lookup here but the test would not be very hermetic. Even @@ -169,6 +484,320 @@ TEST(AsyncIo, TwoWayPipe) { EXPECT_EQ("bar", result2); } +TEST(AsyncIo, InMemoryCapabilityPipe) { + EventLoop loop; + WaitScope waitScope(loop); + + auto pipe = newCapabilityPipe(); + auto pipe2 = newCapabilityPipe(); + char receiveBuffer1[4]; + char receiveBuffer2[4]; + + // Expect to receive a stream, then read "foo" from it, then write "bar" to it. + Own receivedStream; + auto promise = pipe2.ends[1]->receiveStream() + .then([&](Own stream) { + receivedStream = kj::mv(stream); + return receivedStream->tryRead(receiveBuffer2, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return receivedStream->write("bar", 3).then([&receiveBuffer2,n]() { + return heapString(receiveBuffer2, n); + }); + }); + + // Send a stream, then write "foo" to the other end of the sent stream, then receive "bar" + // from it. + kj::String result = pipe2.ends[0]->sendStream(kj::mv(pipe.ends[1])) + .then([&]() { + return pipe.ends[0]->write("foo", 3); + }).then([&]() { + return pipe.ends[0]->tryRead(receiveBuffer1, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer1, n); + }).wait(waitScope); + + kj::String result2 = promise.wait(waitScope); + + EXPECT_EQ("bar", result); + EXPECT_EQ("foo", result2); +} + +#if !_WIN32 && !__CYGWIN__ +TEST(AsyncIo, CapabilityPipe) { + auto ioContext = setupAsyncIo(); + + auto pipe = ioContext.provider->newCapabilityPipe(); + auto pipe2 = ioContext.provider->newCapabilityPipe(); + char receiveBuffer1[4]; + char receiveBuffer2[4]; + + // Expect to receive a stream, then write "bar" to it, then receive "foo" from it. + Own receivedStream; + auto promise = pipe2.ends[1]->receiveStream() + .then([&](Own stream) { + receivedStream = kj::mv(stream); + return receivedStream->write("bar", 3); + }).then([&]() { + return receivedStream->tryRead(receiveBuffer2, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer2, n); + }); + + // Send a stream, then write "foo" to the other end of the sent stream, then receive "bar" + // from it. + kj::String result = pipe2.ends[0]->sendStream(kj::mv(pipe.ends[1])) + .then([&]() { + return pipe.ends[0]->write("foo", 3); + }).then([&]() { + return pipe.ends[0]->tryRead(receiveBuffer1, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer1, n); + }).wait(ioContext.waitScope); + + kj::String result2 = promise.wait(ioContext.waitScope); + + EXPECT_EQ("bar", result); + EXPECT_EQ("foo", result2); +} + +TEST(AsyncIo, CapabilityPipeBlockedSendStream) { + // Check for a bug that existed at one point where if a sendStream() call couldn't complete + // immediately, it would fail. + + auto io = setupAsyncIo(); + + auto pipe = io.provider->newCapabilityPipe(); + + Promise promise = nullptr; + Own endpoint1; + uint nonBlockedCount = 0; + for (;;) { + auto pipe2 = io.provider->newCapabilityPipe(); + promise = pipe.ends[0]->sendStream(kj::mv(pipe2.ends[0])); + if (promise.poll(io.waitScope)) { + // Send completed immediately, because there was enough space in the stream. + ++nonBlockedCount; + promise.wait(io.waitScope); + } else { + // Send blocked! Let's continue with this promise then! + endpoint1 = kj::mv(pipe2.ends[1]); + break; + } + } + + for (uint i KJ_UNUSED: kj::zeroTo(nonBlockedCount)) { + // Receive and ignore all the streams that were sent without blocking. + pipe.ends[1]->receiveStream().wait(io.waitScope); + } + + // Now that write that blocked should have been able to complete. + promise.wait(io.waitScope); + + // Now get the one that blocked. + auto endpoint2 = pipe.ends[1]->receiveStream().wait(io.waitScope); + + endpoint1->write("foo", 3).wait(io.waitScope); + endpoint1->shutdownWrite(); + KJ_EXPECT(endpoint2->readAllText().wait(io.waitScope) == "foo"); +} + +TEST(AsyncIo, CapabilityPipeMultiStreamMessage) { + auto ioContext = setupAsyncIo(); + + auto pipe = ioContext.provider->newCapabilityPipe(); + auto pipe2 = ioContext.provider->newCapabilityPipe(); + auto pipe3 = ioContext.provider->newCapabilityPipe(); + + auto streams = heapArrayBuilder>(2); + streams.add(kj::mv(pipe2.ends[0])); + streams.add(kj::mv(pipe3.ends[0])); + + ArrayPtr secondBuf = "bar"_kj.asBytes(); + pipe.ends[0]->writeWithStreams("foo"_kj.asBytes(), arrayPtr(&secondBuf, 1), streams.finish()) + .wait(ioContext.waitScope); + + char receiveBuffer[7]; + Own receiveStreams[3]; + auto result = pipe.ends[1]->tryReadWithStreams(receiveBuffer, 6, 7, receiveStreams, 3) + .wait(ioContext.waitScope); + + KJ_EXPECT(result.byteCount == 6); + receiveBuffer[6] = '\0'; + KJ_EXPECT(kj::StringPtr(receiveBuffer) == "foobar"); + + KJ_ASSERT(result.capCount == 2); + + receiveStreams[0]->write("baz", 3).wait(ioContext.waitScope); + receiveStreams[0] = nullptr; + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ioContext.waitScope) == "baz"); + + pipe3.ends[1]->write("qux", 3).wait(ioContext.waitScope); + pipe3.ends[1] = nullptr; + KJ_EXPECT(receiveStreams[1]->readAllText().wait(ioContext.waitScope) == "qux"); +} + +TEST(AsyncIo, ScmRightsTruncatedOdd) { + // Test that if we send two FDs over a unix socket, but the receiving end only receives one, we + // don't leak the other FD. + + auto io = setupAsyncIo(); + + auto capPipe = io.provider->newCapabilityPipe(); + + int pipeFds[2]; + KJ_SYSCALL(miniposix::pipe(pipeFds)); + kj::AutoCloseFd in1(pipeFds[0]); + kj::AutoCloseFd out1(pipeFds[1]); + + KJ_SYSCALL(miniposix::pipe(pipeFds)); + kj::AutoCloseFd in2(pipeFds[0]); + kj::AutoCloseFd out2(pipeFds[1]); + + { + AutoCloseFd sendFds[2] = { kj::mv(out1), kj::mv(out2) }; + capPipe.ends[0]->writeWithFds("foo"_kj.asBytes(), nullptr, sendFds).wait(io.waitScope); + } + + { + char buffer[4]; + AutoCloseFd fdBuffer[1]; + auto result = capPipe.ends[1]->tryReadWithFds(buffer, 3, 3, fdBuffer, 1).wait(io.waitScope); + KJ_ASSERT(result.capCount == 1); + kj::FdOutputStream(fdBuffer[0].get()).write("bar", 3); + } + + // We want to carefully verify that out1 and out2 were closed, without deadlocking if they + // weren't. So we manually set nonblocking mode and then issue read()s. + KJ_SYSCALL(fcntl(in1, F_SETFL, O_NONBLOCK)); + KJ_SYSCALL(fcntl(in2, F_SETFL, O_NONBLOCK)); + + char buffer[4]; + ssize_t n; + + // First we read "bar" from in1. + KJ_NONBLOCKING_SYSCALL(n = read(in1, buffer, 4)); + KJ_ASSERT(n == 3); + buffer[3] = '\0'; + KJ_ASSERT(kj::StringPtr(buffer) == "bar"); + + // Now it should be EOF. + KJ_NONBLOCKING_SYSCALL(n = read(in1, buffer, 4)); + if (n < 0) { + KJ_FAIL_ASSERT("out1 was not closed"); + } + KJ_ASSERT(n == 0); + + // Second pipe should have been closed implicitly because we didn't provide space to receive it. + KJ_NONBLOCKING_SYSCALL(n = read(in2, buffer, 4)); + if (n < 0) { + KJ_FAIL_ASSERT("out2 was not closed. This could indicate that your operating system kernel is " + "buggy and leaks file descriptors when an SCM_RIGHTS message is truncated. FreeBSD was " + "known to do this until late 2018, while MacOS still has this bug as of this writing in " + "2019. However, KJ works around the problem on those platforms. You need to enable the " + "same work-around for your OS -- search for 'SCM_RIGHTS' in src/kj/async-io-unix.c++."); + } + KJ_ASSERT(n == 0); +} + +#if !__aarch64__ +// This test fails under qemu-user, probably due to a bug in qemu's syscall emulation rather than +// a bug in the kernel. We don't have a good way to detect qemu so we just skip the test on aarch64 +// in general. + +TEST(AsyncIo, ScmRightsTruncatedEven) { + // Test that if we send three FDs over a unix socket, but the receiving end only receives two, we + // don't leak the third FD. This is different from the send-two-receive-one case in that + // CMSG_SPACE() on many systems rounds up such that there is always space for an even number of + // FDs. In that case the other test only verifies that our userspace code to close unwanted FDs + // is correct, whereas *this* test really verifies that the *kernel* properly closes truncated + // FDs. + + auto io = setupAsyncIo(); + + auto capPipe = io.provider->newCapabilityPipe(); + + int pipeFds[2]; + KJ_SYSCALL(miniposix::pipe(pipeFds)); + kj::AutoCloseFd in1(pipeFds[0]); + kj::AutoCloseFd out1(pipeFds[1]); + + KJ_SYSCALL(miniposix::pipe(pipeFds)); + kj::AutoCloseFd in2(pipeFds[0]); + kj::AutoCloseFd out2(pipeFds[1]); + + KJ_SYSCALL(miniposix::pipe(pipeFds)); + kj::AutoCloseFd in3(pipeFds[0]); + kj::AutoCloseFd out3(pipeFds[1]); + + { + AutoCloseFd sendFds[3] = { kj::mv(out1), kj::mv(out2), kj::mv(out3) }; + capPipe.ends[0]->writeWithFds("foo"_kj.asBytes(), nullptr, sendFds).wait(io.waitScope); + } + + { + char buffer[4]; + AutoCloseFd fdBuffer[2]; + auto result = capPipe.ends[1]->tryReadWithFds(buffer, 3, 3, fdBuffer, 2).wait(io.waitScope); + KJ_ASSERT(result.capCount == 2); + kj::FdOutputStream(fdBuffer[0].get()).write("bar", 3); + kj::FdOutputStream(fdBuffer[1].get()).write("baz", 3); + } + + // We want to carefully verify that out1, out2, and out3 were closed, without deadlocking if they + // weren't. So we manually set nonblocking mode and then issue read()s. + KJ_SYSCALL(fcntl(in1, F_SETFL, O_NONBLOCK)); + KJ_SYSCALL(fcntl(in2, F_SETFL, O_NONBLOCK)); + KJ_SYSCALL(fcntl(in3, F_SETFL, O_NONBLOCK)); + + char buffer[4]; + ssize_t n; + + // First we read "bar" from in1. + KJ_NONBLOCKING_SYSCALL(n = read(in1, buffer, 4)); + KJ_ASSERT(n == 3); + buffer[3] = '\0'; + KJ_ASSERT(kj::StringPtr(buffer) == "bar"); + + // Now it should be EOF. + KJ_NONBLOCKING_SYSCALL(n = read(in1, buffer, 4)); + if (n < 0) { + KJ_FAIL_ASSERT("out1 was not closed"); + } + KJ_ASSERT(n == 0); + + // Next we read "baz" from in2. + KJ_NONBLOCKING_SYSCALL(n = read(in2, buffer, 4)); + KJ_ASSERT(n == 3); + buffer[3] = '\0'; + KJ_ASSERT(kj::StringPtr(buffer) == "baz"); + + // Now it should be EOF. + KJ_NONBLOCKING_SYSCALL(n = read(in2, buffer, 4)); + if (n < 0) { + KJ_FAIL_ASSERT("out2 was not closed"); + } + KJ_ASSERT(n == 0); + + // Third pipe should have been closed implicitly because we didn't provide space to receive it. + KJ_NONBLOCKING_SYSCALL(n = read(in3, buffer, 4)); + if (n < 0) { + KJ_FAIL_ASSERT("out3 was not closed. This could indicate that your operating system kernel is " + "buggy and leaks file descriptors when an SCM_RIGHTS message is truncated. FreeBSD was " + "known to do this until late 2018, while MacOS still has this bug as of this writing in " + "2019. However, KJ works around the problem on those platforms. You need to enable the " + "same work-around for your OS -- search for 'SCM_RIGHTS' in src/kj/async-io-unix.c++."); + } + KJ_ASSERT(n == 0); +} + +#endif // !__aarch64__ + +#endif // !_WIN32 && !__CYGWIN__ + TEST(AsyncIo, PipeThread) { auto ioContext = setupAsyncIo(); @@ -227,7 +856,50 @@ TEST(AsyncIo, Timeouts) { #if !_WIN32 // datagrams not implemented on win32 yet +bool isMsgTruncBroken() { + // Detect if the kernel fails to set MSG_TRUNC on recvmsg(). This seems to be the case at least + // when running an arm64 binary under qemu. + + int fd; + KJ_SYSCALL(fd = socket(AF_INET, SOCK_DGRAM, 0)); + KJ_DEFER(close(fd)); + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(0x7f000001); + KJ_SYSCALL(bind(fd, reinterpret_cast(&addr), sizeof(addr))); + + // Read back the assigned port. + socklen_t len = sizeof(addr); + KJ_SYSCALL(getsockname(fd, reinterpret_cast(&addr), &len)); + KJ_ASSERT(len == sizeof(addr)); + + const char* message = "foobar"; + KJ_SYSCALL(sendto(fd, message, strlen(message), 0, + reinterpret_cast(&addr), sizeof(addr))); + + char buf[4]; + struct iovec iov; + iov.iov_base = buf; + iov.iov_len = 3; + struct msghdr msg; + memset(&msg, 0, sizeof(msg)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + ssize_t n; + KJ_SYSCALL(n = recvmsg(fd, &msg, 0)); + KJ_ASSERT(n == 3); + + buf[3] = 0; + KJ_ASSERT(kj::StringPtr(buf) == "foo"); + + return (msg.msg_flags & MSG_TRUNC) == 0; +} + TEST(AsyncIo, Udp) { + bool msgTruncBroken = isMsgTruncBroken(); + auto ioContext = setupAsyncIo(); auto addr = ioContext.provider->getNetwork().parseAddress("127.0.0.1").wait(ioContext.waitScope); @@ -285,7 +957,7 @@ TEST(AsyncIo, Udp) { { auto content = recv1->getContent(); EXPECT_EQ("01234567", kj::heapString(content.value.asChars())); - EXPECT_TRUE(content.isTruncated); + EXPECT_TRUE(content.isTruncated || msgTruncBroken); } EXPECT_EQ(addr2->toString(), recv1->getSource().toString()); { @@ -294,10 +966,14 @@ TEST(AsyncIo, Udp) { EXPECT_FALSE(ancillary.isTruncated); } -#if defined(IP_PKTINFO) && !__CYGWIN__ +#if defined(IP_PKTINFO) && !__CYGWIN__ && !__aarch64__ // Set IP_PKTINFO header and try to receive it. + // // Doesn't work on Cygwin; see: https://cygwin.com/ml/cygwin/2009-01/msg00350.html // TODO(someday): Might work on more-recent Cygwin; I'm still testing against 1.7. + // + // Doesn't work when running arm64 binaries under QEMU -- in fact, it crashes QEMU. We don't + // have a good way to test if we're under QEMU so we just skip this test on aarch64. int one = 1; port1->setsockopt(IPPROTO_IP, IP_PKTINFO, &one, sizeof(one)); @@ -338,7 +1014,7 @@ TEST(AsyncIo, Udp) { EXPECT_EQ(addr2->toString(), recv1->getSource().toString()); { auto ancillary = recv1->getAncillary(); - EXPECT_TRUE(ancillary.isTruncated); + EXPECT_TRUE(ancillary.isTruncated || msgTruncBroken); // We might get a message, but it will be truncated. if (ancillary.value.size() != 0) { @@ -353,6 +1029,10 @@ TEST(AsyncIo, Udp) { } } +#if __APPLE__ +// On MacOS, `CMSG_SPACE(0)` triggers a bogus warning. +#pragma GCC diagnostic ignored "-Wnull-pointer-arithmetic" +#endif // See what happens if there's not enough space even for the cmsghdr. capacity.ancillary = CMSG_SPACE(0) - 8; recv1 = port1->makeReceiver(capacity); @@ -377,5 +1057,2370 @@ TEST(AsyncIo, Udp) { #endif // !_WIN32 +#ifdef __linux__ // Abstract unix sockets are only supported on Linux + +TEST(AsyncIo, AbstractUnixSocket) { + auto ioContext = setupAsyncIo(); + auto& network = ioContext.provider->getNetwork(); + auto elapsedSinceEpoch = systemPreciseMonotonicClock().now() - kj::origin(); + auto address = kj::str("unix-abstract:foo", getpid(), elapsedSinceEpoch / kj::NANOSECONDS); + + Own addr = network.parseAddress(address).wait(ioContext.waitScope); + + Own listener = addr->listen(); + // chdir proves no filesystem dependence. Test fails for regular unix socket + // but passes for abstract unix socket. + int originalDirFd; + KJ_SYSCALL(originalDirFd = open(".", O_RDONLY | O_DIRECTORY | O_CLOEXEC)); + KJ_DEFER(close(originalDirFd)); + KJ_SYSCALL(chdir("/")); + KJ_DEFER(KJ_SYSCALL(fchdir(originalDirFd))); + + addr->connect().attach(kj::mv(listener)).wait(ioContext.waitScope); +} + +#endif // __linux__ + +KJ_TEST("CIDR parsing") { + KJ_EXPECT(CidrRange("1.2.3.4/16").toString() == "1.2.0.0/16"); + KJ_EXPECT(CidrRange("1.2.255.4/18").toString() == "1.2.192.0/18"); + KJ_EXPECT(CidrRange("1234::abcd:ffff:ffff/98").toString() == "1234::abcd:c000:0/98"); + + KJ_EXPECT(CidrRange::inet4({1,2,255,4}, 18).toString() == "1.2.192.0/18"); + KJ_EXPECT(CidrRange::inet6({0x1234, 0x5678}, {0xabcd, 0xffff, 0xffff}, 98).toString() == + "1234:5678::abcd:c000:0/98"); + + union { + struct sockaddr addr; + struct sockaddr_in addr4; + struct sockaddr_in6 addr6; + }; + memset(&addr6, 0, sizeof(addr6)); + + { + addr4.sin_family = AF_INET; + addr4.sin_addr.s_addr = htonl(0x0102dfff); + KJ_EXPECT(CidrRange("1.2.255.255/18").matches(&addr)); + KJ_EXPECT(!CidrRange("1.2.255.255/19").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.0.0/16").matches(&addr)); + KJ_EXPECT(!CidrRange("1.3.0.0/16").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.223.255/32").matches(&addr)); + KJ_EXPECT(CidrRange("0.0.0.0/0").matches(&addr)); + KJ_EXPECT(!CidrRange("::/0").matches(&addr)); + } + + { + addr4.sin_family = AF_INET6; + byte bytes[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + memcpy(addr6.sin6_addr.s6_addr, bytes, 16); + KJ_EXPECT(CidrRange("0102:03ff::/24").matches(&addr)); + KJ_EXPECT(!CidrRange("0102:02ff::/24").matches(&addr)); + KJ_EXPECT(CidrRange("0102:02ff::/23").matches(&addr)); + KJ_EXPECT(CidrRange("0102:0304:0506:0708:090a:0b0c:0d0e:0f10/128").matches(&addr)); + KJ_EXPECT(CidrRange("::/0").matches(&addr)); + KJ_EXPECT(!CidrRange("0.0.0.0/0").matches(&addr)); + } + + { + addr4.sin_family = AF_INET6; + inet_pton(AF_INET6, "::ffff:1.2.223.255", &addr6.sin6_addr); + KJ_EXPECT(CidrRange("1.2.255.255/18").matches(&addr)); + KJ_EXPECT(!CidrRange("1.2.255.255/19").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.0.0/16").matches(&addr)); + KJ_EXPECT(!CidrRange("1.3.0.0/16").matches(&addr)); + KJ_EXPECT(CidrRange("1.2.223.255/32").matches(&addr)); + KJ_EXPECT(CidrRange("0.0.0.0/0").matches(&addr)); + KJ_EXPECT(CidrRange("::/0").matches(&addr)); + } +} + +bool allowed4(_::NetworkFilter& filter, StringPtr addrStr) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + inet_pton(AF_INET, addrStr.cStr(), &addr.sin_addr); + return filter.shouldAllow(reinterpret_cast(&addr), sizeof(addr)); +} + +bool allowed6(_::NetworkFilter& filter, StringPtr addrStr) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + inet_pton(AF_INET6, addrStr.cStr(), &addr.sin6_addr); + return filter.shouldAllow(reinterpret_cast(&addr), sizeof(addr)); +} + +KJ_TEST("NetworkFilter") { + _::NetworkFilter base; + + KJ_EXPECT(allowed4(base, "8.8.8.8")); + KJ_EXPECT(!allowed4(base, "240.1.2.3")); + + { + _::NetworkFilter filter({"public"}, {}, base); + + KJ_EXPECT(allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(!allowed4(filter, "192.168.0.1")); + KJ_EXPECT(!allowed4(filter, "10.1.2.3")); + KJ_EXPECT(!allowed4(filter, "127.0.0.1")); + KJ_EXPECT(!allowed4(filter, "0.0.0.0")); + + KJ_EXPECT(allowed6(filter, "2400:cb00:2048:1::c629:d7a2")); + KJ_EXPECT(!allowed6(filter, "fc00::1234")); + KJ_EXPECT(!allowed6(filter, "::1")); + KJ_EXPECT(!allowed6(filter, "::")); + } + + { + _::NetworkFilter filter({"private"}, {"local"}, base); + + KJ_EXPECT(!allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(allowed4(filter, "192.168.0.1")); + KJ_EXPECT(allowed4(filter, "10.1.2.3")); + KJ_EXPECT(!allowed4(filter, "127.0.0.1")); + KJ_EXPECT(!allowed4(filter, "0.0.0.0")); + + KJ_EXPECT(!allowed6(filter, "2400:cb00:2048:1::c629:d7a2")); + KJ_EXPECT(allowed6(filter, "fc00::1234")); + KJ_EXPECT(!allowed6(filter, "::1")); + KJ_EXPECT(!allowed6(filter, "::")); + } + + { + _::NetworkFilter filter({"1.0.0.0/8", "1.2.3.0/24"}, {"1.2.0.0/16", "1.2.3.4/32"}, base); + + KJ_EXPECT(!allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(allowed4(filter, "1.0.0.1")); + KJ_EXPECT(!allowed4(filter, "1.2.2.1")); + KJ_EXPECT(allowed4(filter, "1.2.3.1")); + KJ_EXPECT(!allowed4(filter, "1.2.3.4")); + } + + // Test combinations of public/private/network/local. At one point these were buggy. + { + _::NetworkFilter filter({"public", "private"}, {}, base); + + KJ_EXPECT(allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(allowed4(filter, "192.168.0.1")); + KJ_EXPECT(allowed4(filter, "10.1.2.3")); + KJ_EXPECT(allowed4(filter, "127.0.0.1")); + KJ_EXPECT(allowed4(filter, "0.0.0.0")); + + KJ_EXPECT(allowed6(filter, "2400:cb00:2048:1::c629:d7a2")); + KJ_EXPECT(allowed6(filter, "fc00::1234")); + KJ_EXPECT(allowed6(filter, "::1")); + KJ_EXPECT(allowed6(filter, "::")); + } + + { + _::NetworkFilter filter({"network", "local"}, {}, base); + + KJ_EXPECT(allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(allowed4(filter, "192.168.0.1")); + KJ_EXPECT(allowed4(filter, "10.1.2.3")); + KJ_EXPECT(allowed4(filter, "127.0.0.1")); + KJ_EXPECT(allowed4(filter, "0.0.0.0")); + + KJ_EXPECT(allowed6(filter, "2400:cb00:2048:1::c629:d7a2")); + KJ_EXPECT(allowed6(filter, "fc00::1234")); + KJ_EXPECT(allowed6(filter, "::1")); + KJ_EXPECT(allowed6(filter, "::")); + } + + { + _::NetworkFilter filter({"public", "local"}, {}, base); + + KJ_EXPECT(allowed4(filter, "8.8.8.8")); + KJ_EXPECT(!allowed4(filter, "240.1.2.3")); + + KJ_EXPECT(!allowed4(filter, "192.168.0.1")); + KJ_EXPECT(!allowed4(filter, "10.1.2.3")); + KJ_EXPECT(allowed4(filter, "127.0.0.1")); + KJ_EXPECT(allowed4(filter, "0.0.0.0")); + + KJ_EXPECT(allowed6(filter, "2400:cb00:2048:1::c629:d7a2")); + KJ_EXPECT(!allowed6(filter, "fc00::1234")); + KJ_EXPECT(allowed6(filter, "::1")); + KJ_EXPECT(allowed6(filter, "::")); + } +} + +KJ_TEST("Network::restrictPeers()") { + auto ioContext = setupAsyncIo(); + auto& w = ioContext.waitScope; + auto& network = ioContext.provider->getNetwork(); + auto restrictedNetwork = network.restrictPeers({"public"}); + + KJ_EXPECT(tryParse(w, *restrictedNetwork, "8.8.8.8") == "8.8.8.8:0"); +#if !_WIN32 + KJ_EXPECT_THROW_MESSAGE("restrictPeers", tryParse(w, *restrictedNetwork, "unix:/foo")); +#endif + + auto addr = restrictedNetwork->parseAddress("127.0.0.1").wait(w); + + auto listener = addr->listen(); + auto acceptTask = listener->accept() + .then([](kj::Own) { + KJ_FAIL_EXPECT("should not have received connection"); + }).eagerlyEvaluate(nullptr); + + KJ_EXPECT_THROW_MESSAGE("restrictPeers", addr->connect().wait(w)); + + // We can connect to the listener but the connection will be immediately closed. + auto addr2 = network.parseAddress("127.0.0.1", listener->getPort()).wait(w); + auto conn = addr2->connect().wait(w); + KJ_EXPECT(conn->readAllText().wait(w) == ""); +} + +kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { + if (expected.size() == 0) return kj::READY_NOW; + + auto buffer = kj::heapArray(expected.size()); + + auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { + if (amount == 0) { + KJ_FAIL_ASSERT("expected data never sent", expected); + } + + auto actual = buffer.slice(0, amount); + if (memcmp(actual.begin(), expected.begin(), actual.size()) != 0) { + KJ_FAIL_ASSERT("data from stream doesn't match expected", expected, actual); + } + + return expectRead(in, expected.slice(amount)); + }); +} + +class MockAsyncInputStream final: public AsyncInputStream { +public: + MockAsyncInputStream(kj::ArrayPtr bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr bytes; + size_t blockSize; +}; + +KJ_TEST("AsyncInputStream::readAllText() / readAllBytes()") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ","); + size_t inputSizes[] = { 0, 1, 256, 4096, 8191, 8192, 8193, 10000, bigText.size() }; + size_t blockSizes[] = { 1, 4, 256, 4096, 8192, bigText.size() }; + uint64_t limits[] = { + 0, 1, 256, + bigText.size() / 2, + bigText.size() - 1, + bigText.size(), + bigText.size() + 1, + kj::maxValue + }; + + for (size_t inputSize: inputSizes) { + for (size_t blockSize: blockSizes) { + for (uint64_t limit: limits) { + KJ_CONTEXT(inputSize, blockSize, limit); + auto textSlice = bigText.asBytes().slice(0, inputSize); + auto readAllText = [&]() { + MockAsyncInputStream input(textSlice, blockSize); + return input.readAllText(limit).wait(ws); + }; + auto readAllBytes = [&]() { + MockAsyncInputStream input(textSlice, blockSize); + return input.readAllBytes(limit).wait(ws); + }; + if (limit > inputSize) { + KJ_EXPECT(readAllText().asBytes() == textSlice); + KJ_EXPECT(readAllBytes() == textSlice); + } else { + KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllText()); + KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllBytes()); + } + } + } + } +} + +KJ_TEST("Userland pipe") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + auto promise = pipe.out->write("foo", 3); + KJ_EXPECT(!promise.poll(ws)); + + char buf[4]; + KJ_EXPECT(pipe.in->tryRead(buf, 1, 4).wait(ws) == 3); + buf[3] = '\0'; + KJ_EXPECT(buf == "foo"_kj); + + promise.wait(ws); + + auto promise2 = pipe.in->readAllText(); + KJ_EXPECT(!promise2.poll(ws)); + + pipe.out = nullptr; + KJ_EXPECT(promise2.wait(ws) == ""); +} + +KJ_TEST("Userland pipe cancel write") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + auto promise = pipe.out->write("foobar", 6); + KJ_EXPECT(!promise.poll(ws)); + + expectRead(*pipe.in, "foo").wait(ws); + KJ_EXPECT(!promise.poll(ws)); + promise = nullptr; + + promise = pipe.out->write("baz", 3); + expectRead(*pipe.in, "baz").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pipe.in->readAllText().wait(ws) == ""); +} + +KJ_TEST("Userland pipe cancel read") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + auto writeOp = pipe.out->write("foo", 3); + auto readOp = expectRead(*pipe.in, "foobar"); + writeOp.wait(ws); + KJ_EXPECT(!readOp.poll(ws)); + readOp = nullptr; + + auto writeOp2 = pipe.out->write("baz", 3); + expectRead(*pipe.in, "baz").wait(ws); +} + +KJ_TEST("Userland pipe pumpTo") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out); + + auto promise = pipe.out->write("foo", 3); + KJ_EXPECT(!promise.poll(ws)); + + expectRead(*pipe2.in, "foo").wait(ws); + + promise.wait(ws); + + auto promise2 = pipe2.in->readAllText(); + KJ_EXPECT(!promise2.poll(ws)); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 3); +} + +KJ_TEST("Userland pipe tryPumpFrom") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + auto promise = pipe.out->write("foo", 3); + KJ_EXPECT(!promise.poll(ws)); + + expectRead(*pipe2.in, "foo").wait(ws); + + promise.wait(ws); + + auto promise2 = pipe2.in->readAllText(); + KJ_EXPECT(!promise2.poll(ws)); + + pipe.out = nullptr; + KJ_EXPECT(!promise2.poll(ws)); + KJ_EXPECT(pumpPromise.wait(ws) == 3); +} + +KJ_TEST("Userland pipe pumpTo cancel") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out); + + auto promise = pipe.out->write("foobar", 3); + KJ_EXPECT(!promise.poll(ws)); + + expectRead(*pipe2.in, "foo").wait(ws); + + // Cancel pump. + pumpPromise = nullptr; + + auto promise3 = pipe2.out->write("baz", 3); + expectRead(*pipe2.in, "baz").wait(ws); +} + +KJ_TEST("Userland pipe tryPumpFrom cancel") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + auto promise = pipe.out->write("foobar", 3); + KJ_EXPECT(!promise.poll(ws)); + + expectRead(*pipe2.in, "foo").wait(ws); + + // Cancel pump. + pumpPromise = nullptr; + + auto promise3 = pipe2.out->write("baz", 3); + expectRead(*pipe2.in, "baz").wait(ws); +} + +KJ_TEST("Userland pipe with limit") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(6); + + { + auto promise = pipe.out->write("foo", 3); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe.in, "foo").wait(ws); + promise.wait(ws); + } + + { + auto promise = pipe.in->readAllText(); + KJ_EXPECT(!promise.poll(ws)); + auto promise2 = pipe.out->write("barbaz", 6); + KJ_EXPECT(promise.wait(ws) == "bar"); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("read end of pipe was aborted", promise2.wait(ws)); + } + + // Further writes throw and reads return EOF. + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "abortRead() has been called", pipe.out->write("baz", 3).wait(ws)); + KJ_EXPECT(pipe.in->readAllText().wait(ws) == ""); +} + +KJ_TEST("Userland pipe pumpTo with limit") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(6); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out); + + { + auto promise = pipe.out->write("foo", 3); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foo").wait(ws); + promise.wait(ws); + } + + { + auto promise = expectRead(*pipe2.in, "bar"); + KJ_EXPECT(!promise.poll(ws)); + auto promise2 = pipe.out->write("barbaz", 6); + promise.wait(ws); + pumpPromise.wait(ws); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("read end of pipe was aborted", promise2.wait(ws)); + } + + // Further writes throw. + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "abortRead() has been called", pipe.out->write("baz", 3).wait(ws)); +} + +KJ_TEST("Userland pipe pump into zero-limited pipe, no data to pump") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(uint64_t(0)); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + expectRead(*pipe2.in, ""); + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 0); +} + +KJ_TEST("Userland pipe pump into zero-limited pipe, data is pumped") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(uint64_t(0)); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + expectRead(*pipe2.in, ""); + auto writePromise = pipe.out->write("foo", 3); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("abortRead() has been called", pumpPromise.wait(ws)); +} + +KJ_TEST("Userland pipe gather write") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe.in, "foobar").wait(ws); + promise.wait(ws); + + auto promise2 = pipe.in->readAllText(); + KJ_EXPECT(!promise2.poll(ws)); + + pipe.out = nullptr; + KJ_EXPECT(promise2.wait(ws) == ""); +} + +KJ_TEST("Userland pipe gather write split on buffer boundary") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe.in, "foo").wait(ws); + expectRead(*pipe.in, "bar").wait(ws); + promise.wait(ws); + + auto promise2 = pipe.in->readAllText(); + KJ_EXPECT(!promise2.poll(ws)); + + pipe.out = nullptr; + KJ_EXPECT(promise2.wait(ws) == ""); +} + +KJ_TEST("Userland pipe gather write split mid-first-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe.in, "fo").wait(ws); + expectRead(*pipe.in, "obar").wait(ws); + promise.wait(ws); + + auto promise2 = pipe.in->readAllText(); + KJ_EXPECT(!promise2.poll(ws)); + + pipe.out = nullptr; + KJ_EXPECT(promise2.wait(ws) == ""); +} + +KJ_TEST("Userland pipe gather write split mid-second-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe.in, "foob").wait(ws); + expectRead(*pipe.in, "ar").wait(ws); + promise.wait(ws); + + auto promise2 = pipe.in->readAllText(); + KJ_EXPECT(!promise2.poll(ws)); + + pipe.out = nullptr; + KJ_EXPECT(promise2.wait(ws) == ""); +} + +KJ_TEST("Userland pipe gather write pump") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 6); +} + +KJ_TEST("Userland pipe gather write pump split on buffer boundary") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foo").wait(ws); + expectRead(*pipe2.in, "bar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 6); +} + +KJ_TEST("Userland pipe gather write pump split mid-first-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "fo").wait(ws); + expectRead(*pipe2.in, "obar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 6); +} + +KJ_TEST("Userland pipe gather write pump split mid-second-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foob").wait(ws); + expectRead(*pipe2.in, "ar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 6); +} + +KJ_TEST("Userland pipe gather write split pump on buffer boundary") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out, 3) + .then([&](uint64_t i) { + KJ_EXPECT(i == 3); + return pipe.in->pumpTo(*pipe2.out, 3); + }); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 3); +} + +KJ_TEST("Userland pipe gather write split pump mid-first-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out, 2) + .then([&](uint64_t i) { + KJ_EXPECT(i == 2); + return pipe.in->pumpTo(*pipe2.out, 4); + }); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 4); +} + +KJ_TEST("Userland pipe gather write split pump mid-second-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out, 4) + .then([&](uint64_t i) { + KJ_EXPECT(i == 4); + return pipe.in->pumpTo(*pipe2.out, 2); + }); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 2); +} + +KJ_TEST("Userland pipe gather write pumpFrom") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + char c; + auto eofPromise = pipe2.in->tryRead(&c, 1, 1); + eofPromise.poll(ws); // force pump to notice EOF + KJ_EXPECT(pumpPromise.wait(ws) == 6); + pipe2.out = nullptr; + KJ_EXPECT(eofPromise.wait(ws) == 0); +} + +KJ_TEST("Userland pipe gather write pumpFrom split on buffer boundary") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foo").wait(ws); + expectRead(*pipe2.in, "bar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + char c; + auto eofPromise = pipe2.in->tryRead(&c, 1, 1); + eofPromise.poll(ws); // force pump to notice EOF + KJ_EXPECT(pumpPromise.wait(ws) == 6); + pipe2.out = nullptr; + KJ_EXPECT(eofPromise.wait(ws) == 0); +} + +KJ_TEST("Userland pipe gather write pumpFrom split mid-first-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "fo").wait(ws); + expectRead(*pipe2.in, "obar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + char c; + auto eofPromise = pipe2.in->tryRead(&c, 1, 1); + eofPromise.poll(ws); // force pump to notice EOF + KJ_EXPECT(pumpPromise.wait(ws) == 6); + pipe2.out = nullptr; + KJ_EXPECT(eofPromise.wait(ws) == 0); +} + +KJ_TEST("Userland pipe gather write pumpFrom split mid-second-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foob").wait(ws); + expectRead(*pipe2.in, "ar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + char c; + auto eofPromise = pipe2.in->tryRead(&c, 1, 1); + eofPromise.poll(ws); // force pump to notice EOF + KJ_EXPECT(pumpPromise.wait(ws) == 6); + pipe2.out = nullptr; + KJ_EXPECT(eofPromise.wait(ws) == 0); +} + +KJ_TEST("Userland pipe gather write split pumpFrom on buffer boundary") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in, 3)) + .then([&](uint64_t i) { + KJ_EXPECT(i == 3); + return KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in, 3)); + }); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 3); +} + +KJ_TEST("Userland pipe gather write split pumpFrom mid-first-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in, 2)) + .then([&](uint64_t i) { + KJ_EXPECT(i == 2); + return KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in, 4)); + }); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 4); +} + +KJ_TEST("Userland pipe gather write split pumpFrom mid-second-buffer") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in, 4)) + .then([&](uint64_t i) { + KJ_EXPECT(i == 4); + return KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in, 2)); + }); + + ArrayPtr parts[] = { "foo"_kj.asBytes(), "bar"_kj.asBytes() }; + auto promise = pipe.out->write(parts); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 2); +} + +KJ_TEST("Userland pipe pumpTo less than write amount") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = pipe.in->pumpTo(*pipe2.out, 1); + + auto pieces = kj::heapArray>(2); + byte a[1] = { 'a' }; + byte b[1] = { 'b' }; + pieces[0] = arrayPtr(a, 1); + pieces[1] = arrayPtr(b, 1); + + auto writePromise = pipe.out->write(pieces); + KJ_EXPECT(!writePromise.poll(ws)); + + expectRead(*pipe2.in, "a").wait(ws); + KJ_EXPECT(pumpPromise.wait(ws) == 1); + KJ_EXPECT(!writePromise.poll(ws)); + + pumpPromise = pipe.in->pumpTo(*pipe2.out, 1); + + expectRead(*pipe2.in, "b").wait(ws); + KJ_EXPECT(pumpPromise.wait(ws) == 1); + writePromise.wait(ws); +} + +KJ_TEST("Userland pipe pumpFrom EOF on abortRead()") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + auto promise = pipe.out->write("foobar", 6); + KJ_EXPECT(!promise.poll(ws)); + expectRead(*pipe2.in, "foobar").wait(ws); + promise.wait(ws); + + KJ_EXPECT(!pumpPromise.poll(ws)); + pipe.out = nullptr; + pipe2.in = nullptr; // force pump to notice EOF + KJ_EXPECT(pumpPromise.wait(ws) == 6); + pipe2.out = nullptr; +} + +KJ_TEST("Userland pipe EOF fulfills pumpFrom promise") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + auto writePromise = pipe.out->write("foobar", 6); + KJ_EXPECT(!writePromise.poll(ws)); + auto pipe3 = newOneWayPipe(); + auto pumpPromise2 = pipe2.in->pumpTo(*pipe3.out); + KJ_EXPECT(!pumpPromise2.poll(ws)); + expectRead(*pipe3.in, "foobar").wait(ws); + writePromise.wait(ws); + + KJ_EXPECT(!pumpPromise.poll(ws)); + pipe.out = nullptr; + KJ_EXPECT(pumpPromise.wait(ws) == 6); + + KJ_EXPECT(!pumpPromise2.poll(ws)); + pipe2.out = nullptr; + KJ_EXPECT(pumpPromise2.wait(ws) == 6); +} + +KJ_TEST("Userland pipe tryPumpFrom to pumpTo for same amount fulfills simultaneously") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in, 6)); + + auto writePromise = pipe.out->write("foobar", 6); + KJ_EXPECT(!writePromise.poll(ws)); + auto pipe3 = newOneWayPipe(); + auto pumpPromise2 = pipe2.in->pumpTo(*pipe3.out, 6); + KJ_EXPECT(!pumpPromise2.poll(ws)); + expectRead(*pipe3.in, "foobar").wait(ws); + writePromise.wait(ws); + + KJ_EXPECT(pumpPromise.wait(ws) == 6); + KJ_EXPECT(pumpPromise2.wait(ws) == 6); +} + +KJ_TEST("Userland pipe multi-part write doesn't quit early") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + auto readPromise = expectRead(*pipe.in, "foo"); + + kj::ArrayPtr pieces[2] = { "foobar"_kj.asBytes(), "baz"_kj.asBytes() }; + auto writePromise = pipe.out->write(pieces); + + readPromise.wait(ws); + KJ_EXPECT(!writePromise.poll(ws)); + expectRead(*pipe.in, "bar").wait(ws); + KJ_EXPECT(!writePromise.poll(ws)); + expectRead(*pipe.in, "baz").wait(ws); + writePromise.wait(ws); +} + +KJ_TEST("Userland pipe BlockedRead gets empty tryPumpFrom") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto pipe2 = newOneWayPipe(); + + // First start a read from the back end. + char buffer[4]; + auto readPromise = pipe2.in->tryRead(buffer, 1, 4); + + // Now arrange a pump between the pipes, using tryPumpFrom(). + auto pumpPromise = KJ_ASSERT_NONNULL(pipe2.out->tryPumpFrom(*pipe.in)); + + // Disconnect the front pipe, causing EOF on the pump. + pipe.out = nullptr; + + // The pump should have produced zero bytes. + KJ_EXPECT(pumpPromise.wait(ws) == 0); + + // The read is incomplete. + KJ_EXPECT(!readPromise.poll(ws)); + + // A subsequent write() completes the read. + pipe2.out->write("foo", 3).wait(ws); + KJ_EXPECT(readPromise.wait(ws) == 3); + buffer[3] = '\0'; + KJ_EXPECT(kj::StringPtr(buffer, 3) == "foo"); +} + +constexpr static auto TEE_MAX_CHUNK_SIZE = 1 << 14; +// AsyncTee::MAX_CHUNK_SIZE, 16k as of this writing + +KJ_TEST("Userland tee") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto tee = newTee(kj::mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto writePromise = pipe.out->write("foobar", 6); + + expectRead(*left, "foobar").wait(ws); + writePromise.wait(ws); + expectRead(*right, "foobar").wait(ws); +} + +KJ_TEST("Userland nested tee") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto tee = newTee(kj::mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto tee2 = newTee(kj::mv(right)); + auto rightLeft = kj::mv(tee2.branches[0]); + auto rightRight = kj::mv(tee2.branches[1]); + + auto writePromise = pipe.out->write("foobar", 6); + + expectRead(*left, "foobar").wait(ws); + writePromise.wait(ws); + expectRead(*rightLeft, "foobar").wait(ws); + expectRead(*rightRight, "foo").wait(ws); + + auto tee3 = newTee(kj::mv(rightRight)); + auto rightRightLeft = kj::mv(tee3.branches[0]); + auto rightRightRight = kj::mv(tee3.branches[1]); + expectRead(*rightRightLeft, "bar").wait(ws); + expectRead(*rightRightRight, "b").wait(ws); + + auto tee4 = newTee(kj::mv(rightRightRight)); + auto rightRightRightLeft = kj::mv(tee4.branches[0]); + auto rightRightRightRight = kj::mv(tee4.branches[1]); + expectRead(*rightRightRightLeft, "ar").wait(ws); + expectRead(*rightRightRightRight, "ar").wait(ws); +} + +KJ_TEST("Userland tee concurrent read") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto tee = newTee(kj::mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + uint8_t leftBuf[6] = { 0 }; + uint8_t rightBuf[6] = { 0 }; + auto leftPromise = left->tryRead(leftBuf, 6, 6); + auto rightPromise = right->tryRead(rightBuf, 6, 6); + KJ_EXPECT(!leftPromise.poll(ws)); + KJ_EXPECT(!rightPromise.poll(ws)); + + pipe.out->write("foobar", 6).wait(ws); + + KJ_EXPECT(leftPromise.wait(ws) == 6); + KJ_EXPECT(rightPromise.wait(ws) == 6); + + KJ_EXPECT(memcmp(leftBuf, "foobar", 6) == 0); + KJ_EXPECT(memcmp(leftBuf, "foobar", 6) == 0); +} + +KJ_TEST("Userland tee cancel and restart read") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto tee = newTee(kj::mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto writePromise = pipe.out->write("foobar", 6); + + { + // Initiate a read and immediately cancel it. + uint8_t buf[6] = { 0 }; + auto promise = left->tryRead(buf, 6, 6); + } + + // Subsequent reads still see the full data. + expectRead(*left, "foobar").wait(ws); + writePromise.wait(ws); + expectRead(*right, "foobar").wait(ws); +} + +KJ_TEST("Userland tee cancel read and destroy branch then read other branch") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto tee = newTee(kj::mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto writePromise = pipe.out->write("foobar", 6); + + { + // Initiate a read and immediately cancel it. + uint8_t buf[6] = { 0 }; + auto promise = left->tryRead(buf, 6, 6); + } + + // And destroy the branch for good measure. + left = nullptr; + + // Subsequent reads on the other branch still see the full data. + expectRead(*right, "foobar").wait(ws); + writePromise.wait(ws); +} + +KJ_TEST("Userland tee subsequent other-branch reads are READY_NOW") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto tee = newTee(kj::mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + uint8_t leftBuf[6] = { 0 }; + auto leftPromise = left->tryRead(leftBuf, 6, 6); + // This is the first read, so there should NOT be buffered data. + KJ_EXPECT(!leftPromise.poll(ws)); + pipe.out->write("foobar", 6).wait(ws); + leftPromise.wait(ws); + KJ_EXPECT(memcmp(leftBuf, "foobar", 6) == 0); + + uint8_t rightBuf[6] = { 0 }; + auto rightPromise = right->tryRead(rightBuf, 6, 6); + // The left read promise was fulfilled, so there SHOULD be buffered data. + KJ_EXPECT(rightPromise.poll(ws)); + rightPromise.wait(ws); + KJ_EXPECT(memcmp(rightBuf, "foobar", 6) == 0); +} + +KJ_TEST("Userland tee read EOF propagation") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + auto writePromise = pipe.out->write("foobar", 6); + auto tee = newTee(mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + // Lengthless pipe, so ... + KJ_EXPECT(left->tryGetLength() == nullptr); + KJ_EXPECT(right->tryGetLength() == nullptr); + + uint8_t leftBuf[7] = { 0 }; + auto leftPromise = left->tryRead(leftBuf, size(leftBuf), size(leftBuf)); + writePromise.wait(ws); + // Destroying the output side should force a short read. + pipe.out = nullptr; + + KJ_EXPECT(leftPromise.wait(ws) == 6); + KJ_EXPECT(memcmp(leftBuf, "foobar", 6) == 0); + + // And we should see a short read here, too. + uint8_t rightBuf[7] = { 0 }; + auto rightPromise = right->tryRead(rightBuf, size(rightBuf), size(rightBuf)); + KJ_EXPECT(rightPromise.wait(ws) == 6); + KJ_EXPECT(memcmp(rightBuf, "foobar", 6) == 0); + + // Further reads should all be short. + KJ_EXPECT(left->tryRead(leftBuf, 1, size(leftBuf)).wait(ws) == 0); + KJ_EXPECT(right->tryRead(rightBuf, 1, size(rightBuf)).wait(ws) == 0); +} + +KJ_TEST("Userland tee read exception propagation") { + kj::EventLoop loop; + WaitScope ws(loop); + + // Make a pipe expecting to read more than we're actually going to write. This will force a "pipe + // ended prematurely" exception when we destroy the output side early. + auto pipe = newOneWayPipe(7); + auto writePromise = pipe.out->write("foobar", 6); + auto tee = newTee(mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + // Test tryGetLength() while we're at it. + KJ_EXPECT(KJ_ASSERT_NONNULL(left->tryGetLength()) == 7); + KJ_EXPECT(KJ_ASSERT_NONNULL(right->tryGetLength()) == 7); + + uint8_t leftBuf[7] = { 0 }; + auto leftPromise = left->tryRead(leftBuf, 6, size(leftBuf)); + writePromise.wait(ws); + // Destroying the output side should force a fulfillment of the read (since we reached minBytes). + pipe.out = nullptr; + KJ_EXPECT(leftPromise.wait(ws) == 6); + KJ_EXPECT(memcmp(leftBuf, "foobar", 6) == 0); + + // The next read sees the exception. + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("pipe ended prematurely", + left->tryRead(leftBuf, 1, size(leftBuf)).ignoreResult().wait(ws)); + + // Test tryGetLength() here -- the unread branch still sees the original length value. + KJ_EXPECT(KJ_ASSERT_NONNULL(left->tryGetLength()) == 1); + KJ_EXPECT(KJ_ASSERT_NONNULL(right->tryGetLength()) == 7); + + // We should see the buffered data on the other side, even though we don't reach our minBytes. + uint8_t rightBuf[7] = { 0 }; + auto rightPromise = right->tryRead(rightBuf, size(rightBuf), size(rightBuf)); + KJ_EXPECT(rightPromise.wait(ws) == 6); + KJ_EXPECT(memcmp(rightBuf, "foobar", 6) == 0); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("pipe ended prematurely", + right->tryRead(rightBuf, 1, size(leftBuf)).ignoreResult().wait(ws)); + + // Further reads should all see the exception again. + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("pipe ended prematurely", + left->tryRead(leftBuf, 1, size(leftBuf)).ignoreResult().wait(ws)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("pipe ended prematurely", + right->tryRead(rightBuf, 1, size(leftBuf)).ignoreResult().wait(ws)); +} + +KJ_TEST("Userland tee read exception propagation w/ data loss") { + kj::EventLoop loop; + WaitScope ws(loop); + + // Make a pipe expecting to read more than we're actually going to write. This will force a "pipe + // ended prematurely" exception once the pipe sees a short read. + auto pipe = newOneWayPipe(7); + auto writePromise = pipe.out->write("foobar", 6); + auto tee = newTee(mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + uint8_t leftBuf[7] = { 0 }; + auto leftPromise = left->tryRead(leftBuf, 7, 7); + writePromise.wait(ws); + // Destroying the output side should force an exception, since we didn't reach our minBytes. + pipe.out = nullptr; + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "pipe ended prematurely", leftPromise.ignoreResult().wait(ws)); + + // And we should see a short read here, too. In fact, we shouldn't see anything: the short read + // above read all of the pipe's data, but then failed to buffer it because it encountered an + // exception. It buffered the exception, instead. + uint8_t rightBuf[7] = { 0 }; + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("pipe ended prematurely", + right->tryRead(rightBuf, 1, 1).ignoreResult().wait(ws)); +} + +KJ_TEST("Userland tee read into different buffer sizes") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto tee = newTee(heap("foo bar baz"_kj.asBytes(), 11)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + uint8_t leftBuf[5] = { 0 }; + uint8_t rightBuf[11] = { 0 }; + + auto leftPromise = left->tryRead(leftBuf, 5, 5); + auto rightPromise = right->tryRead(rightBuf, 11, 11); + + KJ_EXPECT(leftPromise.wait(ws) == 5); + KJ_EXPECT(rightPromise.wait(ws) == 11); +} + +KJ_TEST("Userland tee reads see max(minBytes...) and min(maxBytes...)") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto tee = newTee(heap("foo bar baz"_kj.asBytes(), 11)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + { + uint8_t leftBuf[5] = { 0 }; + uint8_t rightBuf[11] = { 0 }; + + // Subrange of another range. The smaller maxBytes should win. + auto leftPromise = left->tryRead(leftBuf, 3, 5); + auto rightPromise = right->tryRead(rightBuf, 1, 11); + + KJ_EXPECT(leftPromise.wait(ws) == 5); + KJ_EXPECT(rightPromise.wait(ws) == 5); + } + + { + uint8_t leftBuf[5] = { 0 }; + uint8_t rightBuf[11] = { 0 }; + + // Disjoint ranges. The larger minBytes should win. + auto leftPromise = left->tryRead(leftBuf, 3, 5); + auto rightPromise = right->tryRead(rightBuf, 6, 11); + + KJ_EXPECT(leftPromise.wait(ws) == 5); + KJ_EXPECT(rightPromise.wait(ws) == 6); + + KJ_EXPECT(left->tryRead(leftBuf, 1, 2).wait(ws) == 1); + } +} + +KJ_TEST("Userland tee read stress test") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ","); + + auto tee = newTee(heap(bigText.asBytes(), bigText.size())); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftBuffer = heapArray(bigText.size()); + + { + auto leftSlice = leftBuffer.slice(0, leftBuffer.size()); + while (leftSlice.size() > 0) { + for (size_t blockSize: { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59 }) { + if (leftSlice.size() == 0) break; + auto maxBytes = min(blockSize, leftSlice.size()); + auto amount = left->tryRead(leftSlice.begin(), 1, maxBytes).wait(ws); + leftSlice = leftSlice.slice(amount, leftSlice.size()); + } + } + } + + KJ_EXPECT(memcmp(leftBuffer.begin(), bigText.begin(), leftBuffer.size()) == 0); + KJ_EXPECT(right->readAllText().wait(ws) == bigText); +} + +KJ_TEST("Userland tee pump") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ","); + + auto tee = newTee(heap(bigText.asBytes(), bigText.size())); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftPipe = newOneWayPipe(); + auto rightPipe = newOneWayPipe(); + + auto leftPumpPromise = left->pumpTo(*leftPipe.out, 7); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + + auto rightPumpPromise = right->pumpTo(*rightPipe.out); + // Neither are ready yet, because the left pump's backpressure has blocked the AsyncTee's pull + // loop until we read from leftPipe. + KJ_EXPECT(!leftPumpPromise.poll(ws)); + KJ_EXPECT(!rightPumpPromise.poll(ws)); + + expectRead(*leftPipe.in, "foo bar").wait(ws); + KJ_EXPECT(leftPumpPromise.wait(ws) == 7); + KJ_EXPECT(!rightPumpPromise.poll(ws)); + + // We should be able to read up to how far the left side pumped, and beyond. The left side will + // now have data in its buffer. + expectRead(*rightPipe.in, "foo bar baz,foo bar baz,foo").wait(ws); + + // Consume the left side buffer. + expectRead(*left, " baz,foo bar").wait(ws); + + // We can destroy the left branch entirely and the right branch will still see all data. + left = nullptr; + KJ_EXPECT(!rightPumpPromise.poll(ws)); + auto allTextPromise = rightPipe.in->readAllText(); + KJ_EXPECT(rightPumpPromise.wait(ws) == bigText.size()); + // Need to force an EOF in the right pipe to check the result. + rightPipe.out = nullptr; + KJ_EXPECT(allTextPromise.wait(ws) == bigText.slice(27)); +} + +KJ_TEST("Userland tee pump slows down reads") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ","); + + auto tee = newTee(heap(bigText.asBytes(), bigText.size())); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftPipe = newOneWayPipe(); + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + + // The left pump will cause some data to be buffered on the right branch, which we can read. + auto rightExpectation0 = kj::str(bigText.slice(0, TEE_MAX_CHUNK_SIZE)); + expectRead(*right, rightExpectation0).wait(ws); + + // But the next right branch read is blocked by the left pipe's backpressure. + auto rightExpectation1 = kj::str(bigText.slice(TEE_MAX_CHUNK_SIZE, TEE_MAX_CHUNK_SIZE + 10)); + auto rightPromise = expectRead(*right, rightExpectation1); + KJ_EXPECT(!rightPromise.poll(ws)); + + // The right branch read finishes when we relieve the pressure in the left pipe. + auto allTextPromise = leftPipe.in->readAllText(); + rightPromise.wait(ws); + KJ_EXPECT(leftPumpPromise.wait(ws) == bigText.size()); + leftPipe.out = nullptr; + KJ_EXPECT(allTextPromise.wait(ws) == bigText); +} + +KJ_TEST("Userland tee pump EOF propagation") { + kj::EventLoop loop; + WaitScope ws(loop); + + { + // EOF encountered by two pump operations. + auto pipe = newOneWayPipe(); + auto writePromise = pipe.out->write("foo bar", 7); + auto tee = newTee(mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftPipe = newOneWayPipe(); + auto rightPipe = newOneWayPipe(); + + // Pump the first bit, and block. + + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + auto rightPumpPromise = right->pumpTo(*rightPipe.out); + writePromise.wait(ws); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + KJ_EXPECT(!rightPumpPromise.poll(ws)); + + // Induce an EOF. We should see it propagated to both pump promises. + + pipe.out = nullptr; + + // Relieve backpressure. + auto leftAllPromise = leftPipe.in->readAllText(); + auto rightAllPromise = rightPipe.in->readAllText(); + KJ_EXPECT(leftPumpPromise.wait(ws) == 7); + KJ_EXPECT(rightPumpPromise.wait(ws) == 7); + + // Make sure we got the data on the pipes that were being pumped to. + KJ_EXPECT(!leftAllPromise.poll(ws)); + KJ_EXPECT(!rightAllPromise.poll(ws)); + leftPipe.out = nullptr; + rightPipe.out = nullptr; + KJ_EXPECT(leftAllPromise.wait(ws) == "foo bar"); + KJ_EXPECT(rightAllPromise.wait(ws) == "foo bar"); + } + + { + // EOF encountered by a read and pump operation. + auto pipe = newOneWayPipe(); + auto writePromise = pipe.out->write("foo bar", 7); + auto tee = newTee(mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftPipe = newOneWayPipe(); + auto rightPipe = newOneWayPipe(); + + // Pump one branch, read another. + + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + expectRead(*right, "foo bar").wait(ws); + writePromise.wait(ws); + uint8_t dummy = 0; + auto rightReadPromise = right->tryRead(&dummy, 1, 1); + + // Induce an EOF. We should see it propagated to both the read and pump promises. + + pipe.out = nullptr; + + // Relieve backpressure in the tee to see the EOF. + auto leftAllPromise = leftPipe.in->readAllText(); + KJ_EXPECT(leftPumpPromise.wait(ws) == 7); + KJ_EXPECT(rightReadPromise.wait(ws) == 0); + + // Make sure we got the data on the pipe that was being pumped to. + KJ_EXPECT(!leftAllPromise.poll(ws)); + leftPipe.out = nullptr; + KJ_EXPECT(leftAllPromise.wait(ws) == "foo bar"); + } +} + +KJ_TEST("Userland tee pump EOF on chunk boundary") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ","); + + // Conjure an EOF right on the boundary of the tee's internal chunk. + auto chunkText = kj::str(bigText.slice(0, TEE_MAX_CHUNK_SIZE)); + auto tee = newTee(heap(chunkText.asBytes(), chunkText.size())); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftPipe = newOneWayPipe(); + auto rightPipe = newOneWayPipe(); + + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + auto rightPumpPromise = right->pumpTo(*rightPipe.out); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + KJ_EXPECT(!rightPumpPromise.poll(ws)); + + auto leftAllPromise = leftPipe.in->readAllText(); + auto rightAllPromise = rightPipe.in->readAllText(); + + // The pumps should see the EOF and stop. + KJ_EXPECT(leftPumpPromise.wait(ws) == TEE_MAX_CHUNK_SIZE); + KJ_EXPECT(rightPumpPromise.wait(ws) == TEE_MAX_CHUNK_SIZE); + + // Verify that we saw the data on the other end of the destination pipes. + leftPipe.out = nullptr; + rightPipe.out = nullptr; + KJ_EXPECT(leftAllPromise.wait(ws) == chunkText); + KJ_EXPECT(rightAllPromise.wait(ws) == chunkText); +} + +KJ_TEST("Userland tee pump read exception propagation") { + kj::EventLoop loop; + WaitScope ws(loop); + + { + // Exception encountered by two pump operations. + auto pipe = newOneWayPipe(14); + auto writePromise = pipe.out->write("foo bar", 7); + auto tee = newTee(mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftPipe = newOneWayPipe(); + auto rightPipe = newOneWayPipe(); + + // Pump the first bit, and block. + + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + auto rightPumpPromise = right->pumpTo(*rightPipe.out); + writePromise.wait(ws); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + KJ_EXPECT(!rightPumpPromise.poll(ws)); + + // Induce a read exception. We should see it propagated to both pump promises. + + pipe.out = nullptr; + + // Both promises must exist before the backpressure in the tee is relieved, and the tee pull + // loop actually sees the exception. + auto leftAllPromise = leftPipe.in->readAllText(); + auto rightAllPromise = rightPipe.in->readAllText(); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "pipe ended prematurely", leftPumpPromise.ignoreResult().wait(ws)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "pipe ended prematurely", rightPumpPromise.ignoreResult().wait(ws)); + + // Make sure we got the data on the destination pipes. + KJ_EXPECT(!leftAllPromise.poll(ws)); + KJ_EXPECT(!rightAllPromise.poll(ws)); + leftPipe.out = nullptr; + rightPipe.out = nullptr; + KJ_EXPECT(leftAllPromise.wait(ws) == "foo bar"); + KJ_EXPECT(rightAllPromise.wait(ws) == "foo bar"); + } + + { + // Exception encountered by a read and pump operation. + auto pipe = newOneWayPipe(14); + auto writePromise = pipe.out->write("foo bar", 7); + auto tee = newTee(mv(pipe.in)); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftPipe = newOneWayPipe(); + auto rightPipe = newOneWayPipe(); + + // Pump one branch, read another. + + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + expectRead(*right, "foo bar").wait(ws); + writePromise.wait(ws); + uint8_t dummy = 0; + auto rightReadPromise = right->tryRead(&dummy, 1, 1); + + // Induce a read exception. We should see it propagated to both the read and pump promises. + + pipe.out = nullptr; + + // Relieve backpressure in the tee to see the exceptions. + auto leftAllPromise = leftPipe.in->readAllText(); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "pipe ended prematurely", leftPumpPromise.ignoreResult().wait(ws)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "pipe ended prematurely", rightReadPromise.ignoreResult().wait(ws)); + + // Make sure we got the data on the destination pipe. + KJ_EXPECT(!leftAllPromise.poll(ws)); + leftPipe.out = nullptr; + KJ_EXPECT(leftAllPromise.wait(ws) == "foo bar"); + } +} + +KJ_TEST("Userland tee pump write exception propagation") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ","); + + auto tee = newTee(heap(bigText.asBytes(), bigText.size())); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + // Set up two pumps and let them block. + auto leftPipe = newOneWayPipe(); + auto rightPipe = newOneWayPipe(); + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + auto rightPumpPromise = right->pumpTo(*rightPipe.out); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + KJ_EXPECT(!rightPumpPromise.poll(ws)); + + // Induce a write exception in the right branch pump. It should propagate to the right pump + // promise. + rightPipe.in = nullptr; + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "read end of pipe was aborted", rightPumpPromise.ignoreResult().wait(ws)); + + // The left pump promise does not see the right branch's write exception. + KJ_EXPECT(!leftPumpPromise.poll(ws)); + auto allTextPromise = leftPipe.in->readAllText(); + KJ_EXPECT(leftPumpPromise.wait(ws) == bigText.size()); + leftPipe.out = nullptr; + KJ_EXPECT(allTextPromise.wait(ws) == bigText); +} + +KJ_TEST("Userland tee pump cancellation implies write cancellation") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto text = "foo bar baz"_kj; + + auto tee = newTee(heap(text.asBytes(), text.size())); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + auto leftPipe = newOneWayPipe(); + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + + // Arrange to block the left pump on its write operation. + expectRead(*right, "foo ").wait(ws); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + + // Then cancel the pump, while it's still blocked. + leftPumpPromise = nullptr; + // It should cancel its write operations, so it should now be safe to destroy the output stream to + // which it was pumping. + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + leftPipe.out = nullptr; + })) { + KJ_FAIL_EXPECT("write promises were not canceled", *exception); + } +} + +KJ_TEST("Userland tee buffer size limit") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto text = "foo bar baz"_kj; + + { + // We can carefully read data to stay under our ridiculously low limit. + + auto tee = newTee(heap(text.asBytes(), text.size()), 2); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + expectRead(*left, "fo").wait(ws); + expectRead(*right, "foo ").wait(ws); + expectRead(*left, "o ba").wait(ws); + expectRead(*right, "bar ").wait(ws); + expectRead(*left, "r ba").wait(ws); + expectRead(*right, "baz").wait(ws); + expectRead(*left, "z").wait(ws); + } + + { + // Exceeding the limit causes both branches to see the exception after exhausting their buffers. + + auto tee = newTee(heap(text.asBytes(), text.size()), 2); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + + expectRead(*left, "fo").wait(ws); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("tee buffer size limit exceeded", + expectRead(*left, "o").wait(ws)); + expectRead(*right, "fo").wait(ws); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("tee buffer size limit exceeded", + expectRead(*right, "o").wait(ws)); + } + + { + // We guarantee that two pumps started simultaneously will never exceed our buffer size limit. + + auto tee = newTee(heap(text.asBytes(), text.size()), 2); + auto left = kj::mv(tee.branches[0]); + auto right = kj::mv(tee.branches[1]); + auto leftPipe = kj::newOneWayPipe(); + auto rightPipe = kj::newOneWayPipe(); + + auto leftPumpPromise = left->pumpTo(*leftPipe.out); + auto rightPumpPromise = right->pumpTo(*rightPipe.out); + KJ_EXPECT(!leftPumpPromise.poll(ws)); + KJ_EXPECT(!rightPumpPromise.poll(ws)); + + uint8_t leftBuf[11] = { 0 }; + uint8_t rightBuf[11] = { 0 }; + + // The first read on the left pipe will succeed. + auto leftPromise = leftPipe.in->tryRead(leftBuf, 1, 11); + KJ_EXPECT(leftPromise.wait(ws) == 2); + KJ_EXPECT(memcmp(leftBuf, text.begin(), 2) == 0); + + // But the second will block until we relieve pressure on the right pipe. + leftPromise = leftPipe.in->tryRead(leftBuf + 2, 1, 9); + KJ_EXPECT(!leftPromise.poll(ws)); + + // Relieve the right pipe pressure ... + auto rightPromise = rightPipe.in->tryRead(rightBuf, 1, 11); + KJ_EXPECT(rightPromise.wait(ws) == 2); + KJ_EXPECT(memcmp(rightBuf, text.begin(), 2) == 0); + + // Now the second left pipe read will complete. + KJ_EXPECT(leftPromise.wait(ws) == 2); + KJ_EXPECT(memcmp(leftBuf, text.begin(), 4) == 0); + + // Leapfrog the left branch with the right. There should be 2 bytes in the buffer, so we can + // demand a total of 4. + rightPromise = rightPipe.in->tryRead(rightBuf + 2, 4, 9); + KJ_EXPECT(rightPromise.wait(ws) == 4); + KJ_EXPECT(memcmp(rightBuf, text.begin(), 6) == 0); + + // Leapfrog the right with the left. We demand the entire rest of the stream, so this should + // block. Note that a regular read for this amount on one of the tee branches directly would + // exceed our buffer size limit, but this one does not, because we have the pipe to regulate + // backpressure for us. + leftPromise = leftPipe.in->tryRead(leftBuf + 4, 7, 7); + KJ_EXPECT(!leftPromise.poll(ws)); + + // Ask for the entire rest of the stream on the right branch and wrap things up. + rightPromise = rightPipe.in->tryRead(rightBuf + 6, 5, 5); + + KJ_EXPECT(leftPromise.wait(ws) == 7); + KJ_EXPECT(memcmp(leftBuf, text.begin(), 11) == 0); + + KJ_EXPECT(rightPromise.wait(ws) == 5); + KJ_EXPECT(memcmp(rightBuf, text.begin(), 11) == 0); + } +} + +KJ_TEST("Userspace OneWayPipe whenWriteDisconnected()") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newOneWayPipe(); + + auto abortedPromise = pipe.out->whenWriteDisconnected(); + KJ_ASSERT(!abortedPromise.poll(ws)); + + pipe.in = nullptr; + + KJ_ASSERT(abortedPromise.poll(ws)); + abortedPromise.wait(ws); +} + +KJ_TEST("Userspace TwoWayPipe whenWriteDisconnected()") { + kj::EventLoop loop; + WaitScope ws(loop); + + auto pipe = newTwoWayPipe(); + + auto abortedPromise = pipe.ends[0]->whenWriteDisconnected(); + KJ_ASSERT(!abortedPromise.poll(ws)); + + pipe.ends[1] = nullptr; + + KJ_ASSERT(abortedPromise.poll(ws)); + abortedPromise.wait(ws); +} + +#if !_WIN32 // We don't currently support detecting disconnect with IOCP. +#if !__CYGWIN__ // TODO(someday): Figure out why whenWriteDisconnected() doesn't work on Cygwin. + +KJ_TEST("OS OneWayPipe whenWriteDisconnected()") { + auto io = setupAsyncIo(); + + auto pipe = io.provider->newOneWayPipe(); + + pipe.out->write("foo", 3).wait(io.waitScope); + auto abortedPromise = pipe.out->whenWriteDisconnected(); + KJ_ASSERT(!abortedPromise.poll(io.waitScope)); + + pipe.in = nullptr; + + KJ_ASSERT(abortedPromise.poll(io.waitScope)); + abortedPromise.wait(io.waitScope); +} + +KJ_TEST("OS TwoWayPipe whenWriteDisconnected()") { + auto io = setupAsyncIo(); + + auto pipe = io.provider->newTwoWayPipe(); + + pipe.ends[0]->write("foo", 3).wait(io.waitScope); + pipe.ends[1]->write("bar", 3).wait(io.waitScope); + + auto abortedPromise = pipe.ends[0]->whenWriteDisconnected(); + KJ_ASSERT(!abortedPromise.poll(io.waitScope)); + + pipe.ends[1] = nullptr; + + KJ_ASSERT(abortedPromise.poll(io.waitScope)); + abortedPromise.wait(io.waitScope); + + char buffer[4]; + KJ_ASSERT(pipe.ends[0]->tryRead(&buffer, 3, 3).wait(io.waitScope) == 3); + buffer[3] = '\0'; + KJ_EXPECT(buffer == "bar"_kj); + + // Note: Reading any further in pipe.ends[0] would throw "connection reset". +} + +KJ_TEST("import socket FD that's already broken") { + auto io = setupAsyncIo(); + + int fds[2]; + KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); + KJ_SYSCALL(write(fds[1], "foo", 3)); + KJ_SYSCALL(close(fds[1])); + + auto stream = io.lowLevelProvider->wrapSocketFd(fds[0], LowLevelAsyncIoProvider::TAKE_OWNERSHIP); + + auto abortedPromise = stream->whenWriteDisconnected(); + KJ_ASSERT(abortedPromise.poll(io.waitScope)); + abortedPromise.wait(io.waitScope); + + char buffer[4]; + KJ_ASSERT(stream->tryRead(&buffer, sizeof(buffer), sizeof(buffer)).wait(io.waitScope) == 3); + buffer[3] = '\0'; + KJ_EXPECT(buffer == "foo"_kj); +} + +#endif // !__CYGWIN__ +#endif // !_WIN32 + +KJ_TEST("AggregateConnectionReceiver") { + EventLoop loop; + WaitScope ws(loop); + + auto pipe1 = newCapabilityPipe(); + auto pipe2 = newCapabilityPipe(); + + auto receiversBuilder = kj::heapArrayBuilder>(2); + receiversBuilder.add(kj::heap(*pipe1.ends[0])); + receiversBuilder.add(kj::heap(*pipe2.ends[0])); + + auto aggregate = newAggregateConnectionReceiver(receiversBuilder.finish()); + + CapabilityStreamNetworkAddress connector1(nullptr, *pipe1.ends[1]); + CapabilityStreamNetworkAddress connector2(nullptr, *pipe2.ends[1]); + + auto connectAndWrite = [&](NetworkAddress& addr, kj::StringPtr text) { + return addr.connect() + .then([text](Own stream) { + auto promise = stream->write(text.begin(), text.size()); + return promise.attach(kj::mv(stream)); + }).eagerlyEvaluate([](kj::Exception&& e) { + KJ_LOG(ERROR, e); + }); + }; + + auto acceptAndRead = [&](ConnectionReceiver& socket, kj::StringPtr expected) { + return socket + .accept().then([](Own stream) { + auto promise = stream->readAllText(); + return promise.attach(kj::mv(stream)); + }).then([expected](kj::String actual) { + KJ_EXPECT(actual == expected); + }).eagerlyEvaluate([](kj::Exception&& e) { + KJ_LOG(ERROR, e); + }); + }; + + auto connectPromise1 = connectAndWrite(connector1, "foo"); + KJ_EXPECT(!connectPromise1.poll(ws)); + auto connectPromise2 = connectAndWrite(connector2, "bar"); + KJ_EXPECT(!connectPromise2.poll(ws)); + + acceptAndRead(*aggregate, "foo").wait(ws); + + auto connectPromise3 = connectAndWrite(connector1, "baz"); + KJ_EXPECT(!connectPromise3.poll(ws)); + + acceptAndRead(*aggregate, "bar").wait(ws); + acceptAndRead(*aggregate, "baz").wait(ws); + + connectPromise1.wait(ws); + connectPromise2.wait(ws); + connectPromise3.wait(ws); + + auto acceptPromise1 = acceptAndRead(*aggregate, "qux"); + auto acceptPromise2 = acceptAndRead(*aggregate, "corge"); + auto acceptPromise3 = acceptAndRead(*aggregate, "grault"); + + KJ_EXPECT(!acceptPromise1.poll(ws)); + KJ_EXPECT(!acceptPromise2.poll(ws)); + KJ_EXPECT(!acceptPromise3.poll(ws)); + + // Cancel one of the acceptors... + { auto drop = kj::mv(acceptPromise2); } + + connectAndWrite(connector2, "qux").wait(ws); + connectAndWrite(connector1, "grault").wait(ws); + + acceptPromise1.wait(ws); + acceptPromise3.wait(ws); +} + +// ======================================================================================= +// Tests for optimized pumpTo() between OS handles. Note that this is only even optimized on +// some OSes (only Linux as of this writing), but the behavior should still be the same on all +// OSes, so we run the tests regardless. + +kj::String bigString(size_t size) { + auto result = kj::heapString(size); + for (auto i: kj::zeroTo(size)) { + result[i] = 'a' + i % 26; + } + return result; +} + +KJ_TEST("OS handle pumpTo") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + + { + auto readPromise = expectRead(*pipe2.ends[1], "foo"); + pipe1.ends[0]->write("foo", 3).wait(ws); + readPromise.wait(ws); + } + + { + auto readPromise = expectRead(*pipe2.ends[1], "bar"); + pipe1.ends[0]->write("bar", 3).wait(ws); + readPromise.wait(ws); + } + + auto two = bigString(2000); + auto four = bigString(4000); + auto eight = bigString(8000); + auto fiveHundred = bigString(500'000); + + { + auto readPromise = expectRead(*pipe2.ends[1], two); + pipe1.ends[0]->write(two.begin(), two.size()).wait(ws); + readPromise.wait(ws); + } + + { + auto readPromise = expectRead(*pipe2.ends[1], four); + pipe1.ends[0]->write(four.begin(), four.size()).wait(ws); + readPromise.wait(ws); + } + + { + auto readPromise = expectRead(*pipe2.ends[1], eight); + pipe1.ends[0]->write(eight.begin(), eight.size()).wait(ws); + readPromise.wait(ws); + } + + { + auto readPromise = expectRead(*pipe2.ends[1], fiveHundred); + pipe1.ends[0]->write(fiveHundred.begin(), fiveHundred.size()).wait(ws); + readPromise.wait(ws); + } + + KJ_EXPECT(!pump.poll(ws)) + pipe1.ends[0]->shutdownWrite(); + KJ_EXPECT(pump.wait(ws) == 6 + two.size() + four.size() + eight.size() + fiveHundred.size()); +} + +KJ_TEST("OS handle pumpTo small limit") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 500); + + auto text = bigString(1000); + + auto expected = kj::str(text.slice(0, 500)); + + auto readPromise = expectRead(*pipe2.ends[1], expected); + pipe1.ends[0]->write(text.begin(), text.size()).wait(ws); + auto secondWritePromise = pipe1.ends[0]->write(text.begin(), text.size()); + readPromise.wait(ws); + KJ_EXPECT(pump.wait(ws) == 500); + + expectRead(*pipe1.ends[1], text.slice(500)).wait(ws); +} + +KJ_TEST("OS handle pumpTo small limit -- write first then read") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto text = bigString(1000); + + auto expected = kj::str(text.slice(0, 500)); + + // Initiate the write first and let it put as much in the buffer as possible. + auto writePromise = pipe1.ends[0]->write(text.begin(), text.size()); + writePromise.poll(ws); + + // Now start the pump. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 500); + + auto readPromise = expectRead(*pipe2.ends[1], expected); + writePromise.wait(ws); + auto secondWritePromise = pipe1.ends[0]->write(text.begin(), text.size()); + readPromise.wait(ws); + KJ_EXPECT(pump.wait(ws) == 500); + + expectRead(*pipe1.ends[1], text.slice(500)).wait(ws); +} + +KJ_TEST("OS handle pumpTo large limit") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 750'000); + + auto text = bigString(500'000); + + auto expected = kj::str(text, text.slice(0, 250'000)); + + auto readPromise = expectRead(*pipe2.ends[1], expected); + pipe1.ends[0]->write(text.begin(), text.size()).wait(ws); + auto secondWritePromise = pipe1.ends[0]->write(text.begin(), text.size()); + readPromise.wait(ws); + KJ_EXPECT(pump.wait(ws) == 750'000); + + expectRead(*pipe1.ends[1], text.slice(250'000)).wait(ws); +} + +KJ_TEST("OS handle pumpTo large limit -- write first then read") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto text = bigString(500'000); + + auto expected = kj::str(text, text.slice(0, 250'000)); + + // Initiate the write first and let it put as much in the buffer as possible. + auto writePromise = pipe1.ends[0]->write(text.begin(), text.size()); + writePromise.poll(ws); + + // Now start the pump. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 750'000); + + auto readPromise = expectRead(*pipe2.ends[1], expected); + writePromise.wait(ws); + auto secondWritePromise = pipe1.ends[0]->write(text.begin(), text.size()); + readPromise.wait(ws); + KJ_EXPECT(pump.wait(ws) == 750'000); + + expectRead(*pipe1.ends[1], text.slice(250'000)).wait(ws); +} + +#if !_WIN32 +kj::String fillWriteBuffer(int fd) { + // Fill up the write buffer of the given FD and return the contents written. We need to use the + // raw syscalls to do this because KJ doesn't have a way to know how many bytes made it into the + // socket buffer. + auto huge = bigString(2'000'000); + + size_t pos = 0; + for (;;) { + KJ_ASSERT(pos < huge.size(), "whoa, big buffer"); + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = ::write(fd, huge.begin() + pos, huge.size() - pos)); + if (n < 0) break; + pos += n; + } + + return kj::str(huge.slice(0, pos)); +} + +KJ_TEST("OS handle pumpTo write buffer is full before pump") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto bufferContent = fillWriteBuffer(KJ_ASSERT_NONNULL(pipe2.ends[0]->getFd())); + + // Also prime the input pipe with some buffered bytes. + auto writePromise = pipe1.ends[0]->write("foo", 3); + writePromise.poll(ws); + + // Start the pump and let it get blocked. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + KJ_EXPECT(!pump.poll(ws)); + + // Queue another write, even. + writePromise = writePromise + .then([&]() { return pipe1.ends[0]->write("bar", 3); }); + writePromise.poll(ws); + + // See it all go through. + expectRead(*pipe2.ends[1], bufferContent).wait(ws); + expectRead(*pipe2.ends[1], "foobar").wait(ws); + + writePromise.wait(ws); + + pipe1.ends[0]->shutdownWrite(); + KJ_EXPECT(pump.wait(ws) == 6); + pipe2.ends[0]->shutdownWrite(); + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ws) == ""); +} + +KJ_TEST("OS handle pumpTo write buffer is full before pump -- and pump ends early") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto bufferContent = fillWriteBuffer(KJ_ASSERT_NONNULL(pipe2.ends[0]->getFd())); + + // Also prime the input pipe with some buffered bytes followed by EOF. + auto writePromise = pipe1.ends[0]->write("foo", 3) + .then([&]() { pipe1.ends[0]->shutdownWrite(); }); + writePromise.poll(ws); + + // Start the pump and let it get blocked. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + KJ_EXPECT(!pump.poll(ws)); + + // See it all go through. + expectRead(*pipe2.ends[1], bufferContent).wait(ws); + expectRead(*pipe2.ends[1], "foo").wait(ws); + + writePromise.wait(ws); + + KJ_EXPECT(pump.wait(ws) == 3); + pipe2.ends[0]->shutdownWrite(); + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ws) == ""); +} + +KJ_TEST("OS handle pumpTo write buffer is full before pump -- and pump hits limit early") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto bufferContent = fillWriteBuffer(KJ_ASSERT_NONNULL(pipe2.ends[0]->getFd())); + + // Also prime the input pipe with some buffered bytes followed by EOF. + auto writePromise = pipe1.ends[0]->write("foo", 3); + writePromise.poll(ws); + + // Start the pump and let it get blocked. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0], 3); + KJ_EXPECT(!pump.poll(ws)); + + // See it all go through. + expectRead(*pipe2.ends[1], bufferContent).wait(ws); + expectRead(*pipe2.ends[1], "foo").wait(ws); + + writePromise.wait(ws); + + KJ_EXPECT(pump.wait(ws) == 3); + pipe2.ends[0]->shutdownWrite(); + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ws) == ""); +} + +KJ_TEST("OS handle pumpTo write buffer is full before pump -- and a lot of data is pumped") { + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto pipe1 = ioContext.provider->newTwoWayPipe(); + auto pipe2 = ioContext.provider->newTwoWayPipe(); + + auto bufferContent = fillWriteBuffer(KJ_ASSERT_NONNULL(pipe2.ends[0]->getFd())); + + // Also prime the input pipe with some buffered bytes followed by EOF. + auto text = bigString(500'000); + auto writePromise = pipe1.ends[0]->write(text.begin(), text.size()); + writePromise.poll(ws); + + // Start the pump and let it get blocked. + auto pump = pipe1.ends[1]->pumpTo(*pipe2.ends[0]); + KJ_EXPECT(!pump.poll(ws)); + + // See it all go through. + expectRead(*pipe2.ends[1], bufferContent).wait(ws); + expectRead(*pipe2.ends[1], text).wait(ws); + + writePromise.wait(ws); + + pipe1.ends[0]->shutdownWrite(); + KJ_EXPECT(pump.wait(ws) == text.size()); + pipe2.ends[0]->shutdownWrite(); + KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ws) == ""); +} +#endif + +KJ_TEST("pump file to socket") { + // Tests sendfile() optimization + + auto ioContext = setupAsyncIo(); + auto& ws = ioContext.waitScope; + + auto doTest = [&](kj::Own file) { + file->writeAll("foobar"_kj.asBytes()); + + { + FileInputStream input(*file); + auto pipe = ioContext.provider->newTwoWayPipe(); + auto readPromise = pipe.ends[1]->readAllText(); + input.pumpTo(*pipe.ends[0]).wait(ws); + pipe.ends[0]->shutdownWrite(); + KJ_EXPECT(readPromise.wait(ws) == "foobar"); + KJ_EXPECT(input.getOffset() == 6); + } + + { + FileInputStream input(*file); + auto pipe = ioContext.provider->newTwoWayPipe(); + auto readPromise = pipe.ends[1]->readAllText(); + input.pumpTo(*pipe.ends[0], 3).wait(ws); + pipe.ends[0]->shutdownWrite(); + KJ_EXPECT(readPromise.wait(ws) == "foo"); + KJ_EXPECT(input.getOffset() == 3); + } + + { + FileInputStream input(*file, 3); + auto pipe = ioContext.provider->newTwoWayPipe(); + auto readPromise = pipe.ends[1]->readAllText(); + input.pumpTo(*pipe.ends[0]).wait(ws); + pipe.ends[0]->shutdownWrite(); + KJ_EXPECT(readPromise.wait(ws) == "bar"); + KJ_EXPECT(input.getOffset() == 6); + } + + auto big = bigString(500'000); + file->writeAll(big); + + { + FileInputStream input(*file); + auto pipe = ioContext.provider->newTwoWayPipe(); + auto readPromise = pipe.ends[1]->readAllText(); + input.pumpTo(*pipe.ends[0]).wait(ws); + pipe.ends[0]->shutdownWrite(); + // Extra parens here so that we don't write the big string to the console on failure... + KJ_EXPECT((readPromise.wait(ws) == big)); + KJ_EXPECT(input.getOffset() == big.size()); + } + }; + + // Try with an in-memory file. No optimization is possible. + doTest(kj::newInMemoryFile(kj::nullClock())); + + // Try with a disk file. Should use sendfile(). + auto fs = kj::newDiskFilesystem(); + doTest(fs->getCurrent().createTemporary()); +} + } // namespace } // namespace kj diff --git a/c++/src/kj/async-io-unix.c++ b/c++/src/kj/async-io-unix.c++ index ec53b3a9d0..62ce21323e 100644 --- a/c++/src/kj/async-io-unix.c++ +++ b/c++/src/kj/async-io-unix.c++ @@ -22,7 +22,18 @@ #if !_WIN32 // For Win32 implementation, see async-io-win32.c++. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#ifndef _FILE_OFFSET_BITS +#define _FILE_OFFSET_BITS 64 +// Request 64-bit off_t for sendfile(). (The code will still work if we get 32-bit off_t as long +// as actual files are under 4GB.) +#endif + #include "async-io.h" +#include "async-io-internal.h" #include "async-unix.h" #include "debug.h" #include "thread.h" @@ -44,25 +55,49 @@ #include #include #include +#include +#include + +#if __linux__ +#include +#endif + +#if !defined(SO_PEERCRED) && defined(LOCAL_PEERCRED) +#include +#endif + +#if !defined(SOL_LOCAL) && (__FreeBSD__ || __DragonflyBSD__ || __APPLE__) +// On DragonFly, FreeBSD < 12.2 and older Darwin you're supposed to use 0 for SOL_LOCAL. +#define SOL_LOCAL 0 +#endif namespace kj { namespace { void setNonblocking(int fd) { +#ifdef FIONBIO + int opt = 1; + KJ_SYSCALL(ioctl(fd, FIONBIO, &opt)); +#else int flags; KJ_SYSCALL(flags = fcntl(fd, F_GETFL)); if ((flags & O_NONBLOCK) == 0) { KJ_SYSCALL(fcntl(fd, F_SETFL, flags | O_NONBLOCK)); } +#endif } void setCloseOnExec(int fd) { +#ifdef FIOCLEX + KJ_SYSCALL(ioctl(fd, FIOCLEX)); +#else int flags; KJ_SYSCALL(flags = fcntl(fd, F_GETFD)); if ((flags & FD_CLOEXEC) == 0) { KJ_SYSCALL(fcntl(fd, F_SETFD, flags | FD_CLOEXEC)); } +#endif } static constexpr uint NEW_FD_FLAGS = @@ -111,20 +146,44 @@ private: // ======================================================================================= -class AsyncStreamFd: public OwnedFileDescriptor, public AsyncIoStream { +class AsyncStreamFd: public OwnedFileDescriptor, public AsyncCapabilityStream { public: - AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags) + AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags, uint observerFlags) : OwnedFileDescriptor(fd, flags), - observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ_WRITE) {} + eventPort(eventPort), + observer(eventPort, fd, observerFlags) {} virtual ~AsyncStreamFd() noexcept(false) {} Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return tryReadInternal(buffer, minBytes, maxBytes, 0); + return tryReadInternal(buffer, minBytes, maxBytes, nullptr, 0, {0,0}) + .then([](ReadResult r) { return r.byteCount; }); + } + + Promise tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, {0,0}); + } + + Promise tryReadWithStreams( + void* buffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + auto fdBuffer = kj::heapArray(maxStreams); + auto promise = tryReadInternal(buffer, minBytes, maxBytes, fdBuffer.begin(), maxStreams, {0,0}); + + return promise.then([this, fdBuffer = kj::mv(fdBuffer), streamBuffer] + (ReadResult result) mutable { + for (auto i: kj::zeroTo(result.capCount)) { + streamBuffer[i] = kj::heap(eventPort, fdBuffer[i].release(), + LowLevelAsyncIoProvider::TAKE_OWNERSHIP | LowLevelAsyncIoProvider::ALREADY_CLOEXEC, + UnixEventPort::FdObserver::OBSERVE_READ_WRITE); + } + return result; + }); } Promise write(const void* buffer, size_t size) override { - ssize_t writeResult; - KJ_NONBLOCKING_SYSCALL(writeResult = ::write(fd, buffer, size)) { + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = ::write(fd, buffer, size)) { // Error. // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to @@ -138,28 +197,291 @@ public: return kj::READY_NOW; } - // A negative result means EAGAIN, which we can treat the same as having written zero bytes. - size_t n = writeResult < 0 ? 0 : writeResult; - - if (n == size) { + if (n < 0) { + // EAGAIN -- need to wait for writability and try again. + return observer.whenBecomesWritable().then([=]() { + return write(buffer, size); + }); + } else if (n == size) { + // All done. return READY_NOW; - } - - // Fewer than `size` bytes were written, therefore we must be out of buffer space. Wait until - // the fd becomes writable again. - buffer = reinterpret_cast(buffer) + n; - size -= n; - - return observer.whenBecomesWritable().then([=]() { + } else { + // Fewer than `size` bytes were written, but we CANNOT assume we're out of buffer space, as + // Linux is known to return partial reads/writes when interrupted by a signal -- yes, even + // for non-blocking operations. So, we'll need to write() again now, even though it will + // almost certainly fail with EAGAIN. See comments in the read path for more info. + buffer = reinterpret_cast(buffer) + n; + size -= n; return write(buffer, size); - }); + } } Promise write(ArrayPtr> pieces) override { if (pieces.size() == 0) { - return writeInternal(nullptr, nullptr); + return writeInternal(nullptr, nullptr, nullptr); + } else { + return writeInternal(pieces[0], pieces.slice(1, pieces.size()), nullptr); + } + } + + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { + return writeInternal(data, moreData, fds); + } + + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + auto fds = KJ_MAP(stream, streams) { + return downcast(*stream).fd; + }; + auto promise = writeInternal(data, moreData, fds); + return promise.attach(kj::mv(fds), kj::mv(streams)); + } + + Maybe> tryPumpFrom( + AsyncInputStream& input, uint64_t amount = kj::maxValue) override { +#if __linux__ && !__ANDROID__ + KJ_IF_MAYBE(sock, kj::dynamicDowncastIfAvailable(input)) { + return pumpFromOther(*sock, amount); + } +#endif + +#if __linux__ + KJ_IF_MAYBE(file, kj::dynamicDowncastIfAvailable(input)) { + KJ_IF_MAYBE(fd, file->getUnderlyingFile().getFd()) { + return pumpFromFile(*file, *fd, amount, 0); + } + } +#endif + + return nullptr; + } + +#if __linux__ + // TODO(someday): Support sendfile on other OS's... unfortunately, it works differently on + // different systems. + +private: + Promise pumpFromFile(FileInputStream& input, int fileFd, + uint64_t amount, uint64_t soFar) { + while (soFar < amount) { + off_t offset = input.getOffset(); + ssize_t n; + + // Although sendfile()'s last argument has type size_t, on Linux it seems to cause EINVAL + // if we pass an amount that is greater than UINT32_MAX, so make sure to clamp to that. In + // practice, of course, we'll be limited to the socket buffer size. + size_t requested = kj::min(amount - soFar, (uint32_t)kj::maxValue); + + KJ_SYSCALL_HANDLE_ERRORS(n = sendfile(fd, fileFd, &offset, requested)) { + case EINVAL: + case ENOSYS: + // Fall back to regular pump + return unoptimizedPumpTo(input, *this, amount, soFar); + + case EAGAIN: + return observer.whenBecomesWritable() + .then([this, &input, fileFd, amount, soFar]() { + return pumpFromFile(input, fileFd, amount, soFar); + }); + + default: + KJ_FAIL_SYSCALL("sendfile", error); + } + + if (n == 0) break; + + input.seek(offset); // NOTE: sendfile() updated `offset` in-place. + soFar += n; + } + + return soFar; + } + +public: +#endif // __linux__ + +#if __linux__ && !__ANDROID__ +// Linux's splice() syscall lets us optimize pumping of bytes between file descriptors. +// +// TODO(someday): splice()-based pumping hangs in unit tests on Android for some reason. We should +// figure out why, but for now I'm just disabling it... + +private: + Maybe> pumpFromOther(AsyncStreamFd& input, uint64_t amount) { + // The input is another AsyncStreamFd, so perhaps we can do an optimized pump with splice(). + + // Before we resort to a bunch of syscalls, let's try to see if the pump is small and able to + // be fully satisfied immediately. This optimizes for the case of small streams, e.g. a short + // HTTP body. + + byte buffer[4096]; + size_t pos = 0; + size_t initialAmount = kj::min(sizeof(buffer), amount); + + bool eof = false; + + // Read into the buffer until it's full or there are no bytes available. Note that we'd expect + // one call to read() will pull as much data out of the socket as possible (up to our buffer + // size), so you might think the loop is unnecessary. The reason we want to do a second read(), + // though, is to find out if we're at EOF or merely waiting for more data. In the EOF case, + // we can end the pump early without splicing. + while (pos < initialAmount) { + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = ::read(input.fd, buffer + pos, initialAmount - pos)); + if (n <= 0) { + eof = n == 0; + break; + } + pos += n; + } + + // Write the bytes that we just read back out to the output. + { + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = ::write(fd, buffer, pos)); + if (n < 0) n = 0; // treat EAGAIN as "zero bytes written" + if (size_t(n) < pos) { + // Oh crap, the output buffer is full. This should be rare. But, now we're going to have + // to copy the remaining bytes into the heap to do an async write. + auto leftover = kj::heapArray(buffer + n, pos - n); + auto promise = write(leftover.begin(), leftover.size()); + promise = promise.attach(kj::mv(leftover)); + if (eof || pos == amount) { + return promise.then([pos]() -> uint64_t { return pos; }); + } else { + return promise.then([&input, this, pos, amount]() { + return splicePumpFrom(input, pos, amount); + }); + } + } + } + + if (eof || pos == amount) { + // We finished the pump in one go, so don't splice. + return Promise(uint64_t(pos)); } else { - return writeInternal(pieces[0], pieces.slice(1, pieces.size())); + // Use splice for the rest of the pump. + return splicePumpFrom(input, pos, amount); + } + } + + static constexpr size_t MAX_SPLICE_LEN = 1 << 20; + // Maximum value we'll pass for the `len` argument of `splice()`. Linux does not like it when we + // use `kj::maxValue` here so we clamp it. Note that the actual value of this constant is + // irrelevanta as long as it is more than the pipe buffer size (typically 64k) and less than + // whatever value makes Linux unhappy. All actual operations will be clamped to the buffer size. + // (And if the buffer size is for some reason larger than this, that's OK too, we just won't + // end up using the whole buffer.) + + Promise splicePumpFrom(AsyncStreamFd& input, uint64_t readSoFar, uint64_t limit) { + // splice() requires that either its input or its output is a pipe. But chances are neither + // `input.fd` nor `this->fd` is a pipe -- in most use cases they are sockets. In order to take + // advantage of splice(), then, we need to allocate a pipe to act as the middleman, so we can + // splice() from the input to the pipe, and then from the pipe to the output. + // + // You might wonder why this pipe middleman is required. Why can't splice() go directly from + // a socket to a socket? Linus Torvalds attempts to explain here: + // https://yarchive.net/comp/linux/splice.html + // + // The short version is that the pipe itself is equivalent to an in-memory buffer. In a naive + // pump implementation, we allocate a buffer, read() into it and write() out. With splice(), + // we allocate a kernelspace buffer by allocating a pipe, then we splice() into the pipe and + // splice() back out. + + // Linux normally allocates pipe buffers of 64k (16 pages of 4k each). However, when + // /proc/sys/fs/pipe-user-pages-soft is hit, then Linux will start allocating 4k (1 page) + // buffers instead, and will give an error if we try to increase it. + // + // The soft limit defaults to 16384 pages, which we'd hit after 1024 pipes -- totally possible + // in a big server. 64k is a nice buffer size, but even 4k is better than not using splice, so + // we'll live with whatever buffer size the kernel gives us. + // + // There is a second, "hard" limit, /proc/sys/fs/pipe-user-pages-hard, at which point Linux + // will start refusing to allocate pipes at all. In this case we fall back to an unoptimized + // pump. However, this limit defaults to unlimited, so this won't ever happen unless someone + // has manually changed the limit. That's probably dangerous since if the app allocates pipes + // anywhere else in its codebase, it probably doesn't have any fallbacks in those places, so + // things will break anyway... to avoid that we'd need to self-regulate the number of pipes + // we allocate here to avoid coming close to the hard limit, but that's a lot of effort so I'm + // not going to bother! + + int pipeFds[2]; + KJ_SYSCALL_HANDLE_ERRORS(pipe2(pipeFds, O_NONBLOCK | O_CLOEXEC)) { + case ENFILE: + // Probably hit the limit on pipe buffers, fall back to unoptimized pump. + return unoptimizedPumpTo(input, *this, limit, readSoFar); + default: + KJ_FAIL_SYSCALL("pipe2()", error); + } + + AutoCloseFd pipeIn(pipeFds[0]), pipeOut(pipeFds[1]); + + return splicePumpLoop(input, pipeFds[0], pipeFds[1], readSoFar, limit, 0) + .attach(kj::mv(pipeIn), kj::mv(pipeOut)); + } + + Promise splicePumpLoop(AsyncStreamFd& input, int pipeIn, int pipeOut, + uint64_t readSoFar, uint64_t limit, size_t bufferedAmount) { + for (;;) { + while (bufferedAmount > 0) { + // First flush out whatever is in the pipe buffer. + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = splice(pipeIn, nullptr, fd, nullptr, + MAX_SPLICE_LEN, SPLICE_F_MOVE | SPLICE_F_NONBLOCK)); + if (n > 0) { + KJ_ASSERT(n <= bufferedAmount, "splice pipe larger than bufferedAmount?"); + bufferedAmount -= n; + } else { + KJ_ASSERT(n < 0, "splice pipe empty before bufferedAmount reached?", bufferedAmount); + return observer.whenBecomesWritable() + .then([this, &input, pipeIn, pipeOut, readSoFar, limit, bufferedAmount]() { + return splicePumpLoop(input, pipeIn, pipeOut, readSoFar, limit, bufferedAmount); + }); + } + } + + // Now the pipe buffer is empty, so we can try to read some more. + { + if (readSoFar >= limit) { + // Hit the limit, we're done. + KJ_ASSERT(readSoFar == limit); + return readSoFar; + } + + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = splice(input.fd, nullptr, pipeOut, nullptr, + kj::min(limit - readSoFar, MAX_SPLICE_LEN), SPLICE_F_MOVE | SPLICE_F_NONBLOCK)); + if (n == 0) { + // EOF. + return readSoFar; + } else if (n < 0) { + // No data available, wait. + return input.observer.whenBecomesReadable() + .then([this, &input, pipeIn, pipeOut, readSoFar, limit]() { + return splicePumpLoop(input, pipeIn, pipeOut, readSoFar, limit, 0); + }); + } + + readSoFar += n; + bufferedAmount = n; + } + } + } + +public: +#endif // __linux__ && !__ANDROID__ + + Promise whenWriteDisconnected() override { + KJ_IF_MAYBE(p, writeDisconnectedPromise) { + return p->addBranch(); + } else { + auto fork = observer.whenWriteDisconnected().fork(); + auto result = fork.addBranch(); + writeDisconnectedPromise = kj::mv(fork); + return kj::mv(result); } } @@ -197,6 +519,15 @@ public: *length = socklen; } + kj::Maybe getFd() const override { + return fd; + } + + void registerAncillaryMessageHandler( + kj::Function)> fn) override { + ancillaryMsgCallback = kj::mv(fn); + } + Promise waitConnected() { // Wait until initial connection has completed. This actually just waits until it is writable. @@ -221,24 +552,164 @@ public: } private: + UnixEventPort& eventPort; UnixEventPort::FdObserver observer; + Maybe> writeDisconnectedPromise; + Maybe)>> ancillaryMsgCallback; - Promise tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, - size_t alreadyRead) { + Promise tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds, + ReadResult alreadyRead) { // `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes, // maxBytes, and buffer have already been adjusted to account for them, but this count must // be included in the final return value. ssize_t n; - KJ_NONBLOCKING_SYSCALL(n = ::read(fd, buffer, maxBytes)) { - // Error. + if (maxFds == 0 && ancillaryMsgCallback == nullptr) { + KJ_NONBLOCKING_SYSCALL(n = ::read(fd, buffer, maxBytes)) { + // Error. + + // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to + // a bug that exists in both Clang and GCC: + // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799 + // http://llvm.org/bugs/show_bug.cgi?id=12286 + goto error; + } + } else { + struct msghdr msg; + memset(&msg, 0, sizeof(msg)); + + struct iovec iov; + memset(&iov, 0, sizeof(iov)); + iov.iov_base = buffer; + iov.iov_len = maxBytes; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + // Allocate space to receive a cmsg. + size_t msgBytes; + if (ancillaryMsgCallback == nullptr) { +#if __APPLE__ || __FreeBSD__ + // Until very recently (late 2018 / early 2019), FreeBSD suffered from a bug in which when + // an SCM_RIGHTS message was truncated on delivery, it would not close the FDs that weren't + // delivered -- they would simply leak: https://bugs.freebsd.org/131876 + // + // My testing indicates that MacOS has this same bug as of today (April 2019). I don't know + // if they plan to fix it or are even aware of it. + // + // To handle both cases, we will always provide space to receive 512 FDs. Hopefully, this is + // greater than the maximum number of FDs that these kernels will transmit in one message + // PLUS enough space for any other ancillary messages that could be sent before the + // SCM_RIGHTS message to push it back in the buffer. I couldn't find any firm documentation + // on these limits, though -- I only know that Linux is limited to 253, and I saw a hint in + // a comment in someone else's application that suggested FreeBSD is the same. Hopefully, + // then, this is sufficient to prevent attacks. But if not, there's nothing more we can do; + // it's really up to the kernel to fix this. + msgBytes = CMSG_SPACE(sizeof(int) * 512); +#else + msgBytes = CMSG_SPACE(sizeof(int) * maxFds); +#endif + } else { + // If we want room for ancillary messages instead of or in addition to FDs, just use the + // same amount of cushion as in the MacOS/FreeBSD case above. + // Someday we may want to allow customization here, but there's no immediate use for it. + msgBytes = CMSG_SPACE(sizeof(int) * 512); + } - // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to - // a bug that exists in both Clang and GCC: - // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799 - // http://llvm.org/bugs/show_bug.cgi?id=12286 - goto error; + // On Linux, CMSG_SPACE will align to a word-size boundary, but on Mac it always aligns to a + // 32-bit boundary. I guess aligning to 32 bits helps avoid the problem where you + // surprisingly end up with space for two file descriptors when you only wanted one. However, + // cmsghdr's preferred alignment is word-size (it contains a size_t). If we stack-allocate + // the buffer, we need to make sure it is aligned properly (maybe not on x64, but maybe on + // other platforms), so we want to allocate an array of words (we use void*). So... we use + // CMSG_SPACE() and then additionally round up to deal with Mac. + size_t msgWords = (msgBytes + sizeof(void*) - 1) / sizeof(void*); + KJ_STACK_ARRAY(void*, cmsgSpace, msgWords, 16, 256); + auto cmsgBytes = cmsgSpace.asBytes(); + memset(cmsgBytes.begin(), 0, cmsgBytes.size()); + msg.msg_control = cmsgBytes.begin(); + msg.msg_controllen = msgBytes; + +#ifdef MSG_CMSG_CLOEXEC + static constexpr int RECVMSG_FLAGS = MSG_CMSG_CLOEXEC; +#else + static constexpr int RECVMSG_FLAGS = 0; +#endif + + KJ_NONBLOCKING_SYSCALL(n = ::recvmsg(fd, &msg, RECVMSG_FLAGS)) { + // Error. + + // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to + // a bug that exists in both Clang and GCC: + // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799 + // http://llvm.org/bugs/show_bug.cgi?id=12286 + goto error; + } + + if (n >= 0) { + // Process all messages. + // + // WARNING DANGER: We have to be VERY careful not to miss a file descriptor here, because + // if we do, then that FD will never be closed, and a malicious peer could exploit this to + // fill up our FD table, creating a DoS attack. Some things to keep in mind: + // - CMSG_SPACE() could have rounded up the space for alignment purposes, and this could + // mean we permitted the kernel to deliver more file descriptors than `maxFds`. We need + // to close the extras. + // - We can receive multiple ancillary messages at once. In particular, there is also + // SCM_CREDENTIALS. The sender decides what to send. They could send SCM_CREDENTIALS + // first followed by SCM_RIGHTS. We need to make sure we see both. + size_t nfds = 0; + size_t spaceLeft = msg.msg_controllen; + Vector ancillaryMessages; + for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (spaceLeft >= CMSG_LEN(0) && + cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { + // Some operating systems (like MacOS) do not adjust csmg_len when the message is + // truncated. We must do so ourselves or risk overrunning the buffer. + auto len = kj::min(cmsg->cmsg_len, spaceLeft); + auto data = arrayPtr(reinterpret_cast(CMSG_DATA(cmsg)), + (len - CMSG_LEN(0)) / sizeof(int)); + kj::Vector trashFds; + for (auto fd: data) { + kj::AutoCloseFd ownFd(fd); + if (nfds < maxFds) { + fdBuffer[nfds++] = kj::mv(ownFd); + } else { + trashFds.add(kj::mv(ownFd)); + } + } + } else if (spaceLeft >= CMSG_LEN(0) && ancillaryMsgCallback != nullptr) { + auto len = kj::min(cmsg->cmsg_len, spaceLeft); + auto data = ArrayPtr(CMSG_DATA(cmsg), len - CMSG_LEN(0)); + ancillaryMessages.add(cmsg->cmsg_level, cmsg->cmsg_type, data); + } + + if (spaceLeft >= CMSG_LEN(0) && spaceLeft >= cmsg->cmsg_len) { + spaceLeft -= cmsg->cmsg_len; + } else { + spaceLeft = 0; + } + } + +#ifndef MSG_CMSG_CLOEXEC + for (size_t i = 0; i < nfds; i++) { + setCloseOnExec(fdBuffer[i]); + } +#endif + + if (ancillaryMessages.size() > 0) { + KJ_IF_MAYBE(fn, ancillaryMsgCallback) { + (*fn)(ancillaryMessages.asPtr()); + } + } + + alreadyRead.capCount += nfds; + fdBuffer += nfds; + maxFds -= nfds; + } } + if (false) { error: return alreadyRead; @@ -247,51 +718,41 @@ private: if (n < 0) { // Read would block. return observer.whenBecomesReadable().then([=]() { - return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead); + return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, alreadyRead); }); } else if (n == 0) { // EOF -OR- maxBytes == 0. return alreadyRead; } else if (implicitCast(n) >= minBytes) { // We read enough to stop here. - return alreadyRead + n; + alreadyRead.byteCount += n; + return alreadyRead; } else { // The kernel returned fewer bytes than we asked for (and fewer than we need). buffer = reinterpret_cast(buffer) + n; minBytes -= n; maxBytes -= n; - alreadyRead += n; - - KJ_IF_MAYBE(atEnd, observer.atEndHint()) { - if (*atEnd) { - // We've already received an indication that the next read() will return EOF, so there's - // nothing to wait for. - return alreadyRead; - } else { - // As of the last time the event queue was checked, the kernel reported that we were - // *not* at the end of the stream. It's unlikely that this has changed in the short time - // it took to handle the event, therefore calling read() now will almost certainly fail - // with EAGAIN. Moreover, since EOF had not been received as of the last check, we know - // that even if it was received since then, whenBecomesReadable() will catch that. So, - // let's go ahead and skip calling read() here and instead go straight to waiting for - // more input. - return observer.whenBecomesReadable().then([=]() { - return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead); - }); - } - } else { - // The kernel has not indicated one way or the other whether we are likely to be at EOF. - // In this case we *must* keep calling read() until we either get a return of zero or - // EAGAIN. - return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead); - } + alreadyRead.byteCount += n; + + // According to David Klempner, who works on Stubby at Google, we sadly CANNOT assume that + // we've consumed the whole read buffer here. If a signal is delivered in the middle of a + // read() -- yes, even a non-blocking read -- it can cause the kernel to return a partial + // result, with data still in the buffer. + // https://bugzilla.kernel.org/show_bug.cgi?id=199131 + // https://twitter.com/CaptainSegfault/status/1112622245531144194 + // + // Unfortunately, we have no choice but to issue more read()s until it either tells us EOF + // or EAGAIN. We used to have an optimization here using observer.atEndHint() (when it is + // non-null) to avoid a redundant call to read(). Alas... + return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, alreadyRead); } } Promise writeInternal(ArrayPtr firstPiece, - ArrayPtr> morePieces) { - const size_t iovmax = kj::miniposix::iovMax(1 + morePieces.size()); + ArrayPtr> morePieces, + ArrayPtr fds) { + const size_t iovmax = kj::miniposix::iovMax(); // If there are more than IOV_MAX pieces, we'll only write the first IOV_MAX for now, and // then we'll loop later. KJ_STACK_ARRAY(struct iovec, iov, kj::min(1 + morePieces.size(), iovmax), 16, 128); @@ -307,23 +768,87 @@ private: iovTotal += iov[i].iov_len; } - ssize_t writeResult; - KJ_NONBLOCKING_SYSCALL(writeResult = ::writev(fd, iov.begin(), iov.size())) { - // Error. + if (iovTotal == 0) { + KJ_REQUIRE(fds.size() == 0, "can't write FDs without bytes"); + return kj::READY_NOW; + } - // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to - // a bug that exists in both Clang and GCC: - // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799 - // http://llvm.org/bugs/show_bug.cgi?id=12286 - goto error; + ssize_t n; + if (fds.size() == 0) { + KJ_NONBLOCKING_SYSCALL(n = ::writev(fd, iov.begin(), iov.size()), iovTotal, iov.size()) { + // Error. + + // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to + // a bug that exists in both Clang and GCC: + // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799 + // http://llvm.org/bugs/show_bug.cgi?id=12286 + goto error; + } + } else { + struct msghdr msg; + memset(&msg, 0, sizeof(msg)); + msg.msg_iov = iov.begin(); + msg.msg_iovlen = iov.size(); + + // Allocate space to send a cmsg. + size_t msgBytes = CMSG_SPACE(sizeof(int) * fds.size()); + // On Linux, CMSG_SPACE will align to a word-size boundary, but on Mac it always aligns to a + // 32-bit boundary. I guess aligning to 32 bits helps avoid the problem where you + // surprisingly end up with space for two file descriptors when you only wanted one. However, + // cmsghdr's preferred alignment is word-size (it contains a size_t). If we stack-allocate + // the buffer, we need to make sure it is aligned properly (maybe not on x64, but maybe on + // other platforms), so we want to allocate an array of words (we use void*). So... we use + // CMSG_SPACE() and then additionally round up to deal with Mac. + size_t msgWords = (msgBytes + sizeof(void*) - 1) / sizeof(void*); + KJ_STACK_ARRAY(void*, cmsgSpace, msgWords, 16, 256); + auto cmsgBytes = cmsgSpace.asBytes(); + memset(cmsgBytes.begin(), 0, cmsgBytes.size()); + msg.msg_control = cmsgBytes.begin(); + msg.msg_controllen = msgBytes; + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(sizeof(int) * fds.size()); + memcpy(CMSG_DATA(cmsg), fds.begin(), fds.asBytes().size()); + + KJ_NONBLOCKING_SYSCALL(n = ::sendmsg(fd, &msg, 0)) { + // Error. + + // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to + // a bug that exists in both Clang and GCC: + // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799 + // http://llvm.org/bugs/show_bug.cgi?id=12286 + goto error; + } } + if (false) { error: return kj::READY_NOW; } - // A negative result means EAGAIN, which we can treat the same as having written zero bytes. - size_t n = writeResult < 0 ? 0 : writeResult; + if (n < 0) { + // Got EAGAIN. Nothing was written. + return observer.whenBecomesWritable().then([=]() { + return writeInternal(firstPiece, morePieces, fds); + }); + } else if (n == 0) { + // Why would a sendmsg() with a non-empty message ever return 0 when writing to a stream + // socket? If there's no room in the send buffer, it should fail with EAGAIN. If the + // connection is closed, it should fail with EPIPE. Various documents and forum posts around + // the internet claim this can happen but no one seems to know when. My guess is it can only + // happen if we try to send an empty message -- which we didn't. So I think this is + // impossible. If it is possible, we need to figure out how to correctly handle it, which + // depends on what caused it. + // + // Note in particular that if 0 is a valid return here, and we sent an SCM_RIGHTS message, + // we need to know whether the message was sent or not, in order to decide whether to retry + // sending it! + KJ_FAIL_ASSERT("non-empty sendmsg() returned 0"); + } + + // Non-zero bytes were written. This also implies that *all* FDs were written. // Discard all data that was written, then issue a new write for what's left (if any). for (;;) { @@ -334,12 +859,12 @@ private: if (iovTotal == 0) { // Oops, what actually happened is that we hit the IOV_MAX limit. Don't wait. - return writeInternal(firstPiece, morePieces); + return writeInternal(firstPiece, morePieces, nullptr); } - return observer.whenBecomesWritable().then([=]() { - return writeInternal(firstPiece, morePieces); - }); + // As with read(), we cannot assume that a short write() really means the write buffer is + // full (see comments in the read path above). We have to write again. + return writeInternal(firstPiece, morePieces, nullptr); } else if (morePieces.size() == 0) { // First piece was fully-consumed and there are no more pieces, so we're done. KJ_DASSERT(n == firstPiece.size(), n); @@ -355,6 +880,10 @@ private: } }; +#if __linux__ && !__ANDROID__ +constexpr size_t AsyncStreamFd::MAX_SPLICE_LEN; +#endif // __linux__ && !__ANDROID__ + // ======================================================================================= class SocketAddress { @@ -449,7 +978,12 @@ public: return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port)); } case AF_UNIX: { - return str("unix:", addr.unixDomain.sun_path); + auto path = _::safeUnixPath(&addr.unixDomain, addrlen); + if (path.size() > 0 && path[0] == '\0') { + return str("unix-abstract:", path.slice(1, path.size())); + } else { + return str("unix:", path); + } } default: return str("(unknown address family ", addr.generic.sa_family, ")"); @@ -457,11 +991,12 @@ public: } static Promise> lookupHost( - LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint); + LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint, + _::NetworkFilter& filter); // Perform a DNS lookup. static Promise> parse( - LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) { + LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) { // TODO(someday): Allow commas in `str`. SocketAddress result; @@ -470,9 +1005,39 @@ public: StringPtr path = str.slice(strlen("unix:")); KJ_REQUIRE(path.size() < sizeof(addr.unixDomain.sun_path), "Unix domain socket address is too long.", str); + KJ_REQUIRE(path.size() == strlen(path.cStr()), + "Unix domain socket address contains NULL. Use" + " 'unix-abstract:' for the abstract namespace."); result.addr.unixDomain.sun_family = AF_UNIX; strcpy(result.addr.unixDomain.sun_path, path.cStr()); result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1; + + if (!result.parseAllowedBy(filter)) { + KJ_FAIL_REQUIRE("unix sockets blocked by restrictPeers()"); + return Array(); + } + + auto array = kj::heapArrayBuilder(1); + array.add(result); + return array.finish(); + } + + if (str.startsWith("unix-abstract:")) { + StringPtr path = str.slice(strlen("unix-abstract:")); + KJ_REQUIRE(path.size() + 1 < sizeof(addr.unixDomain.sun_path), + "Unix domain socket address is too long.", str); + result.addr.unixDomain.sun_family = AF_UNIX; + result.addr.unixDomain.sun_path[0] = '\0'; + // although not strictly required by Linux, also copy the trailing + // NULL terminator so that we can safely read it back in toString + memcpy(result.addr.unixDomain.sun_path + 1, path.cStr(), path.size() + 1); + result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1; + + if (!result.parseAllowedBy(filter)) { + KJ_FAIL_REQUIRE("abstract unix sockets blocked by restrictPeers()"); + return Array(); + } + auto array = kj::heapArrayBuilder(1); array.add(result); return array.finish(); @@ -525,7 +1090,8 @@ public: port = strtoul(portText->cStr(), &endptr, 0); if (portText->size() == 0 || *endptr != '\0') { // Not a number. Maybe it's a service name. Fall back to DNS. - return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint); + return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint, + filter); } KJ_REQUIRE(port < 65536, "Port number too large."); } else { @@ -547,6 +1113,7 @@ public: result.addr.inet6.sin6_family = AF_INET6; result.addr.inet6.sin6_port = htons(port); #endif + auto array = kj::heapArrayBuilder(1); array.add(result); return array.finish(); @@ -565,26 +1132,34 @@ public: addrTarget = &result.addr.inet4.sin_addr; } - // addrPart is not necessarily NUL-terminated so we have to make a copy. :( - KJ_REQUIRE(addrPart.size() < INET6_ADDRSTRLEN - 1, "IP address too long.", addrPart); - char buffer[INET6_ADDRSTRLEN]; - memcpy(buffer, addrPart.begin(), addrPart.size()); - buffer[addrPart.size()] = '\0'; - - // OK, parse it! - switch (inet_pton(af, buffer, addrTarget)) { - case 1: { - // success. - auto array = kj::heapArrayBuilder(1); - array.add(result); - return array.finish(); + if (addrPart.size() < INET6_ADDRSTRLEN - 1) { + // addrPart is not necessarily NUL-terminated so we have to make a copy. :( + char buffer[INET6_ADDRSTRLEN]; + memcpy(buffer, addrPart.begin(), addrPart.size()); + buffer[addrPart.size()] = '\0'; + + // OK, parse it! + switch (inet_pton(af, buffer, addrTarget)) { + case 1: { + // success. + if (!result.parseAllowedBy(filter)) { + KJ_FAIL_REQUIRE("address family blocked by restrictPeers()"); + return Array(); + } + + auto array = kj::heapArrayBuilder(1); + array.add(result); + return array.finish(); + } + case 0: + // It's apparently not a simple address... fall back to DNS. + break; + default: + KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart); } - case 0: - // It's apparently not a simple address... fall back to DNS. - return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port); - default: - KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart); } + + return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter); } static SocketAddress getLocalAddress(int sockfd) { @@ -594,9 +1169,23 @@ public: return result; } + bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) { + return filter.shouldAllow(&addr.generic, addrlen); + } + + bool parseAllowedBy(_::NetworkFilter& filter) { + return filter.shouldAllowParse(&addr.generic, addrlen); + } + + kj::Own getIdentity(LowLevelAsyncIoProvider& llaiop, + LowLevelAsyncIoProvider::NetworkFilter& filter, + AsyncIoStream& stream) const; + private: - SocketAddress(): addrlen(0) { - memset(&addr, 0, sizeof(addr)); + SocketAddress() { + // We need to memset the whole object 0 otherwise Valgrind gets unhappy when we write it to a + // pipe, due to the padding bytes being uninitialized. + memset(this, 0, sizeof(*this)); } socklen_t addrlen; @@ -610,54 +1199,6 @@ private: } addr; struct LookupParams; - class LookupReader; -}; - -class SocketAddress::LookupReader { - // Reads SocketAddresses off of a pipe coming from another thread that is performing - // getaddrinfo. - -public: - LookupReader(kj::Own&& thread, kj::Own&& input) - : thread(kj::mv(thread)), input(kj::mv(input)) {} - - ~LookupReader() { - if (thread) thread->detach(); - } - - Promise> read() { - return input->tryRead(¤t, sizeof(current), sizeof(current)).then( - [this](size_t n) -> Promise> { - if (n < sizeof(current)) { - thread = nullptr; - // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check - // anyway. - KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; } - return addresses.releaseAsArray(); - } else { - // getaddrinfo() can return multiple copies of the same address for several reasons. - // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so - // it may return two copies of the same address, one for each type, unless it explicitly - // knows that the service name given is specific to one type. But we can't tell it a type, - // because we don't actually know which one the user wants, and if we specify SOCK_STREAM - // while the user specified a UDP service name then they'll get a resolution error which - // is lame. (At least, I think that's how it works.) - // - // So we instead resort to de-duping results. - if (alreadySeen.insert(current).second) { - addresses.add(current); - } - return read(); - } - }); - } - -private: - kj::Own thread; - kj::Own input; - SocketAddress current; - kj::Vector addresses; - std::set alreadySeen; }; struct SocketAddress::LookupParams { @@ -666,7 +1207,8 @@ struct SocketAddress::LookupParams { }; Promise> SocketAddress::lookupHost( - LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) { + LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint, + _::NetworkFilter& filter) { // This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is // the only cross-platform DNS API and it is blocking. // @@ -674,107 +1216,164 @@ Promise> SocketAddress::lookupHost( // Maybe use the various platform-specific asynchronous DNS libraries? Please do not implement // a custom DNS resolver... - int fds[2]; -#if __linux__ && !__BIONIC__ - KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC)); -#else - KJ_SYSCALL(pipe(fds)); -#endif - - auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS); - - int outFd = fds[1]; - + auto paf = newPromiseAndCrossThreadFulfiller>(); LookupParams params = { kj::mv(host), kj::mv(service) }; - auto thread = heap(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) { - FdOutputStream output((AutoCloseFd(outFd))); - - struct addrinfo* list; - int status = getaddrinfo( - params.host == "*" ? nullptr : params.host.cStr(), - params.service == nullptr ? nullptr : params.service.cStr(), - nullptr, &list); - if (status == 0) { - KJ_DEFER(freeaddrinfo(list)); - - struct addrinfo* cur = list; - while (cur != nullptr) { - if (params.service == nullptr) { - switch (cur->ai_addr->sa_family) { - case AF_INET: - ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint); - break; - case AF_INET6: - ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint); - break; - default: - break; + auto thread = heap( + [fulfiller=kj::mv(paf.fulfiller),params=kj::mv(params),portHint]() mutable { + // getaddrinfo() can return multiple copies of the same address for several reasons. + // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so + // it may return two copies of the same address, one for each type, unless it explicitly + // knows that the service name given is specific to one type. But we can't tell it a type, + // because we don't actually know which one the user wants, and if we specify SOCK_STREAM + // while the user specified a UDP service name then they'll get a resolution error which + // is lame. (At least, I think that's how it works.) + // + // So we instead resort to de-duping results. + std::set result; + + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; +#if __BIONIC__ + // AI_V4MAPPED causes getaddrinfo() to fail on Bionic libc (Android). + hints.ai_flags = AI_ADDRCONFIG; +#else + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG; +#endif + struct addrinfo* list; + int status = getaddrinfo( + params.host == "*" ? nullptr : params.host.cStr(), + params.service == nullptr ? nullptr : params.service.cStr(), + &hints, &list); + if (status == 0) { + KJ_DEFER(freeaddrinfo(list)); + + struct addrinfo* cur = list; + while (cur != nullptr) { + if (params.service == nullptr) { + switch (cur->ai_addr->sa_family) { + case AF_INET: + ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint); + break; + case AF_INET6: + ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint); + break; + default: + break; + } } - } - SocketAddress addr; - memset(&addr, 0, sizeof(addr)); // mollify valgrind - if (params.host == "*") { - // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo(). - addr.wildcard = true; - addr.addrlen = sizeof(addr.addr.inet6); - addr.addr.inet6.sin6_family = AF_INET6; - switch (cur->ai_addr->sa_family) { - case AF_INET: - addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port; - break; - case AF_INET6: - addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port; - break; - default: - addr.addr.inet6.sin6_port = portHint; - break; + SocketAddress addr; + if (params.host == "*") { + // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo(). + addr.wildcard = true; + addr.addrlen = sizeof(addr.addr.inet6); + addr.addr.inet6.sin6_family = AF_INET6; + switch (cur->ai_addr->sa_family) { + case AF_INET: + addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port; + break; + case AF_INET6: + addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port; + break; + default: + addr.addr.inet6.sin6_port = portHint; + break; + } + } else { + addr.addrlen = cur->ai_addrlen; + memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen); } - } else { - addr.addrlen = cur->ai_addrlen; - memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen); + result.insert(addr); + cur = cur->ai_next; + } + } else if (status == EAI_SYSTEM) { + KJ_FAIL_SYSCALL("getaddrinfo", errno, params.host, params.service) { + return; + } + } else { + KJ_FAIL_REQUIRE("DNS lookup failed.", + params.host, params.service, gai_strerror(status)) { + return; } - KJ_ASSERT_CAN_MEMCPY(SocketAddress); - output.write(&addr, sizeof(addr)); - cur = cur->ai_next; - } - } else if (status == EAI_SYSTEM) { - KJ_FAIL_SYSCALL("getaddrinfo", errno, params.host, params.service) { - return; } + })) { + fulfiller->reject(kj::mv(*exception)); } else { - KJ_FAIL_REQUIRE("DNS lookup failed.", - params.host, params.service, gai_strerror(status)) { - return; - } + fulfiller->fulfill(KJ_MAP(addr, result) { return addr; }); } - })); + }); - auto reader = heap(kj::mv(thread), kj::mv(input)); - return reader->read().attach(kj::mv(reader)); + return kj::mv(paf.promise); } // ======================================================================================= class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor { public: - FdConnectionReceiver(UnixEventPort& eventPort, int fd, uint flags) - : OwnedFileDescriptor(fd, flags), eventPort(eventPort), + FdConnectionReceiver(LowLevelAsyncIoProvider& lowLevel, + UnixEventPort& eventPort, int fd, + LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) + : OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter), observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ) {} Promise> accept() override { + return acceptImpl(false).then([](AuthenticatedStream&& a) { return kj::mv(a.stream); }); + } + + Promise acceptAuthenticated() override { + return acceptImpl(true); + } + + Promise acceptImpl(bool authenticated) { int newFd; + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + retry: #if __linux__ && !__BIONIC__ - newFd = ::accept4(fd, nullptr, nullptr, SOCK_NONBLOCK | SOCK_CLOEXEC); + newFd = ::accept4(fd, reinterpret_cast(&addr), &addrlen, + SOCK_NONBLOCK | SOCK_CLOEXEC); #else - newFd = ::accept(fd, nullptr, nullptr); + newFd = ::accept(fd, reinterpret_cast(&addr), &addrlen); #endif if (newFd >= 0) { - return Own(heap(eventPort, newFd, NEW_FD_FLAGS)); + kj::AutoCloseFd ownFd(newFd); + if (!filter.shouldAllow(reinterpret_cast(&addr), addrlen)) { + // Ignore disallowed address. + return acceptImpl(authenticated); + } else { + // TODO(perf): As a hack for the 0.4 release we are always setting + // TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's + // RPC protocol. Later, we should extend the interface to provide more + // control over this. Perhaps write() should have a flag which + // specifies whether to pass MSG_MORE. + int one = 1; + KJ_SYSCALL_HANDLE_ERRORS(::setsockopt( + ownFd.get(), IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one))) { + case EOPNOTSUPP: + case ENOPROTOOPT: // (returned for AF_UNIX in cygwin) +#if __FreeBSD__ + case EINVAL: // (returned for AF_UNIX in FreeBSD) +#endif + break; + default: + KJ_FAIL_SYSCALL("setsocketopt(IPPROTO_TCP, TCP_NODELAY)", error); + } + + AuthenticatedStream result; + result.stream = heap(eventPort, ownFd.release(), NEW_FD_FLAGS, + UnixEventPort::FdObserver::OBSERVE_READ_WRITE); + if (authenticated) { + result.peerIdentity = SocketAddress(reinterpret_cast(&addr), addrlen) + .getIdentity(lowLevel, filter, *result.stream); + } + return kj::mv(result); + } } else { int error = errno; @@ -784,8 +1383,8 @@ public: case EWOULDBLOCK: #endif // Not ready yet. - return observer.whenBecomesReadable().then([this]() { - return accept(); + return observer.whenBecomesReadable().then([this,authenticated]() { + return acceptImpl(authenticated); }); case EINTR: @@ -824,16 +1423,24 @@ public: void setsockopt(int level, int option, const void* value, uint length) override { KJ_SYSCALL(::setsockopt(fd, level, option, value, length)); } + void getsockname(struct sockaddr* addr, uint* length) override { + socklen_t socklen = *length; + KJ_SYSCALL(::getsockname(fd, addr, &socklen)); + *length = socklen; + } public: + LowLevelAsyncIoProvider& lowLevel; UnixEventPort& eventPort; + LowLevelAsyncIoProvider::NetworkFilter& filter; UnixEventPort::FdObserver observer; }; class DatagramPortImpl final: public DatagramPort, public OwnedFileDescriptor { public: - DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd, uint flags) - : OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), + DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd, + LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) + : OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter), observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ | UnixEventPort::FdObserver::OBSERVE_WRITE) {} @@ -861,27 +1468,36 @@ public: public: LowLevelAsyncIoProvider& lowLevel; UnixEventPort& eventPort; + LowLevelAsyncIoProvider::NetworkFilter& filter; UnixEventPort::FdObserver observer; }; class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider { public: LowLevelAsyncIoProviderImpl() - : eventLoop(eventPort), waitScope(eventLoop) {} + : eventPort(), eventLoop(eventPort), waitScope(eventLoop) {} inline WaitScope& getWaitScope() { return waitScope; } Own wrapInputFd(int fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_READ); } Own wrapOutputFd(int fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_WRITE); } Own wrapSocketFd(int fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_READ_WRITE); + } + Own wrapUnixSocketFd(Fd fd, uint flags = 0) override { + return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_READ_WRITE); } Promise> wrapConnectingSocketFd( int fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override { + // It's important that we construct the AsyncStreamFd first, so that `flags` are honored, + // especially setting nonblocking mode and taking ownership. + auto result = heap(eventPort, fd, flags, + UnixEventPort::FdObserver::OBSERVE_READ_WRITE); + // Unfortunately connect() doesn't fit the mold of KJ_NONBLOCKING_SYSCALL, since it indicates // non-blocking using EINPROGRESS. for (;;) { @@ -891,7 +1507,8 @@ public: // Fine. break; } else if (error != EINTR) { - KJ_FAIL_SYSCALL("connect()", error) { break; } + auto address = SocketAddress(addr, addrlen).toString(); + KJ_FAIL_SYSCALL("connect()", error, address) { break; } return Own(); } } else { @@ -900,10 +1517,8 @@ public: } } - auto result = heap(eventPort, fd, flags); - auto connected = result->waitConnected(); - return connected.then(kj::mvCapture(result, [fd](Own&& stream) { + return connected.then([fd,stream=kj::mv(result)]() mutable -> Own { int err; socklen_t errlen = sizeof(err); KJ_SYSCALL(getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &errlen)); @@ -911,13 +1526,15 @@ public: KJ_FAIL_SYSCALL("connect()", err) { break; } } return kj::mv(stream); - })); + }); } - Own wrapListenSocketFd(int fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + Own wrapListenSocketFd( + int fd, NetworkFilter& filter, uint flags = 0) override { + return heap(*this, eventPort, fd, filter, flags); } - Own wrapDatagramSocketFd(int fd, uint flags = 0) override { - return heap(*this, eventPort, fd, flags); + Own wrapDatagramSocketFd( + int fd, NetworkFilter& filter, uint flags = 0) override { + return heap(*this, eventPort, fd, filter, flags); } Timer& getTimer() override { return eventPort.getTimer(); } @@ -934,39 +1551,50 @@ private: class NetworkAddressImpl final: public NetworkAddress { public: - NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array addrs) - : lowLevel(lowLevel), addrs(kj::mv(addrs)) {} + NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, + LowLevelAsyncIoProvider::NetworkFilter& filter, + Array addrs) + : lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {} Promise> connect() override { auto addrsCopy = heapArray(addrs.asPtr()); - auto promise = connectImpl(lowLevel, addrsCopy); + auto promise = connectImpl(lowLevel, filter, addrsCopy, false); + return promise.attach(kj::mv(addrsCopy)) + .then([](AuthenticatedStream&& a) { return kj::mv(a.stream); }); + } + + Promise connectAuthenticated() override { + auto addrsCopy = heapArray(addrs.asPtr()); + auto promise = connectImpl(lowLevel, filter, addrsCopy, true); return promise.attach(kj::mv(addrsCopy)); } Own listen() override { - if (addrs.size() > 1) { - KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will " - "be used. If this is incorrect, specify the address numerically. This may be fixed " - "in the future.", addrs[0].toString()); - } + auto makeReceiver = [&](SocketAddress& addr) { + int fd = addr.socket(SOCK_STREAM); - int fd = addrs[0].socket(SOCK_STREAM); + { + KJ_ON_SCOPE_FAILURE(close(fd)); - { - KJ_ON_SCOPE_FAILURE(close(fd)); + // We always enable SO_REUSEADDR because having to take your server down for five minutes + // before it can restart really sucks. + int optval = 1; + KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval))); - // We always enable SO_REUSEADDR because having to take your server down for five minutes - // before it can restart really sucks. - int optval = 1; - KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval))); + addr.bind(fd); - addrs[0].bind(fd); + // TODO(someday): Let queue size be specified explicitly in string addresses. + KJ_SYSCALL(::listen(fd, SOMAXCONN)); + } - // TODO(someday): Let queue size be specified explicitly in string addresses. - KJ_SYSCALL(::listen(fd, SOMAXCONN)); - } + return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS); + }; - return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS); + if (addrs.size() == 1) { + return makeReceiver(addrs[0]); + } else { + return newAggregateConnectionReceiver(KJ_MAP(addr, addrs) { return makeReceiver(addr); }); + } } Own bindDatagramPort() override { @@ -989,11 +1617,11 @@ public: addrs[0].bind(fd); } - return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS); + return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS); } Own clone() override { - return kj::heap(lowLevel, kj::heapArray(addrs.asPtr())); + return kj::heap(lowLevel, filter, kj::heapArray(addrs.asPtr())); } String toString() override { @@ -1007,26 +1635,40 @@ public: private: LowLevelAsyncIoProvider& lowLevel; + LowLevelAsyncIoProvider::NetworkFilter& filter; Array addrs; uint counter = 0; - static Promise> connectImpl( - LowLevelAsyncIoProvider& lowLevel, ArrayPtr addrs) { + static Promise connectImpl( + LowLevelAsyncIoProvider& lowLevel, + LowLevelAsyncIoProvider::NetworkFilter& filter, + ArrayPtr addrs, + bool authenticated) { KJ_ASSERT(addrs.size() > 0); - int fd = addrs[0].socket(SOCK_STREAM); - - return kj::evalNow([&]() { - return lowLevel.wrapConnectingSocketFd( - fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS); - }).then([](Own&& stream) -> Promise> { + return kj::evalNow([&]() -> Promise> { + if (!addrs[0].allowedBy(filter)) { + return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()"); + } else { + int fd = addrs[0].socket(SOCK_STREAM); + return lowLevel.wrapConnectingSocketFd( + fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS); + } + }).then([&lowLevel,&filter,addrs,authenticated](Own&& stream) + -> Promise { // Success, pass along. - return kj::mv(stream); - }, [&lowLevel,addrs](Exception&& exception) mutable -> Promise> { + AuthenticatedStream result; + result.stream = kj::mv(stream); + if (authenticated) { + result.peerIdentity = addrs[0].getIdentity(lowLevel, filter, *result.stream); + } + return kj::mv(result); + }, [&lowLevel,&filter,addrs,authenticated](Exception&& exception) mutable + -> Promise { // Connect failed. if (addrs.size() > 1) { // Try the next address instead. - return connectImpl(lowLevel, addrs.slice(1, addrs.size())); + return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()), authenticated); } else { // No more addresses to try, so propagate the exception. return kj::mv(exception); @@ -1035,28 +1677,98 @@ private: } }; +kj::Own SocketAddress::getIdentity(kj::LowLevelAsyncIoProvider& llaiop, + LowLevelAsyncIoProvider::NetworkFilter& filter, + AsyncIoStream& stream) const { + switch (addr.generic.sa_family) { + case AF_INET: + case AF_INET6: { + auto builder = kj::heapArrayBuilder(1); + builder.add(*this); + return NetworkPeerIdentity::newInstance( + kj::heap(llaiop, filter, builder.finish())); + } + case AF_UNIX: { + LocalPeerIdentity::Credentials result; + + // There is little documentation on what happens when the uid/pid can't be obtained, but I've + // seen vague references on the internet saying that a PID of 0 and a UID of uid_t(-1) are used + // as invalid values. + +// OpenBSD defines SO_PEERCRED but uses a different interface for it +// hence we're falling back to LOCAL_PEERCRED +#if defined(SO_PEERCRED) && !__OpenBSD__ + struct ucred creds; + uint length = sizeof(creds); + stream.getsockopt(SOL_SOCKET, SO_PEERCRED, &creds, &length); + if (creds.pid > 0) { + result.pid = creds.pid; + } + if (creds.uid != static_cast(-1)) { + result.uid = creds.uid; + } + +#elif defined(LOCAL_PEERCRED) + // MacOS / FreeBSD / OpenBSD + struct xucred creds; + uint length = sizeof(creds); + stream.getsockopt(SOL_LOCAL, LOCAL_PEERCRED, &creds, &length); + KJ_ASSERT(length == sizeof(creds)); + if (creds.cr_uid != static_cast(-1)) { + result.uid = creds.cr_uid; + } + +#if defined(LOCAL_PEERPID) + // MacOS only? + pid_t pid; + length = sizeof(pid); + stream.getsockopt(SOL_LOCAL, LOCAL_PEERPID, &pid, &length); + KJ_ASSERT(length == sizeof(pid)); + if (pid > 0) { + result.pid = pid; + } +#endif +#endif + + return LocalPeerIdentity::newInstance(result); + } + default: + return UnknownPeerIdentity::newInstance(); + } +} + class SocketNetwork final: public Network { public: explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {} + explicit SocketNetwork(SocketNetwork& parent, + kj::ArrayPtr allow, + kj::ArrayPtr deny) + : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {} Promise> parseAddress(StringPtr addr, uint portHint = 0) override { - auto& lowLevelCopy = lowLevel; - return evalLater(mvCapture(heapString(addr), - [&lowLevelCopy,portHint](String&& addr) { - return SocketAddress::parse(lowLevelCopy, addr, portHint); - })).then([&lowLevelCopy](Array addresses) -> Own { - return heap(lowLevelCopy, kj::mv(addresses)); + return evalNow([&]() { + return SocketAddress::parse(lowLevel, addr, portHint, filter); + }).then([this](Array addresses) -> Own { + return heap(lowLevel, filter, kj::mv(addresses)); }); } Own getSockaddr(const void* sockaddr, uint len) override { auto array = kj::heapArrayBuilder(1); array.add(SocketAddress(sockaddr, len)); - return Own(heap(lowLevel, array.finish())); + KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; } + return Own(heap(lowLevel, filter, array.finish())); + } + + Own restrictPeers( + kj::ArrayPtr allow, + kj::ArrayPtr deny = nullptr) override { + return heap(*this, allow, deny); } private: LowLevelAsyncIoProvider& lowLevel; + _::NetworkFilter filter; }; // ======================================================================================= @@ -1088,7 +1800,7 @@ Promise DatagramPortImpl::send( msg.msg_name = const_cast(implicitCast(addr.getRaw())); msg.msg_namelen = addr.getRawSize(); - const size_t iovmax = kj::miniposix::iovMax(pieces.size()); + const size_t iovmax = kj::miniposix::iovMax(); KJ_STACK_ARRAY(struct iovec, iov, kj::min(pieces.size(), iovmax), 16, 64); for (size_t i: kj::indices(pieces)) { @@ -1101,7 +1813,7 @@ Promise DatagramPortImpl::send( // Too many pieces, but we can't use multiple syscalls because they'd send separate // datagrams. We'll have to copy the trailing pieces into a temporary array. // - // TODO(perf): On Linux we could use multiple syscalls via MSG_MORE. + // TODO(perf): On Linux we could use multiple syscalls via MSG_MORE or sendmsg/sendmmsg. size_t extraSize = 0; for (size_t i = iovmax - 1; i < pieces.size(); i++) { extraSize += pieces[i].size(); @@ -1112,8 +1824,8 @@ Promise DatagramPortImpl::send( memcpy(extra.begin() + extraSize, pieces[i].begin(), pieces[i].size()); extraSize += pieces[i].size(); } - iov[iovmax - 1].iov_base = extra.begin(); - iov[iovmax - 1].iov_len = extra.size(); + iov.back().iov_base = extra.begin(); + iov.back().iov_len = extra.size(); } msg.msg_iov = iov.begin(); @@ -1167,10 +1879,16 @@ public: return receive(); }); } else { + if (!port.filter.shouldAllow(reinterpret_cast(msg.msg_name), + msg.msg_namelen)) { + // Ignore message from disallowed source. + return receive(); + } + receivedSize = n; contentTruncated = msg.msg_flags & MSG_TRUNC; - source.emplace(port.lowLevel, msg.msg_name, msg.msg_namelen); + source.emplace(port.lowLevel, port.filter, msg.msg_name, msg.msg_namelen); ancillaryList.resize(0); ancillaryTruncated = msg.msg_flags & MSG_CTRUNC; @@ -1181,6 +1899,10 @@ public: // when truncated. On other platforms (Linux) the length in cmsghdr will itself be // truncated to fit within the buffer. +#if __APPLE__ +// On MacOS, `CMSG_SPACE(0)` triggers a bogus warning. +#pragma GCC diagnostic ignored "-Wnull-pointer-arithmetic" +#endif const byte* pos = reinterpret_cast(cmsg); size_t available = ancillaryBuffer.end() - pos; if (available < CMSG_SPACE(0)) { @@ -1228,9 +1950,10 @@ private: bool ancillaryTruncated = false; struct StoredAddress { - StoredAddress(LowLevelAsyncIoProvider& lowLevel, const void* sockaddr, uint length) + StoredAddress(LowLevelAsyncIoProvider& lowLevel, LowLevelAsyncIoProvider::NetworkFilter& filter, + const void* sockaddr, uint length) : raw(sockaddr, length), - abstract(lowLevel, Array(&raw, 1, NullArrayDisposer::instance)) {} + abstract(lowLevel, filter, Array(&raw, 1, NullArrayDisposer::instance)) {} SocketAddress raw; NetworkAddressImpl abstract; @@ -1276,6 +1999,19 @@ public: } }; } + CapabilityPipe newCapabilityPipe() override { + int fds[2]; + int type = SOCK_STREAM; +#if __linux__ && !__BIONIC__ + type |= SOCK_NONBLOCK | SOCK_CLOEXEC; +#endif + KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds)); + return CapabilityPipe { { + lowLevel.wrapUnixSocketFd(fds[0], NEW_FD_FLAGS), + lowLevel.wrapUnixSocketFd(fds[1], NEW_FD_FLAGS) + } }; + } + Network& getNetwork() override { return network; } @@ -1294,13 +2030,12 @@ public: auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); - auto thread = heap(kj::mvCapture(startFunc, - [threadFd](Function&& startFunc) { + auto thread = heap([threadFd,startFunc=kj::mv(startFunc)]() mutable { LowLevelAsyncIoProviderImpl lowLevel; auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); AsyncIoProviderImpl ioProvider(lowLevel); startFunc(ioProvider, *stream, lowLevel.getWaitScope()); - })); + }); return { kj::mv(thread), kj::mv(pipe) }; } diff --git a/c++/src/kj/async-io-win32.c++ b/c++/src/kj/async-io-win32.c++ index e0f4266229..f7f51803b3 100644 --- a/c++/src/kj/async-io-win32.c++ +++ b/c++/src/kj/async-io-win32.c++ @@ -23,10 +23,10 @@ // For Unix implementation, see async-io-unix.c++. // Request Vista-level APIs. -#define WINVER 0x0600 -#define _WIN32_WINNT 0x0600 +#include #include "async-io.h" +#include "async-io-internal.h" #include "async-win32.h" #include "debug.h" #include "thread.h" @@ -148,10 +148,29 @@ int win32Socketpair(SOCKET socks[2]) { if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR) break; + retryAccept: socks[1] = accept(listener, NULL, NULL); if (socks[1] == -1) break; + // Verify that the client is actually us and not someone else who raced to connect first. + // (This check added by Kenton for security.) + union { + struct sockaddr_in inaddr; + struct sockaddr addr; + } b, c; + socklen_t bAddrlen = sizeof(b.inaddr); + socklen_t cAddrlen = sizeof(b.inaddr); + if (getpeername(socks[1], &b.addr, &bAddrlen) == SOCKET_ERROR) + break; + if (getsockname(socks[0], &c.addr, &cAddrlen) == SOCKET_ERROR) + break; + if (bAddrlen != cAddrlen || memcmp(&b.addr, &c.addr, bAddrlen) != 0) { + // Someone raced to connect first. Ignore. + closesocket(socks[1]); + goto retryAccept; + } + closesocket(listener); return 0; } @@ -169,17 +188,6 @@ int win32Socketpair(SOCKET socks[2]) { namespace { -bool detectWine() { - HMODULE hntdll = GetModuleHandle("ntdll.dll"); - if(hntdll == NULL) return false; - return GetProcAddress(hntdll, "wine_get_version") != nullptr; -} - -bool isWine() { - static bool result = detectWine(); - return result; -} - // ======================================================================================= static constexpr uint NEW_FD_FLAGS = LowLevelAsyncIoProvider::TAKE_OWNERSHIP; @@ -291,6 +299,23 @@ public: }); } + Promise whenWriteDisconnected() override { + // Windows IOCP does not provide a direct, documented way to detect when the socket disconnects + // without actually doing a read or write. However, there is an undocoumented-but-stable + // ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is + // implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible + // to put only one socket into the list and perform the ioctl asynchronously. Here's the + // source code for select() in Windows 2000 (not sure how this became public...): + // + // https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655 + // + // And here's an interesting discussion: https://github.com/python-trio/trio/issues/52 + // + // TODO(someday): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented + // because I added this method for a Linux-only use case. + return NEVER_DONE; + } + void shutdownWrite() override { // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the // Win32AsyncIoProvider interface. @@ -327,6 +352,10 @@ public: *length = socklen; } + Maybe getWin32Handle() const { + return reinterpret_cast(fd); + } + private: Own observer; @@ -524,11 +553,12 @@ public: } static Promise> lookupHost( - LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint); + LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint, + _::NetworkFilter& filter); // Perform a DNS lookup. static Promise> parse( - LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) { + LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) { // TODO(someday): Allow commas in `str`. SocketAddress result; @@ -580,7 +610,8 @@ public: port = strtoul(portText->cStr(), &endptr, 0); if (portText->size() == 0 || *endptr != '\0') { // Not a number. Maybe it's a service name. Fall back to DNS. - return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint); + return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint, + filter); } KJ_REQUIRE(port < 65536, "Port number too large."); } else { @@ -612,35 +643,58 @@ public: addrTarget = &result.addr.inet4.sin_addr; } - // addrPart is not necessarily NUL-terminated so we have to make a copy. :( char buffer[64]; - KJ_REQUIRE(addrPart.size() < sizeof(buffer) - 1, "IP address too long.", addrPart); - memcpy(buffer, addrPart.begin(), addrPart.size()); - buffer[addrPart.size()] = '\0'; - - // OK, parse it! - switch (InetPtonA(af, buffer, addrTarget)) { - case 1: { - // success. - auto array = kj::heapArrayBuilder(1); - array.add(result); - return array.finish(); + if (addrPart.size() < sizeof(buffer) - 1) { + // addrPart is not necessarily NUL-terminated so we have to make a copy. :( + memcpy(buffer, addrPart.begin(), addrPart.size()); + buffer[addrPart.size()] = '\0'; + + // OK, parse it! + switch (InetPtonA(af, buffer, addrTarget)) { + case 1: { + // success. + if (!result.parseAllowedBy(filter)) { + KJ_FAIL_REQUIRE("address family blocked by restrictPeers()"); + return Array(); + } + + auto array = kj::heapArrayBuilder(1); + array.add(result); + return array.finish(); + } + case 0: + // It's apparently not a simple address... fall back to DNS. + break; + default: + KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart); } - case 0: - // It's apparently not a simple address... fall back to DNS. - return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port); - default: - KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart); } + + return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter); } - static SocketAddress getLocalAddress(int sockfd) { + static SocketAddress getLocalAddress(SOCKET sockfd) { SocketAddress result; result.addrlen = sizeof(addr); KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen)); return result; } + static SocketAddress getPeerAddress(SOCKET sockfd) { + SocketAddress result; + result.addrlen = sizeof(addr); + KJ_WINSOCK(getpeername(sockfd, &result.addr.generic, &result.addrlen)); + return result; + } + + bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) { + return filter.shouldAllow(&addr.generic, addrlen); + } + + bool parseAllowedBy(_::NetworkFilter& filter) { + return filter.shouldAllowParse(&addr.generic, addrlen); + } + static SocketAddress getWildcardForFamily(int family) { SocketAddress result; switch (family) { @@ -672,54 +726,6 @@ private: } addr; struct LookupParams; - class LookupReader; -}; - -class SocketAddress::LookupReader { - // Reads SocketAddresses off of a pipe coming from another thread that is performing - // getaddrinfo. - -public: - LookupReader(kj::Own&& thread, kj::Own&& input) - : thread(kj::mv(thread)), input(kj::mv(input)) {} - - ~LookupReader() { - if (thread) thread->detach(); - } - - Promise> read() { - return input->tryRead(¤t, sizeof(current), sizeof(current)).then( - [this](size_t n) -> Promise> { - if (n < sizeof(current)) { - thread = nullptr; - // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check - // anyway. - KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; } - return addresses.releaseAsArray(); - } else { - // getaddrinfo() can return multiple copies of the same address for several reasons. - // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so - // it may return two copies of the same address, one for each type, unless it explicitly - // knows that the service name given is specific to one type. But we can't tell it a type, - // because we don't actually know which one the user wants, and if we specify SOCK_STREAM - // while the user specified a UDP service name then they'll get a resolution error which - // is lame. (At least, I think that's how it works.) - // - // So we instead resort to de-duping results. - if (alreadySeen.insert(current).second) { - addresses.add(current); - } - return read(); - } - }); - } - -private: - kj::Own thread; - kj::Own input; - SocketAddress current; - kj::Vector addresses; - std::set alreadySeen; }; struct SocketAddress::LookupParams { @@ -728,7 +734,8 @@ struct SocketAddress::LookupParams { }; Promise> SocketAddress::lookupHost( - LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) { + LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint, + _::NetworkFilter& filter) { // This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is // the only cross-platform DNS API and it is blocking. // @@ -736,98 +743,98 @@ Promise> SocketAddress::lookupHost( // - Not implemented in Wine. // - Doesn't seem compatible with I/O completion ports, in particular because it's not associated // with a handle. Could signal completion as an APC instead, but that requires the IOCP code - // to use GetQueuedCompletionStatusEx() which it doesn't right now becaues it's not available + // to use GetQueuedCompletionStatusEx() which it doesn't right now because it's not available // in Wine. // - Requires Unicode, for some reason. Only GetAddrInfoExW() supports async, according to the // docs. Never mind that DNS itself is ASCII... - SOCKET fds[2]; - KJ_WINSOCK(_::win32Socketpair(fds)); - - auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS); - - int outFd = fds[1]; - + auto paf = newPromiseAndCrossThreadFulfiller>(); LookupParams params = { kj::mv(host), kj::mv(service) }; - auto thread = heap(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) { - KJ_DEFER(closesocket(outFd)); - - struct addrinfo* list; - int status = getaddrinfo( - params.host == "*" ? nullptr : params.host.cStr(), - params.service == nullptr ? nullptr : params.service.cStr(), - nullptr, &list); - if (status == 0) { - KJ_DEFER(freeaddrinfo(list)); - - struct addrinfo* cur = list; - while (cur != nullptr) { - if (params.service == nullptr) { - switch (cur->ai_addr->sa_family) { - case AF_INET: - ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint); - break; - case AF_INET6: - ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint); - break; - default: - break; + auto thread = heap( + [fulfiller=kj::mv(paf.fulfiller),params=kj::mv(params),portHint]() mutable { + // getaddrinfo() can return multiple copies of the same address for several reasons. + // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so + // it may return two copies of the same address, one for each type, unless it explicitly + // knows that the service name given is specific to one type. But we can't tell it a type, + // because we don't actually know which one the user wants, and if we specify SOCK_STREAM + // while the user specified a UDP service name then they'll get a resolution error which + // is lame. (At least, I think that's how it works.) + // + // So we instead resort to de-duping results. + std::set result; + + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + addrinfo* list; + int status = getaddrinfo( + params.host == "*" ? nullptr : params.host.cStr(), + params.service == nullptr ? nullptr : params.service.cStr(), + nullptr, &list); + if (status == 0) { + KJ_DEFER(freeaddrinfo(list)); + + addrinfo* cur = list; + while (cur != nullptr) { + if (params.service == nullptr) { + switch (cur->ai_addr->sa_family) { + case AF_INET: + ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint); + break; + case AF_INET6: + ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint); + break; + default: + break; + } } - } - SocketAddress addr; - memset(&addr, 0, sizeof(addr)); // mollify valgrind - if (params.host == "*") { - // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo(). - addr.wildcard = true; - addr.addrlen = sizeof(addr.addr.inet6); - addr.addr.inet6.sin6_family = AF_INET6; - switch (cur->ai_addr->sa_family) { - case AF_INET: - addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port; - break; - case AF_INET6: - addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port; - break; - default: - addr.addr.inet6.sin6_port = portHint; - break; + SocketAddress addr; + memset(&addr, 0, sizeof(addr)); // mollify valgrind + if (params.host == "*") { + // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo(). + addr.wildcard = true; + addr.addrlen = sizeof(addr.addr.inet6); + addr.addr.inet6.sin6_family = AF_INET6; + switch (cur->ai_addr->sa_family) { + case AF_INET: + addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port; + break; + case AF_INET6: + addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port; + break; + default: + addr.addr.inet6.sin6_port = portHint; + break; + } + } else { + addr.addrlen = cur->ai_addrlen; + memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen); } - } else { - addr.addrlen = cur->ai_addrlen; - memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen); + result.insert(addr); + cur = cur->ai_next; } - KJ_ASSERT_CAN_MEMCPY(SocketAddress); - - const char* data = reinterpret_cast(&addr); - size_t size = sizeof(addr); - while (size > 0) { - int n; - KJ_WINSOCK(n = send(outFd, data, size, 0)); - data += n; - size -= n; + } else { + KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) { + return; } - - cur = cur->ai_next; } + })) { + fulfiller->reject(kj::mv(*exception)); } else { - KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) { - return; - } + fulfiller->fulfill(KJ_MAP(addr, result) { return addr; }); } - })); + }); - auto reader = heap(kj::mv(thread), kj::mv(input)); - return reader->read().attach(kj::mv(reader)); + return kj::mv(paf.promise); } // ======================================================================================= class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd { public: - FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd, uint flags) - : OwnedFd(fd, flags), eventPort(eventPort), + FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd, + LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) + : OwnedFd(fd, flags), eventPort(eventPort), filter(filter), observer(eventPort.observeIo(reinterpret_cast(fd))), address(SocketAddress::getLocalAddress(fd)) { // In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have @@ -858,8 +865,10 @@ public: } } - return op->onComplete().attach(kj::mv(scratch)).then(mvCapture(result, - [this](Own stream, Win32EventPort::IoResult ioResult) { + return op->onComplete().then( + [this,newFd,stream=kj::mv(result),scratch=kj::mv(scratch)] + (Win32EventPort::IoResult ioResult) mutable + -> Promise> { if (ioResult.errorCode != ERROR_SUCCESS) { KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; } } else { @@ -867,8 +876,19 @@ public: stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, reinterpret_cast(&me), sizeof(me)); } - return kj::mv(stream); - })); + + // Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've + // named `scratch`). However, the format in which it writes these is undocumented, and + // doesn't even match between native Windows and WINE. Apparently it is useless. I don't know + // why they require the buffer to have space for it in the first place. We'll need to call + // getpeername() to get the address. + auto addr = SocketAddress::getPeerAddress(newFd); + if (addr.allowedBy(filter)) { + return Own(kj::mv(stream)); + } else { + return accept(); + } + }); } uint getPort() override { @@ -885,9 +905,15 @@ public: KJ_WINSOCK(::setsockopt(fd, level, option, reinterpret_cast(value), length)); } + void getsockname(struct sockaddr* addr, uint* length) override { + socklen_t socklen = *length; + KJ_WINSOCK(::getsockname(fd, addr, &socklen)); + *length = socklen; + } public: Win32EventPort& eventPort; + LowLevelAsyncIoProvider::NetworkFilter& filter; Own observer; LPFN_ACCEPTEX acceptEx = nullptr; SocketAddress address; @@ -919,12 +945,13 @@ public: SocketAddress::getWildcardForFamily(addr->sa_family).bind(fd); auto connected = result->connect(addr, addrlen); - return connected.then(kj::mvCapture(result, [](Own&& result) { - return kj::mv(result); - })); + return connected.then([result=kj::mv(result)]() mutable -> Own { + return Own(kj::mv(result)); + }); } - Own wrapListenSocketFd(SOCKET fd, uint flags = 0) override { - return heap(eventPort, fd, flags); + Own wrapListenSocketFd( + SOCKET fd, NetworkFilter& filter, uint flags = 0) override { + return heap(eventPort, fd, filter, flags); } Timer& getTimer() override { return eventPort.getTimer(); } @@ -941,12 +968,14 @@ private: class NetworkAddressImpl final: public NetworkAddress { public: - NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array addrs) - : lowLevel(lowLevel), addrs(kj::mv(addrs)) {} + NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, + LowLevelAsyncIoProvider::NetworkFilter& filter, + Array addrs) + : lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {} Promise> connect() override { auto addrsCopy = heapArray(addrs.asPtr()); - auto promise = connectImpl(lowLevel, addrsCopy); + auto promise = connectImpl(lowLevel, filter, addrsCopy); return promise.attach(kj::mv(addrsCopy)); } @@ -974,7 +1003,7 @@ public: KJ_WINSOCK(::listen(fd, SOMAXCONN)); } - return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS); + return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS); } Own bindDatagramPort() override { @@ -998,11 +1027,11 @@ public: addrs[0].bind(fd); } - return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS); + return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS); } Own clone() override { - return kj::heap(lowLevel, kj::heapArray(addrs.asPtr())); + return kj::heap(lowLevel, filter, kj::heapArray(addrs.asPtr())); } String toString() override { @@ -1016,26 +1045,34 @@ public: private: LowLevelAsyncIoProvider& lowLevel; + LowLevelAsyncIoProvider::NetworkFilter& filter; Array addrs; uint counter = 0; static Promise> connectImpl( - LowLevelAsyncIoProvider& lowLevel, ArrayPtr addrs) { + LowLevelAsyncIoProvider& lowLevel, + LowLevelAsyncIoProvider::NetworkFilter& filter, + ArrayPtr addrs) { KJ_ASSERT(addrs.size() > 0); int fd = addrs[0].socket(SOCK_STREAM); - return kj::evalNow([&]() { - return lowLevel.wrapConnectingSocketFd( - fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS); + return kj::evalNow([&]() -> Promise> { + if (!addrs[0].allowedBy(filter)) { + return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()"); + } else { + return lowLevel.wrapConnectingSocketFd( + fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS); + } }).then([](Own&& stream) -> Promise> { // Success, pass along. return kj::mv(stream); - }, [&lowLevel,KJ_CPCAP(addrs)](Exception&& exception) mutable -> Promise> { + }, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable + -> Promise> { // Connect failed. if (addrs.size() > 1) { // Try the next address instead. - return connectImpl(lowLevel, addrs.slice(1, addrs.size())); + return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size())); } else { // No more addresses to try, so propagate the exception. return kj::mv(exception); @@ -1047,25 +1084,35 @@ private: class SocketNetwork final: public Network { public: explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {} + explicit SocketNetwork(SocketNetwork& parent, + kj::ArrayPtr allow, + kj::ArrayPtr deny) + : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {} Promise> parseAddress(StringPtr addr, uint portHint = 0) override { - auto& lowLevelCopy = lowLevel; - return evalLater(mvCapture(heapString(addr), - [&lowLevelCopy,portHint](String&& addr) { - return SocketAddress::parse(lowLevelCopy, addr, portHint); - })).then([&lowLevelCopy](Array addresses) -> Own { - return heap(lowLevelCopy, kj::mv(addresses)); + return evalNow([&]() { + return SocketAddress::parse(lowLevel, addr, portHint, filter); + }).then([this](Array addresses) -> Own { + return heap(lowLevel, filter, kj::mv(addresses)); }); } Own getSockaddr(const void* sockaddr, uint len) override { auto array = kj::heapArrayBuilder(1); array.add(SocketAddress(sockaddr, len)); - return Own(heap(lowLevel, array.finish())); + KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; } + return Own(heap(lowLevel, filter, array.finish())); + } + + Own restrictPeers( + kj::ArrayPtr allow, + kj::ArrayPtr deny = nullptr) override { + return heap(*this, allow, deny); } private: LowLevelAsyncIoProvider& lowLevel; + _::NetworkFilter filter; }; // ======================================================================================= @@ -1107,13 +1154,12 @@ public: auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); - auto thread = heap(kj::mvCapture(startFunc, - [threadFd](Function&& startFunc) { + auto thread = heap([threadFd,startFunc=kj::mv(startFunc)]() mutable { LowLevelAsyncIoProviderImpl lowLevel; auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); AsyncIoProviderImpl ioProvider(lowLevel); startFunc(ioProvider, *stream, lowLevel.getWaitScope()); - })); + }); return { kj::mv(thread), kj::mv(pipe) }; } diff --git a/c++/src/kj/async-io.c++ b/c++/src/kj/async-io.c++ index c7a7cb50de..5fea50d386 100644 --- a/c++/src/kj/async-io.c++ +++ b/c++/src/kj/async-io.c++ @@ -19,9 +19,36 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if _WIN32 +// Request Vista-level APIs. +#include +#endif + #include "async-io.h" +#include "async-io-internal.h" #include "debug.h" #include "vector.h" +#include "io.h" +#include "one-of.h" +#include +#include + +#if _WIN32 +#include +#include +#include +#include +#define inet_pton InetPtonA +#define inet_ntop InetNtopA +#include +#define dup _dup +#else +#include +#include +#include +#include +#include +#endif namespace kj { @@ -31,23 +58,34 @@ Promise AsyncInputStream::read(void* buffer, size_t bytes) { Promise AsyncInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) { return tryRead(buffer, minBytes, maxBytes).then([=](size_t result) { - KJ_REQUIRE(result >= minBytes, "Premature EOF") { + if (result >= minBytes) { + return result; + } else { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "stream disconnected prematurely")); // Pretend we read zeros from the input. memset(reinterpret_cast(buffer) + result, 0, minBytes - result); return minBytes; } - return result; }); } Maybe AsyncInputStream::tryGetLength() { return nullptr; } +void AsyncInputStream::registerAncillaryMessageHandler( + Function)> fn) { + KJ_UNIMPLEMENTED("registerAncillaryMsgHandler is not implemented by this AsyncInputStream"); +} + +Maybe> AsyncInputStream::tryTee(uint64_t) { + return nullptr; +} + namespace { class AsyncPump { public: - AsyncPump(AsyncInputStream& input, AsyncOutputStream& output, uint64_t limit) - : input(input), output(output), limit(limit) {} + AsyncPump(AsyncInputStream& input, AsyncOutputStream& output, uint64_t limit, uint64_t doneSoFar) + : input(input), output(output), limit(limit), doneSoFar(doneSoFar) {} Promise pump() { // TODO(perf): This could be more efficient by reading half a buffer at a time and then @@ -56,7 +94,7 @@ public: uint64_t n = kj::min(limit - doneSoFar, sizeof(buffer)); if (n == 0) return doneSoFar; - return input.tryRead(buffer, 1, sizeof(buffer)) + return input.tryRead(buffer, 1, n) .then([this](size_t amount) -> Promise { if (amount == 0) return doneSoFar; // EOF doneSoFar += amount; @@ -71,12 +109,20 @@ private: AsyncInputStream& input; AsyncOutputStream& output; uint64_t limit; - uint64_t doneSoFar = 0; + uint64_t doneSoFar; byte buffer[4096]; }; } // namespace +Promise unoptimizedPumpTo( + AsyncInputStream& input, AsyncOutputStream& output, uint64_t amount, + uint64_t completedSoFar) { + auto pump = heap(input, output, amount, completedSoFar); + auto promise = pump->pump(); + return promise.attach(kj::mv(pump)); +} + Promise AsyncInputStream::pumpTo( AsyncOutputStream& output, uint64_t amount) { // See if output wants to dispatch on us. @@ -85,9 +131,7 @@ Promise AsyncInputStream::pumpTo( } // OK, fall back to naive approach. - auto pump = heap(*this, output, amount); - auto promise = pump->pump(); - return promise.attach(kj::mv(pump)); + return unoptimizedPumpTo(*this, output, amount); } namespace { @@ -96,17 +140,17 @@ class AllReader { public: AllReader(AsyncInputStream& input): input(input) {} - Promise> readAllBytes() { - return loop().then([this](uint64_t size) { - auto out = heapArray(size); + Promise> readAllBytes(uint64_t limit) { + return loop(limit).then([this, limit](uint64_t headroom) { + auto out = heapArray(limit - headroom); copyInto(out); return out; }); } - Promise readAllText() { - return loop().then([this](uint64_t size) { - auto out = heapArray(size + 1); + Promise readAllText(uint64_t limit) { + return loop(limit).then([this, limit](uint64_t headroom) { + auto out = heapArray(limit - headroom + 1); copyInto(out.slice(0, out.size() - 1).asBytes()); out.back() = '\0'; return String(kj::mv(out)); @@ -117,17 +161,19 @@ private: AsyncInputStream& input; Vector> parts; - Promise loop(uint64_t total = 0) { - auto part = heapArray(4096); + Promise loop(uint64_t limit) { + KJ_REQUIRE(limit > 0, "Reached limit before EOF."); + + auto part = heapArray(kj::min(4096, limit)); auto partPtr = part.asPtr(); parts.add(kj::mv(part)); return input.tryRead(partPtr.begin(), partPtr.size(), partPtr.size()) - .then([this,KJ_CPCAP(partPtr),total](size_t amount) -> Promise { - uint64_t newTotal = total + amount; + .then([this,KJ_CPCAP(partPtr),limit](size_t amount) mutable -> Promise { + limit -= amount; if (amount < partPtr.size()) { - return newTotal; + return limit; } else { - return loop(newTotal); + return loop(limit); } }); } @@ -144,15 +190,15 @@ private: } // namespace -Promise> AsyncInputStream::readAllBytes() { +Promise> AsyncInputStream::readAllBytes(uint64_t limit) { auto reader = kj::heap(*this); - auto promise = reader->readAllBytes(); + auto promise = reader->readAllBytes(limit); return promise.attach(kj::mv(reader)); } -Promise AsyncInputStream::readAllText() { +Promise AsyncInputStream::readAllText(uint64_t limit) { auto reader = kj::heap(*this); - auto promise = reader->readAllText(); + auto promise = reader->readAllText(limit); return promise.attach(kj::mv(reader)); } @@ -161,35 +207,3047 @@ Maybe> AsyncOutputStream::tryPumpFrom( return nullptr; } -void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) { - KJ_UNIMPLEMENTED("Not a socket."); -} -void AsyncIoStream::setsockopt(int level, int option, const void* value, uint length) { - KJ_UNIMPLEMENTED("Not a socket."); -} -void AsyncIoStream::getsockname(struct sockaddr* addr, uint* length) { - KJ_UNIMPLEMENTED("Not a socket."); -} -void AsyncIoStream::getpeername(struct sockaddr* addr, uint* length) { - KJ_UNIMPLEMENTED("Not a socket."); -} -void ConnectionReceiver::getsockopt(int level, int option, void* value, uint* length) { - KJ_UNIMPLEMENTED("Not a socket."); -} -void ConnectionReceiver::setsockopt(int level, int option, const void* value, uint length) { - KJ_UNIMPLEMENTED("Not a socket."); -} -void DatagramPort::getsockopt(int level, int option, void* value, uint* length) { - KJ_UNIMPLEMENTED("Not a socket."); -} -void DatagramPort::setsockopt(int level, int option, const void* value, uint length) { - KJ_UNIMPLEMENTED("Not a socket."); -} -Own NetworkAddress::bindDatagramPort() { - KJ_UNIMPLEMENTED("Datagram sockets not implemented."); -} -Own LowLevelAsyncIoProvider::wrapDatagramSocketFd(Fd fd, uint flags) { - KJ_UNIMPLEMENTED("Datagram sockets not implemented."); +namespace { + +class AsyncPipe final: public AsyncCapabilityStream, public Refcounted { +public: + ~AsyncPipe() noexcept(false) { + KJ_REQUIRE(state == nullptr || ownState.get() != nullptr, + "destroying AsyncPipe with operation still in-progress; probably going to segfault") { + // Don't std::terminate(). + break; + } + } + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + if (minBytes == 0) { + return constPromise(); + } else KJ_IF_MAYBE(s, state) { + return s->tryRead(buffer, minBytes, maxBytes); + } else { + return newAdaptedPromise( + *this, arrayPtr(reinterpret_cast(buffer), maxBytes), minBytes) + .then([](ReadResult r) { return r.byteCount; }); + } + } + + Promise tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + if (minBytes == 0) { + return ReadResult { 0, 0 }; + } else KJ_IF_MAYBE(s, state) { + return s->tryReadWithFds(buffer, minBytes, maxBytes, fdBuffer, maxFds); + } else { + return newAdaptedPromise( + *this, arrayPtr(reinterpret_cast(buffer), maxBytes), minBytes, + kj::arrayPtr(fdBuffer, maxFds)); + } + } + + Promise tryReadWithStreams( + void* buffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + if (minBytes == 0) { + return ReadResult { 0, 0 }; + } else KJ_IF_MAYBE(s, state) { + return s->tryReadWithStreams(buffer, minBytes, maxBytes, streamBuffer, maxStreams); + } else { + return newAdaptedPromise( + *this, arrayPtr(reinterpret_cast(buffer), maxBytes), minBytes, + kj::arrayPtr(streamBuffer, maxStreams)); + } + } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + if (amount == 0) { + return constPromise(); + } else KJ_IF_MAYBE(s, state) { + return s->pumpTo(output, amount); + } else { + return newAdaptedPromise(*this, output, amount); + } + } + + void abortRead() override { + KJ_IF_MAYBE(s, state) { + s->abortRead(); + } else { + ownState = kj::heap(); + state = *ownState; + + readAborted = true; + KJ_IF_MAYBE(f, readAbortFulfiller) { + f->get()->fulfill(); + readAbortFulfiller = nullptr; + } + } + } + + Promise write(const void* buffer, size_t size) override { + if (size == 0) { + return READY_NOW; + } else KJ_IF_MAYBE(s, state) { + return s->write(buffer, size); + } else { + return newAdaptedPromise( + *this, arrayPtr(reinterpret_cast(buffer), size), nullptr); + } + } + + Promise write(ArrayPtr> pieces) override { + while (pieces.size() > 0 && pieces[0].size() == 0) { + pieces = pieces.slice(1, pieces.size()); + } + + if (pieces.size() == 0) { + return kj::READY_NOW; + } else KJ_IF_MAYBE(s, state) { + return s->write(pieces); + } else { + return newAdaptedPromise( + *this, pieces[0], pieces.slice(1, pieces.size())); + } + } + + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { + while (data.size() == 0 && moreData.size() > 0) { + data = moreData.front(); + moreData = moreData.slice(1, moreData.size()); + } + + if (data.size() == 0) { + KJ_REQUIRE(fds.size() == 0, "can't attach FDs to empty message"); + return READY_NOW; + } else KJ_IF_MAYBE(s, state) { + return s->writeWithFds(data, moreData, fds); + } else { + return newAdaptedPromise(*this, data, moreData, fds); + } + } + + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + while (data.size() == 0 && moreData.size() > 0) { + data = moreData.front(); + moreData = moreData.slice(1, moreData.size()); + } + + if (data.size() == 0) { + KJ_REQUIRE(streams.size() == 0, "can't attach capabilities to empty message"); + return READY_NOW; + } else KJ_IF_MAYBE(s, state) { + return s->writeWithStreams(data, moreData, kj::mv(streams)); + } else { + return newAdaptedPromise(*this, data, moreData, kj::mv(streams)); + } + } + + Maybe> tryPumpFrom( + AsyncInputStream& input, uint64_t amount) override { + if (amount == 0) { + return constPromise(); + } else KJ_IF_MAYBE(s, state) { + return s->tryPumpFrom(input, amount); + } else { + return newAdaptedPromise(*this, input, amount); + } + } + + Promise whenWriteDisconnected() override { + if (readAborted) { + return kj::READY_NOW; + } else KJ_IF_MAYBE(p, readAbortPromise) { + return p->addBranch(); + } else { + auto paf = newPromiseAndFulfiller(); + readAbortFulfiller = kj::mv(paf.fulfiller); + auto fork = paf.promise.fork(); + auto result = fork.addBranch(); + readAbortPromise = kj::mv(fork); + return result; + } + } + + void shutdownWrite() override { + KJ_IF_MAYBE(s, state) { + s->shutdownWrite(); + } else { + ownState = kj::heap(); + state = *ownState; + } + } + +private: + Maybe state; + // Object-oriented state! If any method call is blocked waiting on activity from the other end, + // then `state` is non-null and method calls should be forwarded to it. If no calls are + // outstanding, `state` is null. + + kj::Own ownState; + + bool readAborted = false; + Maybe>> readAbortFulfiller = nullptr; + Maybe> readAbortPromise = nullptr; + + void endState(AsyncIoStream& obj) { + KJ_IF_MAYBE(s, state) { + if (s == &obj) { + state = nullptr; + } + } + } + + template + static auto teeExceptionVoid(F& fulfiller) { + // Returns a functor that can be passed as the second parameter to .then() to propagate the + // exception to a given fulfiller. The functor's return type is void. + return [&fulfiller](kj::Exception&& e) { + fulfiller.reject(kj::cp(e)); + kj::throwRecoverableException(kj::mv(e)); + }; + } + template + static auto teeExceptionSize(F& fulfiller) { + // Returns a functor that can be passed as the second parameter to .then() to propagate the + // exception to a given fulfiller. The functor's return type is size_t. + return [&fulfiller](kj::Exception&& e) -> size_t { + fulfiller.reject(kj::cp(e)); + kj::throwRecoverableException(kj::mv(e)); + return 0; + }; + } + template + static auto teeExceptionPromise(F& fulfiller) { + // Returns a functor that can be passed as the second parameter to .then() to propagate the + // exception to a given fulfiller. The functor's return type is Promise. + return [&fulfiller](kj::Exception&& e) -> kj::Promise { + fulfiller.reject(kj::cp(e)); + return kj::mv(e); + }; + } + + class BlockedWrite final: public AsyncCapabilityStream { + // AsyncPipe state when a write() is currently waiting for a corresponding read(). + + public: + BlockedWrite(PromiseFulfiller& fulfiller, AsyncPipe& pipe, + ArrayPtr writeBuffer, + ArrayPtr> morePieces, + kj::OneOf, Array>> capBuffer = {}) + : fulfiller(fulfiller), pipe(pipe), writeBuffer(writeBuffer), morePieces(morePieces), + capBuffer(kj::mv(capBuffer)) { + KJ_REQUIRE(pipe.state == nullptr); + pipe.state = *this; + } + + ~BlockedWrite() noexcept(false) { + pipe.endState(*this); + } + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + KJ_SWITCH_ONEOF(tryReadImpl(buffer, minBytes, maxBytes)) { + KJ_CASE_ONEOF(done, Done) { + return done.result; + } + KJ_CASE_ONEOF(retry, Retry) { + return pipe.tryRead(retry.buffer, retry.minBytes, retry.maxBytes) + .then([n = retry.alreadyRead](size_t amount) { return amount + n; }); + } + } + KJ_UNREACHABLE; + } + + Promise tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + size_t capCount = 0; + { // TODO(cleanup): Remove redundant braces when we update to C++17. + KJ_SWITCH_ONEOF(capBuffer) { + KJ_CASE_ONEOF(fds, ArrayPtr) { + capCount = kj::max(fds.size(), maxFds); + // Unfortunately, we have to dup() each FD, because the writer doesn't release ownership + // by default. + // TODO(perf): Should we add an ownership-releasing version of writeWithFds()? + for (auto i: kj::zeroTo(capCount)) { + int duped; + KJ_SYSCALL(duped = dup(fds[i])); + fdBuffer[i] = kj::AutoCloseFd(fds[i]); + } + fdBuffer += capCount; + maxFds -= capCount; + } + KJ_CASE_ONEOF(streams, Array>) { + if (streams.size() > 0 && maxFds > 0) { + // TODO(someday): We could let people pass a LowLevelAsyncIoProvider to + // newTwoWayPipe() if we wanted to auto-wrap FDs, but does anyone care? + KJ_FAIL_REQUIRE( + "async pipe message was written with streams attached, but corresponding read " + "asked for FDs, and we don't know how to convert here"); + } + } + } + } + + // Drop any unclaimed caps. This mirrors the behavior of unix sockets, where if we didn't + // provide enough buffer space for all the written FDs, the remaining ones are lost. + capBuffer = {}; + + KJ_SWITCH_ONEOF(tryReadImpl(buffer, minBytes, maxBytes)) { + KJ_CASE_ONEOF(done, Done) { + return ReadResult { done.result, capCount }; + } + KJ_CASE_ONEOF(retry, Retry) { + return pipe.tryReadWithFds( + retry.buffer, retry.minBytes, retry.maxBytes, fdBuffer, maxFds) + .then([byteCount = retry.alreadyRead, capCount](ReadResult result) { + result.byteCount += byteCount; + result.capCount += capCount; + return result; + }); + } + } + KJ_UNREACHABLE; + } + + Promise tryReadWithStreams( + void* buffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + size_t capCount = 0; + { // TODO(cleanup): Remove redundant braces when we update to C++17. + KJ_SWITCH_ONEOF(capBuffer) { + KJ_CASE_ONEOF(fds, ArrayPtr) { + if (fds.size() > 0 && maxStreams > 0) { + // TODO(someday): Use AsyncIoStream's `Maybe getFd()` method? + KJ_FAIL_REQUIRE( + "async pipe message was written with FDs attached, but corresponding read " + "asked for streams, and we don't know how to convert here"); + } + } + KJ_CASE_ONEOF(streams, Array>) { + capCount = kj::max(streams.size(), maxStreams); + for (auto i: kj::zeroTo(capCount)) { + streamBuffer[i] = kj::mv(streams[i]); + } + streamBuffer += capCount; + maxStreams -= capCount; + } + } + } + + // Drop any unclaimed caps. This mirrors the behavior of unix sockets, where if we didn't + // provide enough buffer space for all the written FDs, the remaining ones are lost. + capBuffer = {}; + + KJ_SWITCH_ONEOF(tryReadImpl(buffer, minBytes, maxBytes)) { + KJ_CASE_ONEOF(done, Done) { + return ReadResult { done.result, capCount }; + } + KJ_CASE_ONEOF(retry, Retry) { + return pipe.tryReadWithStreams( + retry.buffer, retry.minBytes, retry.maxBytes, streamBuffer, maxStreams) + .then([byteCount = retry.alreadyRead, capCount](ReadResult result) { + result.byteCount += byteCount; + result.capCount += capCount; + return result; + }); + } + } + KJ_UNREACHABLE; + } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + // Note: Pumps drop all capabilities. + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + if (amount < writeBuffer.size()) { + // Consume a portion of the write buffer. + return canceler.wrap(output.write(writeBuffer.begin(), amount) + .then([this,amount]() { + writeBuffer = writeBuffer.slice(amount, writeBuffer.size()); + // We pumped the full amount, so we're done pumping. + return amount; + }, teeExceptionSize(fulfiller))); + } + + // First piece doesn't cover the whole pump. Figure out how many more pieces to add. + uint64_t actual = writeBuffer.size(); + size_t i = 0; + while (i < morePieces.size() && + amount >= actual + morePieces[i].size()) { + actual += morePieces[i++].size(); + } + + // Write the first piece. + auto promise = output.write(writeBuffer.begin(), writeBuffer.size()); + + // Write full pieces as a single gather-write. + if (i > 0) { + auto more = morePieces.slice(0, i); + promise = promise.then([&output,more]() { return output.write(more); }); + } + + if (i == morePieces.size()) { + // This will complete the write. + return canceler.wrap(promise.then([this,&output,amount,actual]() -> Promise { + canceler.release(); + fulfiller.fulfill(); + pipe.endState(*this); + + if (actual == amount) { + // Oh, we had exactly enough. + return actual; + } else { + return pipe.pumpTo(output, amount - actual) + .then([actual](uint64_t actual2) { return actual + actual2; }); + } + }, teeExceptionPromise(fulfiller))); + } else { + // Pump ends mid-piece. Write the last, partial piece. + auto n = amount - actual; + auto splitPiece = morePieces[i]; + KJ_ASSERT(n <= splitPiece.size()); + auto newWriteBuffer = splitPiece.slice(n, splitPiece.size()); + auto newMorePieces = morePieces.slice(i + 1, morePieces.size()); + auto prefix = splitPiece.slice(0, n); + if (prefix.size() > 0) { + promise = promise.then([&output,prefix]() { + return output.write(prefix.begin(), prefix.size()); + }); + } + + return canceler.wrap(promise.then([this,newWriteBuffer,newMorePieces,amount]() { + writeBuffer = newWriteBuffer; + morePieces = newMorePieces; + canceler.release(); + return amount; + }, teeExceptionSize(fulfiller))); + } + } + + void abortRead() override { + canceler.cancel("abortRead() was called"); + fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "read end of pipe was aborted")); + pipe.endState(*this); + pipe.abortRead(); + } + + Promise write(const void* buffer, size_t size) override { + KJ_FAIL_REQUIRE("can't write() again until previous write() completes"); + } + Promise write(ArrayPtr> pieces) override { + KJ_FAIL_REQUIRE("can't write() again until previous write() completes"); + } + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { + KJ_FAIL_REQUIRE("can't write() again until previous write() completes"); + } + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + KJ_FAIL_REQUIRE("can't write() again until previous write() completes"); + } + Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { + KJ_FAIL_REQUIRE("can't tryPumpFrom() again until previous write() completes"); + } + void shutdownWrite() override { + KJ_FAIL_REQUIRE("can't shutdownWrite() until previous write() completes"); + } + + Promise whenWriteDisconnected() override { + KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); + } + + private: + PromiseFulfiller& fulfiller; + AsyncPipe& pipe; + ArrayPtr writeBuffer; + ArrayPtr> morePieces; + kj::OneOf, Array>> capBuffer; + Canceler canceler; + + struct Done { size_t result; }; + struct Retry { void* buffer; size_t minBytes; size_t maxBytes; size_t alreadyRead; }; + + OneOf tryReadImpl(void* readBufferPtr, size_t minBytes, size_t maxBytes) { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + auto readBuffer = arrayPtr(reinterpret_cast(readBufferPtr), maxBytes); + + size_t totalRead = 0; + while (readBuffer.size() >= writeBuffer.size()) { + // The whole current write buffer can be copied into the read buffer. + + { + auto n = writeBuffer.size(); + memcpy(readBuffer.begin(), writeBuffer.begin(), n); + totalRead += n; + readBuffer = readBuffer.slice(n, readBuffer.size()); + } + + if (morePieces.size() == 0) { + // All done writing. + fulfiller.fulfill(); + pipe.endState(*this); + + if (totalRead >= minBytes) { + // Also all done reading. + return Done { totalRead }; + } else { + return Retry { readBuffer.begin(), minBytes - totalRead, readBuffer.size(), totalRead }; + } + } + + writeBuffer = morePieces[0]; + morePieces = morePieces.slice(1, morePieces.size()); + } + + // At this point, the read buffer is smaller than the current write buffer, so we can fill + // it completely. + { + auto n = readBuffer.size(); + memcpy(readBuffer.begin(), writeBuffer.begin(), n); + writeBuffer = writeBuffer.slice(n, writeBuffer.size()); + totalRead += n; + } + + return Done { totalRead }; + } + }; + + class BlockedPumpFrom final: public AsyncCapabilityStream { + // AsyncPipe state when a tryPumpFrom() is currently waiting for a corresponding read(). + + public: + BlockedPumpFrom(PromiseFulfiller& fulfiller, AsyncPipe& pipe, + AsyncInputStream& input, uint64_t amount) + : fulfiller(fulfiller), pipe(pipe), input(input), amount(amount) { + KJ_REQUIRE(pipe.state == nullptr); + pipe.state = *this; + } + + ~BlockedPumpFrom() noexcept(false) { + pipe.endState(*this); + } + + Promise tryRead(void* readBuffer, size_t minBytes, size_t maxBytes) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + auto pumpLeft = amount - pumpedSoFar; + auto min = kj::min(pumpLeft, minBytes); + auto max = kj::min(pumpLeft, maxBytes); + return canceler.wrap(input.tryRead(readBuffer, min, max) + .then([this,readBuffer,minBytes,maxBytes,min](size_t actual) -> kj::Promise { + canceler.release(); + pumpedSoFar += actual; + KJ_ASSERT(pumpedSoFar <= amount); + + if (pumpedSoFar == amount || actual < min) { + // Either we pumped all we wanted or we hit EOF. + fulfiller.fulfill(kj::cp(pumpedSoFar)); + pipe.endState(*this); + } + + if (actual >= minBytes) { + return actual; + } else { + return pipe.tryRead(reinterpret_cast(readBuffer) + actual, + minBytes - actual, maxBytes - actual) + .then([actual](size_t actual2) { return actual + actual2; }); + } + }, teeExceptionPromise(fulfiller))); + } + + Promise tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + // Pumps drop all capabilities, so fall back to regular read. (We don't even know if the + // destination is an AsyncCapabilityStream...) + return tryRead(readBuffer, minBytes, maxBytes) + .then([](size_t n) { return ReadResult { n, 0 }; }); + } + + Promise tryReadWithStreams( + void* readBuffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + // Pumps drop all capabilities, so fall back to regular read. (We don't even know if the + // destination is an AsyncCapabilityStream...) + return tryRead(readBuffer, minBytes, maxBytes) + .then([](size_t n) { return ReadResult { n, 0 }; }); + } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount2) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + auto n = kj::min(amount2, amount - pumpedSoFar); + return canceler.wrap(input.pumpTo(output, n) + .then([this,&output,amount2,n](uint64_t actual) -> Promise { + canceler.release(); + pumpedSoFar += actual; + KJ_ASSERT(pumpedSoFar <= amount); + if (pumpedSoFar == amount || actual < n) { + // Either we pumped all we wanted or we hit EOF. + fulfiller.fulfill(kj::cp(pumpedSoFar)); + pipe.endState(*this); + return pipe.pumpTo(output, amount2 - actual) + .then([actual](uint64_t actual2) { return actual + actual2; }); + } + + // Completed entire pumpTo amount. + KJ_ASSERT(actual == amount2); + return amount2; + }, teeExceptionSize(fulfiller))); + } + + void abortRead() override { + canceler.cancel("abortRead() was called"); + + // The input might have reached EOF, but we haven't detected it yet because we haven't tried + // to read that far. If we had not optimized tryPumpFrom() and instead used the default + // pumpTo() implementation, then the input would not have called write() again once it + // reached EOF, and therefore the abortRead() on the other end would *not* propagate an + // exception! We need the same behavior here. To that end, we need to detect if we're at EOF + // by reading one last byte. + checkEofTask = kj::evalNow([&]() { + static char junk; + return input.tryRead(&junk, 1, 1).then([this](uint64_t n) { + if (n == 0) { + fulfiller.fulfill(kj::cp(pumpedSoFar)); + } else { + fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "read end of pipe was aborted")); + } + }).eagerlyEvaluate([this](kj::Exception&& e) { + fulfiller.reject(kj::mv(e)); + }); + }); + + pipe.endState(*this); + pipe.abortRead(); + } + + Promise write(const void* buffer, size_t size) override { + KJ_FAIL_REQUIRE("can't write() again until previous tryPumpFrom() completes"); + } + Promise write(ArrayPtr> pieces) override { + KJ_FAIL_REQUIRE("can't write() again until previous tryPumpFrom() completes"); + } + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { + KJ_FAIL_REQUIRE("can't write() again until previous tryPumpFrom() completes"); + } + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + KJ_FAIL_REQUIRE("can't write() again until previous tryPumpFrom() completes"); + } + Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { + KJ_FAIL_REQUIRE("can't tryPumpFrom() again until previous tryPumpFrom() completes"); + } + void shutdownWrite() override { + KJ_FAIL_REQUIRE("can't shutdownWrite() until previous tryPumpFrom() completes"); + } + + Promise whenWriteDisconnected() override { + KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); + } + + private: + PromiseFulfiller& fulfiller; + AsyncPipe& pipe; + AsyncInputStream& input; + uint64_t amount; + uint64_t pumpedSoFar = 0; + Canceler canceler; + kj::Promise checkEofTask = nullptr; + }; + + class BlockedRead final: public AsyncCapabilityStream { + // AsyncPipe state when a tryRead() is currently waiting for a corresponding write(). + + public: + BlockedRead( + PromiseFulfiller& fulfiller, AsyncPipe& pipe, + ArrayPtr readBuffer, size_t minBytes, + kj::OneOf, ArrayPtr>> capBuffer = {}) + : fulfiller(fulfiller), pipe(pipe), readBuffer(readBuffer), minBytes(minBytes), + capBuffer(capBuffer) { + KJ_REQUIRE(pipe.state == nullptr); + pipe.state = *this; + } + + ~BlockedRead() noexcept(false) { + pipe.endState(*this); + } + + Promise tryRead(void* readBuffer, size_t minBytes, size_t maxBytes) override { + KJ_FAIL_REQUIRE("can't read() again until previous read() completes"); + } + Promise tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + KJ_FAIL_REQUIRE("can't read() again until previous read() completes"); + } + Promise tryReadWithStreams( + void* readBuffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + KJ_FAIL_REQUIRE("can't read() again until previous read() completes"); + } + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + KJ_FAIL_REQUIRE("can't read() again until previous read() completes"); + } + + void abortRead() override { + canceler.cancel("abortRead() was called"); + fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "read end of pipe was aborted")); + pipe.endState(*this); + pipe.abortRead(); + } + + Promise write(const void* writeBuffer, size_t size) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + auto data = arrayPtr(reinterpret_cast(writeBuffer), size); + KJ_SWITCH_ONEOF(writeImpl(data, nullptr)) { + KJ_CASE_ONEOF(done, Done) { + return READY_NOW; + } + KJ_CASE_ONEOF(retry, Retry) { + KJ_ASSERT(retry.moreData == nullptr); + return pipe.write(retry.data.begin(), retry.data.size()); + } + } + KJ_UNREACHABLE; + } + + Promise write(ArrayPtr> pieces) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + KJ_SWITCH_ONEOF(writeImpl(pieces[0], pieces.slice(1, pieces.size()))) { + KJ_CASE_ONEOF(done, Done) { + return READY_NOW; + } + KJ_CASE_ONEOF(retry, Retry) { + if (retry.data.size() == 0) { + // We exactly finished the current piece, so just issue a write for the remaining + // pieces. + if (retry.moreData.size() == 0) { + // Nothing left. + return READY_NOW; + } else { + // Write remaining pieces. + return pipe.write(retry.moreData); + } + } else { + // Unfortunately we have to execute a separate write() for the remaining part of this + // piece, because we can't modify the pieces array. + auto promise = pipe.write(retry.data.begin(), retry.data.size()); + if (retry.moreData.size() == 0) { + // No more pieces so that's it. + return kj::mv(promise); + } else { + // Also need to write the remaining pieces. + auto& pipeRef = pipe; + return promise.then([pieces=retry.moreData,&pipeRef]() { + return pipeRef.write(pieces); + }); + } + } + } + } + KJ_UNREACHABLE; + } + + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { +#if __GNUC__ && !__clang__ && __GNUC__ >= 7 +// GCC 7 decides the open-brace below is "misleadingly indented" as if it were guarded by the `for` +// that appears in the implementation of KJ_REQUIRE(). Shut up shut up shut up. +#pragma GCC diagnostic ignored "-Wmisleading-indentation" +#endif + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + { // TODO(cleanup): Remove redundant braces when we update to C++17. + KJ_SWITCH_ONEOF(capBuffer) { + KJ_CASE_ONEOF(fdBuffer, ArrayPtr) { + size_t count = kj::max(fdBuffer.size(), fds.size()); + // Unfortunately, we have to dup() each FD, because the writer doesn't release ownership + // by default. + // TODO(perf): Should we add an ownership-releasing version of writeWithFds()? + for (auto i: kj::zeroTo(count)) { + int duped; + KJ_SYSCALL(duped = dup(fds[i])); + fdBuffer[i] = kj::AutoCloseFd(duped); + } + capBuffer = fdBuffer.slice(count, fdBuffer.size()); + readSoFar.capCount += count; + } + KJ_CASE_ONEOF(streamBuffer, ArrayPtr>) { + if (streamBuffer.size() > 0 && fds.size() > 0) { + // TODO(someday): Use AsyncIoStream's `Maybe getFd()` method? + KJ_FAIL_REQUIRE( + "async pipe message was written with FDs attached, but corresponding read " + "asked for streams, and we don't know how to convert here"); + } + } + } + } + + KJ_SWITCH_ONEOF(writeImpl(data, moreData)) { + KJ_CASE_ONEOF(done, Done) { + return READY_NOW; + } + KJ_CASE_ONEOF(retry, Retry) { + // Any leftover fds in `fds` are dropped on the floor, per contract. + // TODO(cleanup): We use another writeWithFds() call here only because it accepts `data` + // and `moreData` directly. After the stream API refactor, we should be able to avoid + // this. + return pipe.writeWithFds(retry.data, retry.moreData, nullptr); + } + } + KJ_UNREACHABLE; + } + + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + { // TODO(cleanup): Remove redundant braces when we update to C++17. + KJ_SWITCH_ONEOF(capBuffer) { + KJ_CASE_ONEOF(fdBuffer, ArrayPtr) { + if (fdBuffer.size() > 0 && streams.size() > 0) { + // TODO(someday): We could let people pass a LowLevelAsyncIoProvider to newTwoWayPipe() + // if we wanted to auto-wrap FDs, but does anyone care? + KJ_FAIL_REQUIRE( + "async pipe message was written with streams attached, but corresponding read " + "asked for FDs, and we don't know how to convert here"); + } + } + KJ_CASE_ONEOF(streamBuffer, ArrayPtr>) { + size_t count = kj::max(streamBuffer.size(), streams.size()); + for (auto i: kj::zeroTo(count)) { + streamBuffer[i] = kj::mv(streams[i]); + } + capBuffer = streamBuffer.slice(count, streamBuffer.size()); + readSoFar.capCount += count; + } + } + } + + KJ_SWITCH_ONEOF(writeImpl(data, moreData)) { + KJ_CASE_ONEOF(done, Done) { + return READY_NOW; + } + KJ_CASE_ONEOF(retry, Retry) { + // Any leftover fds in `fds` are dropped on the floor, per contract. + // TODO(cleanup): We use another writeWithStreams() call here only because it accepts + // `data` and `moreData` directly. After the stream API refactor, we should be able to + // avoid this. + return pipe.writeWithStreams(retry.data, retry.moreData, nullptr); + } + } + KJ_UNREACHABLE; + } + + Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { + // Note: Pumps drop all capabilities. + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + KJ_ASSERT(minBytes > readSoFar.byteCount); + auto minToRead = kj::min(amount, minBytes - readSoFar.byteCount); + auto maxToRead = kj::min(amount, readBuffer.size()); + + return canceler.wrap(input.tryRead(readBuffer.begin(), minToRead, maxToRead) + .then([this,&input,amount](size_t actual) -> Promise { + readBuffer = readBuffer.slice(actual, readBuffer.size()); + readSoFar.byteCount += actual; + + if (readSoFar.byteCount >= minBytes) { + // We've read enough to close out this read (readSoFar >= minBytes). + canceler.release(); + fulfiller.fulfill(kj::cp(readSoFar)); + pipe.endState(*this); + + if (actual < amount) { + // We didn't read as much data as the pump requested, but we did fulfill the read, so + // we don't know whether we reached EOF on the input. We need to continue the pump, + // replacing the BlockedRead state. + return input.pumpTo(pipe, amount - actual) + .then([actual](uint64_t actual2) -> uint64_t { return actual + actual2; }); + } else { + // We pumped as much data as was requested, so we can return that now. + return actual; + } + } else { + // The pump completed without fulfilling the read. This either means that the pump + // reached EOF or the `amount` requested was not enough to satisfy the read in the first + // place. Pumps do not propagate EOF, so either way we want to leave the BlockedRead in + // place waiting for more data. + return actual; + } + }, teeExceptionPromise(fulfiller))); + } + + void shutdownWrite() override { + canceler.cancel("shutdownWrite() was called"); + fulfiller.fulfill(kj::cp(readSoFar)); + pipe.endState(*this); + pipe.shutdownWrite(); + } + + Promise whenWriteDisconnected() override { + KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); + } + + private: + PromiseFulfiller& fulfiller; + AsyncPipe& pipe; + ArrayPtr readBuffer; + size_t minBytes; + kj::OneOf, ArrayPtr>> capBuffer; + ReadResult readSoFar = {0, 0}; + Canceler canceler; + + struct Done {}; + struct Retry { ArrayPtr data; ArrayPtr> moreData; }; + + OneOf writeImpl(ArrayPtr data, + ArrayPtr> moreData) { + for (;;) { + if (data.size() < readBuffer.size()) { + // First write segment consumes a portion of the read buffer but not all of it. + auto n = data.size(); + memcpy(readBuffer.begin(), data.begin(), n); + readSoFar.byteCount += n; + readBuffer = readBuffer.slice(n, readBuffer.size()); + if (moreData.size() == 0) { + // Consumed all written pieces. + if (readSoFar.byteCount >= minBytes) { + // We've read enough to close out this read. + fulfiller.fulfill(kj::cp(readSoFar)); + pipe.endState(*this); + } + return Done(); + } + data = moreData[0]; + moreData = moreData.slice(1, moreData.size()); + // loop + } else { + // First write segment consumes entire read buffer. + auto n = readBuffer.size(); + readSoFar.byteCount += n; + fulfiller.fulfill(kj::cp(readSoFar)); + pipe.endState(*this); + memcpy(readBuffer.begin(), data.begin(), n); + + data = data.slice(n, data.size()); + if (data.size() == 0 && moreData.size() == 0) { + return Done(); + } else { + // Note: Even if `data` is empty, we don't replace it with moreData[0], because the + // retry might need to use write(ArrayPtr>) which doesn't allow + // passing a separate first segment. + return Retry { data, moreData }; + } + } + } + } + }; + + class BlockedPumpTo final: public AsyncCapabilityStream { + // AsyncPipe state when a pumpTo() is currently waiting for a corresponding write(). + + public: + BlockedPumpTo(PromiseFulfiller& fulfiller, AsyncPipe& pipe, + AsyncOutputStream& output, uint64_t amount) + : fulfiller(fulfiller), pipe(pipe), output(output), amount(amount) { + KJ_REQUIRE(pipe.state == nullptr); + pipe.state = *this; + } + + ~BlockedPumpTo() noexcept(false) { + pipe.endState(*this); + } + + Promise tryRead(void* readBuffer, size_t minBytes, size_t maxBytes) override { + KJ_FAIL_REQUIRE("can't read() again until previous pumpTo() completes"); + } + Promise tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + KJ_FAIL_REQUIRE("can't read() again until previous pumpTo() completes"); + } + Promise tryReadWithStreams( + void* readBuffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + KJ_FAIL_REQUIRE("can't read() again until previous pumpTo() completes"); + } + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + KJ_FAIL_REQUIRE("can't read() again until previous pumpTo() completes"); + } + + void abortRead() override { + canceler.cancel("abortRead() was called"); + fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "read end of pipe was aborted")); + pipe.endState(*this); + pipe.abortRead(); + } + + Promise write(const void* writeBuffer, size_t size) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + auto actual = kj::min(amount - pumpedSoFar, size); + return canceler.wrap(output.write(writeBuffer, actual) + .then([this,size,actual,writeBuffer]() -> kj::Promise { + canceler.release(); + pumpedSoFar += actual; + + KJ_ASSERT(pumpedSoFar <= amount); + KJ_ASSERT(actual <= size); + + if (pumpedSoFar == amount) { + // Done with pump. + fulfiller.fulfill(kj::cp(pumpedSoFar)); + pipe.endState(*this); + } + + if (actual == size) { + return kj::READY_NOW; + } else { + KJ_ASSERT(pumpedSoFar == amount); + return pipe.write(reinterpret_cast(writeBuffer) + actual, size - actual); + } + }, teeExceptionPromise(fulfiller))); + } + + Promise write(ArrayPtr> pieces) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + size_t size = 0; + size_t needed = amount - pumpedSoFar; + for (auto i: kj::indices(pieces)) { + if (pieces[i].size() > needed) { + // The pump ends in the middle of this write. + + auto promise = output.write(pieces.slice(0, i)); + + if (needed > 0) { + // The pump includes part of this piece, but not all. Unfortunately we need to split + // writes. + auto partial = pieces[i].slice(0, needed); + promise = promise.then([this,partial]() { + return output.write(partial.begin(), partial.size()); + }); + auto partial2 = pieces[i].slice(needed, pieces[i].size()); + promise = canceler.wrap(promise.then([this,partial2]() { + canceler.release(); + fulfiller.fulfill(kj::cp(amount)); + pipe.endState(*this); + return pipe.write(partial2.begin(), partial2.size()); + }, teeExceptionPromise(fulfiller))); + ++i; + } else { + // The pump ends exactly at the end of a piece, how nice. + promise = canceler.wrap(promise.then([this]() { + canceler.release(); + fulfiller.fulfill(kj::cp(amount)); + pipe.endState(*this); + }, teeExceptionVoid(fulfiller))); + } + + auto remainder = pieces.slice(i, pieces.size()); + if (remainder.size() > 0) { + auto& pipeRef = pipe; + promise = promise.then([&pipeRef,remainder]() { + return pipeRef.write(remainder); + }); + } + + return promise; + } else { + size += pieces[i].size(); + needed -= pieces[i].size(); + } + } + + // Turns out we can forward this whole write. + KJ_ASSERT(size <= amount - pumpedSoFar); + return canceler.wrap(output.write(pieces).then([this,size]() { + pumpedSoFar += size; + KJ_ASSERT(pumpedSoFar <= amount); + if (pumpedSoFar == amount) { + // Done pumping. + canceler.release(); + fulfiller.fulfill(kj::cp(amount)); + pipe.endState(*this); + } + }, teeExceptionVoid(fulfiller))); + } + + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { + // Pumps drop all capabilities, so fall back to regular write(). + + // TODO(cleaunp): After stream API refactor, regular write() methods will take + // (data, moreData) and we can clean this up. + if (moreData.size() == 0) { + return write(data.begin(), data.size()); + } else { + auto pieces = kj::heapArrayBuilder>(moreData.size() + 1); + pieces.add(data); + pieces.addAll(moreData); + return write(pieces.finish()); + } + } + + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + // Pumps drop all capabilities, so fall back to regular write(). + + // TODO(cleaunp): After stream API refactor, regular write() methods will take + // (data, moreData) and we can clean this up. + if (moreData.size() == 0) { + return write(data.begin(), data.size()); + } else { + auto pieces = kj::heapArrayBuilder>(moreData.size() + 1); + pieces.add(data); + pieces.addAll(moreData); + return write(pieces.finish()); + } + } + + Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount2) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + + auto n = kj::min(amount2, amount - pumpedSoFar); + return output.tryPumpFrom(input, n) + .map([&](Promise subPump) { + return canceler.wrap(subPump + .then([this,&input,amount2,n](uint64_t actual) -> Promise { + canceler.release(); + pumpedSoFar += actual; + KJ_ASSERT(pumpedSoFar <= amount); + if (pumpedSoFar == amount) { + fulfiller.fulfill(kj::cp(amount)); + pipe.endState(*this); + } + + KJ_ASSERT(actual <= amount2); + if (actual == amount2) { + // Completed entire tryPumpFrom amount. + return amount2; + } else if (actual < n) { + // Received less than requested, presumably because EOF. + return actual; + } else { + // We received all the bytes that were requested but it didn't complete the pump. + KJ_ASSERT(pumpedSoFar == amount); + return input.pumpTo(pipe, amount2 - actual); + } + }, teeExceptionPromise(fulfiller))); + }); + } + + void shutdownWrite() override { + canceler.cancel("shutdownWrite() was called"); + fulfiller.fulfill(kj::cp(pumpedSoFar)); + pipe.endState(*this); + pipe.shutdownWrite(); + } + + Promise whenWriteDisconnected() override { + KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); + } + + private: + PromiseFulfiller& fulfiller; + AsyncPipe& pipe; + AsyncOutputStream& output; + uint64_t amount; + size_t pumpedSoFar = 0; + Canceler canceler; + }; + + class AbortedRead final: public AsyncCapabilityStream { + // AsyncPipe state when abortRead() has been called. + + public: + Promise tryRead(void* readBufferPtr, size_t minBytes, size_t maxBytes) override { + return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); + } + Promise tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); + } + Promise tryReadWithStreams( + void* readBuffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); + } + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); + } + void abortRead() override { + // ignore repeated abort + } + + Promise write(const void* buffer, size_t size) override { + return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); + } + Promise write(ArrayPtr> pieces) override { + return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); + } + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { + return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); + } + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); + } + Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { + // There might not actually be any data in `input`, in which case a pump wouldn't actually + // write anything and wouldn't fail. + + if (input.tryGetLength().orDefault(1) == 0) { + // Yeah a pump would pump nothing. + return constPromise(); + } else { + // While we *could* just return nullptr here, it would probably then fall back to a normal + // buffered pump, which would allocate a big old buffer just to find there's nothing to + // read. Let's try reading 1 byte to avoid that allocation. + static char c; + return input.tryRead(&c, 1, 1).then([](size_t n) { + if (n == 0) { + // Yay, we're at EOF as hoped. + return uint64_t(0); + } else { + // There was data in the input. The pump would have thrown. + kj::throwRecoverableException( + KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called")); + return uint64_t(0); + } + }); + } + } + void shutdownWrite() override { + // ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped, + // which is not an error even if reads have been aborted. + } + Promise whenWriteDisconnected() override { + KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); + } + }; + + class ShutdownedWrite final: public AsyncCapabilityStream { + // AsyncPipe state when shutdownWrite() has been called. + + public: + Promise tryRead(void* readBufferPtr, size_t minBytes, size_t maxBytes) override { + return constPromise(); + } + Promise tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + return ReadResult { 0, 0 }; + } + Promise tryReadWithStreams( + void* readBuffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + return ReadResult { 0, 0 }; + } + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return constPromise(); + } + void abortRead() override { + // ignore + } + + Promise write(const void* buffer, size_t size) override { + KJ_FAIL_REQUIRE("shutdownWrite() has been called"); + } + Promise write(ArrayPtr> pieces) override { + KJ_FAIL_REQUIRE("shutdownWrite() has been called"); + } + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { + KJ_FAIL_REQUIRE("shutdownWrite() has been called"); + } + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + KJ_FAIL_REQUIRE("shutdownWrite() has been called"); + } + Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { + KJ_FAIL_REQUIRE("shutdownWrite() has been called"); + } + void shutdownWrite() override { + // ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped, + // so it will only be called once anyhow. + } + Promise whenWriteDisconnected() override { + KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); + } + }; +}; + +class PipeReadEnd final: public AsyncInputStream { +public: + PipeReadEnd(kj::Own pipe): pipe(kj::mv(pipe)) {} + ~PipeReadEnd() noexcept(false) { + unwind.catchExceptionsIfUnwinding([&]() { + pipe->abortRead(); + }); + } + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return pipe->tryRead(buffer, minBytes, maxBytes); + } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return pipe->pumpTo(output, amount); + } + +private: + Own pipe; + UnwindDetector unwind; +}; + +class PipeWriteEnd final: public AsyncOutputStream { +public: + PipeWriteEnd(kj::Own pipe): pipe(kj::mv(pipe)) {} + ~PipeWriteEnd() noexcept(false) { + unwind.catchExceptionsIfUnwinding([&]() { + pipe->shutdownWrite(); + }); + } + + Promise write(const void* buffer, size_t size) override { + return pipe->write(buffer, size); + } + + Promise write(ArrayPtr> pieces) override { + return pipe->write(pieces); + } + + Maybe> tryPumpFrom( + AsyncInputStream& input, uint64_t amount) override { + return pipe->tryPumpFrom(input, amount); + } + + Promise whenWriteDisconnected() override { + return pipe->whenWriteDisconnected(); + } + +private: + Own pipe; + UnwindDetector unwind; +}; + +class TwoWayPipeEnd final: public AsyncCapabilityStream { +public: + TwoWayPipeEnd(kj::Own in, kj::Own out) + : in(kj::mv(in)), out(kj::mv(out)) {} + ~TwoWayPipeEnd() noexcept(false) { + unwind.catchExceptionsIfUnwinding([&]() { + out->shutdownWrite(); + in->abortRead(); + }); + } + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return in->tryRead(buffer, minBytes, maxBytes); + } + Promise tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) override { + return in->tryReadWithFds(buffer, minBytes, maxBytes, fdBuffer, maxFds); + } + Promise tryReadWithStreams( + void* buffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) override { + return in->tryReadWithStreams(buffer, minBytes, maxBytes, streamBuffer, maxStreams); + } + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return in->pumpTo(output, amount); + } + void abortRead() override { + in->abortRead(); + } + + Promise write(const void* buffer, size_t size) override { + return out->write(buffer, size); + } + Promise write(ArrayPtr> pieces) override { + return out->write(pieces); + } + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) override { + return out->writeWithFds(data, moreData, fds); + } + Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) override { + return out->writeWithStreams(data, moreData, kj::mv(streams)); + } + Maybe> tryPumpFrom( + AsyncInputStream& input, uint64_t amount) override { + return out->tryPumpFrom(input, amount); + } + Promise whenWriteDisconnected() override { + return out->whenWriteDisconnected(); + } + void shutdownWrite() override { + out->shutdownWrite(); + } + +private: + kj::Own in; + kj::Own out; + UnwindDetector unwind; +}; + +class LimitedInputStream final: public AsyncInputStream { +public: + LimitedInputStream(kj::Own inner, uint64_t limit) + : inner(kj::mv(inner)), limit(limit) { + if (limit == 0) { + this->inner = nullptr; + } + } + + Maybe tryGetLength() override { + return limit; + } + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + if (limit == 0) return constPromise(); + return inner->tryRead(buffer, kj::min(minBytes, limit), kj::min(maxBytes, limit)) + .then([this,minBytes](size_t actual) { + decreaseLimit(actual, minBytes); + return actual; + }); + } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + if (limit == 0) return constPromise(); + auto requested = kj::min(amount, limit); + return inner->pumpTo(output, requested) + .then([this,requested](uint64_t actual) { + decreaseLimit(actual, requested); + return actual; + }); + } + +private: + Own inner; + uint64_t limit; + + void decreaseLimit(uint64_t amount, uint64_t requested) { + KJ_ASSERT(limit >= amount); + limit -= amount; + if (limit == 0) { + inner = nullptr; + } else if (amount < requested) { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, + "fixed-length pipe ended prematurely")); + } + } +}; + +} // namespace + +OneWayPipe newOneWayPipe(kj::Maybe expectedLength) { + auto impl = kj::refcounted(); + Own readEnd = kj::heap(kj::addRef(*impl)); + KJ_IF_MAYBE(l, expectedLength) { + readEnd = kj::heap(kj::mv(readEnd), *l); + } + Own writeEnd = kj::heap(kj::mv(impl)); + return { kj::mv(readEnd), kj::mv(writeEnd) }; +} + +TwoWayPipe newTwoWayPipe() { + auto pipe1 = kj::refcounted(); + auto pipe2 = kj::refcounted(); + auto end1 = kj::heap(kj::addRef(*pipe1), kj::addRef(*pipe2)); + auto end2 = kj::heap(kj::mv(pipe2), kj::mv(pipe1)); + return { { kj::mv(end1), kj::mv(end2) } }; +} + +CapabilityPipe newCapabilityPipe() { + auto pipe1 = kj::refcounted(); + auto pipe2 = kj::refcounted(); + auto end1 = kj::heap(kj::addRef(*pipe1), kj::addRef(*pipe2)); + auto end2 = kj::heap(kj::mv(pipe2), kj::mv(pipe1)); + return { { kj::mv(end1), kj::mv(end2) } }; +} + +namespace { + +class AsyncTee final: public Refcounted { + class Buffer { + public: + Buffer() = default; + + uint64_t consume(ArrayPtr& readBuffer, size_t& minBytes); + // Consume as many bytes as possible, copying them into `readBuffer`. Return the number of bytes + // consumed. + // + // `readBuffer` and `minBytes` are both assigned appropriate new values, such that after any + // call to `consume()`, `readBuffer` will point to the remaining slice of unwritten space, and + // `minBytes` will have been decremented (clamped to zero) by the amount of bytes read. That is, + // the read can be considered fulfilled if `minBytes` is zero after a call to `consume()`. + + Array> asArray(uint64_t minBytes, uint64_t& amount); + // Consume the first `minBytes` of the buffer (or the entire buffer) and return it in an Array + // of ArrayPtrs, suitable for passing to AsyncOutputStream.write(). The outer Array + // owns the underlying data. + + void produce(Array bytes); + // Enqueue a byte array to the end of the buffer list. + + bool empty() const; + uint64_t size() const; + + Buffer clone() const { + size_t size = 0; + for (const auto& buf: bufferList) { + size += buf.size(); + } + auto builder = heapArrayBuilder(size); + for (const auto& buf: bufferList) { + builder.addAll(buf); + } + std::deque> deque; + deque.emplace_back(builder.finish()); + return Buffer{mv(deque)}; + } + + private: + Buffer(std::deque>&& buffer) : bufferList(mv(buffer)) {} + + std::deque> bufferList; + }; + + class Sink; + +public: + class Branch final: public AsyncInputStream { + public: + Branch(Own teeArg): tee(mv(teeArg)) { + tee->branches.add(*this); + } + + Branch(Own teeArg, Branch& cloneFrom) + : tee(mv(teeArg)), buffer(cloneFrom.buffer.clone()) { + tee->branches.add(*this); + } + + ~Branch() noexcept(false) { + KJ_ASSERT(link.isLinked()) { + // Don't std::terminate(). + return; + } + tee->branches.remove(*this); + + KJ_REQUIRE(sink == nullptr, + "destroying tee branch with operation still in-progress; probably going to segfault") { + // Don't std::terminate(). + break; + } + } + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return tee->tryRead(*this, buffer, minBytes, maxBytes); + } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return tee->pumpTo(*this, output, amount); + } + + Maybe tryGetLength() override { + return tee->tryGetLength(*this); + } + + Maybe> tryTee(uint64_t limit) override { + if (tee->getBufferSizeLimit() != limit) { + // Cannot optimize this path as the limit has changed, so we need a new AsyncTee to manage + // the limit. + return nullptr; + } + + return kj::heap(addRef(*tee), *this); + } + + private: + Own tee; + ListLink link; + + Buffer buffer; + Maybe sink; + + friend class AsyncTee; + }; + + explicit AsyncTee(Own inner, uint64_t bufferSizeLimit) + : inner(mv(inner)), bufferSizeLimit(bufferSizeLimit), length(this->inner->tryGetLength()) {} + ~AsyncTee() noexcept(false) { + KJ_ASSERT(branches.size() == 0, "destroying AsyncTee with branch still alive") { + // Don't std::terminate(). + break; + } + } + + Promise tryRead(Branch& branch, void* buffer, size_t minBytes, size_t maxBytes) { + KJ_ASSERT(branch.sink == nullptr); + + // If there is excess data in the buffer for us, slurp that up. + auto readBuffer = arrayPtr(reinterpret_cast(buffer), maxBytes); + auto readSoFar = branch.buffer.consume(readBuffer, minBytes); + + if (minBytes == 0) { + return readSoFar; + } + + if (branch.buffer.empty()) { + KJ_IF_MAYBE(reason, stoppage) { + // Prefer a short read to an exception. The exception prevents the pull loop from adding any + // data to the buffer, so `readSoFar` will be zero the next time someone calls `tryRead()`, + // and the caller will see the exception. + if (reason->is() || readSoFar > 0) { + return readSoFar; + } + return cp(reason->get()); + } + } + + auto promise = newAdaptedPromise( + branch.sink, readBuffer, minBytes, readSoFar); + ensurePulling(); + return mv(promise); + } + + Maybe tryGetLength(Branch& branch) { + return length.map([&branch](uint64_t amount) { + return amount + branch.buffer.size(); + }); + } + + uint64_t getBufferSizeLimit() const { + return bufferSizeLimit; + } + + Promise pumpTo(Branch& branch, AsyncOutputStream& output, uint64_t amount) { + KJ_ASSERT(branch.sink == nullptr); + + if (amount == 0) { + return amount; + } + + if (branch.buffer.empty()) { + KJ_IF_MAYBE(reason, stoppage) { + if (reason->is()) { + return constPromise(); + } + return cp(reason->get()); + } + } + + auto promise = newAdaptedPromise(branch.sink, output, amount); + ensurePulling(); + return mv(promise); + } + +private: + struct Eof {}; + using Stoppage = OneOf; + + class Sink { + public: + struct Need { + // We use uint64_t here because: + // - pumpTo() accepts it as the `amount` parameter. + // - all practical values of tryRead()'s `maxBytes` parameter (a size_t) should also fit into + // a uint64_t, unless we're on a machine with multiple exabytes of memory ... + + uint64_t minBytes = 0; + + uint64_t maxBytes = kj::maxValue; + }; + + virtual Promise fill(Buffer& inBuffer, const Maybe& stoppage) = 0; + // Attempt to fill the sink with bytes andreturn a promise which must resolve before any inner + // read may be attempted. If a sink requires backpressure to be respected, this is how it should + // be communicated. + // + // If the sink is full, it must detach from the tee before the returned promise is resolved. + // + // The returned promise must not result in an exception. + + virtual Need need() = 0; + + virtual void reject(Exception&& exception) = 0; + // Inform this sink of a catastrophic exception and detach it. Regular read exceptions should be + // propagated through `fill()`'s stoppage parameter instead. + }; + + template + class SinkBase: public Sink { + // Registers itself with the tee as a sink on construction, detaches from the tee on + // fulfillment, rejection, or destruction. + // + // A bit of a Frankenstein, avert your eyes. For one thing, it's more of a mixin than a base... + + public: + explicit SinkBase(PromiseFulfiller& fulfiller, Maybe& sinkLink) + : fulfiller(fulfiller), sinkLink(sinkLink) { + KJ_ASSERT(sinkLink == nullptr, "sink initiated with sink already in flight"); + sinkLink = *this; + } + KJ_DISALLOW_COPY_AND_MOVE(SinkBase); + ~SinkBase() noexcept(false) { detach(); } + + void reject(Exception&& exception) override { + // The tee is allowed to reject this sink if it needs to, e.g. to propagate a non-inner read + // exception from the pull loop. Only the derived class is allowed to fulfill() directly, + // though -- the tee must keep calling fill(). + + fulfiller.reject(mv(exception)); + detach(); + } + + protected: + template + void fulfill(U value) { + fulfiller.fulfill(fwd(value)); + detach(); + } + + private: + void detach() { + KJ_IF_MAYBE(sink, sinkLink) { + if (sink == this) { + sinkLink = nullptr; + } + } + } + + PromiseFulfiller& fulfiller; + Maybe& sinkLink; + }; + + class ReadSink final: public SinkBase { + public: + explicit ReadSink(PromiseFulfiller& fulfiller, Maybe& registration, + ArrayPtr buffer, size_t minBytes, size_t readSoFar) + : SinkBase(fulfiller, registration), buffer(buffer), + minBytes(minBytes), readSoFar(readSoFar) {} + + Promise fill(Buffer& inBuffer, const Maybe& stoppage) override { + auto amount = inBuffer.consume(buffer, minBytes); + readSoFar += amount; + + if (minBytes == 0) { + // We satisfied the read request. + fulfill(readSoFar); + return READY_NOW; + } + + if (amount == 0 && inBuffer.empty()) { + // We made no progress on the read request and the buffer is tapped out. + KJ_IF_MAYBE(reason, stoppage) { + if (reason->is() || readSoFar > 0) { + // Prefer short read to exception. + fulfill(readSoFar); + } else { + reject(cp(reason->get())); + } + return READY_NOW; + } + } + + return READY_NOW; + } + + Need need() override { return Need { minBytes, buffer.size() }; } + + private: + ArrayPtr buffer; + size_t minBytes; + // Arguments to the outer tryRead() call, sliced/decremented after every buffer consumption. + + size_t readSoFar; + // End result of the outer tryRead(). + }; + + class PumpSink final: public SinkBase { + public: + explicit PumpSink(PromiseFulfiller& fulfiller, Maybe& registration, + AsyncOutputStream& output, uint64_t limit) + : SinkBase(fulfiller, registration), output(output), limit(limit) {} + + ~PumpSink() noexcept(false) { + canceler.cancel("This pump has been canceled."); + } + + Promise fill(Buffer& inBuffer, const Maybe& stoppage) override { + KJ_ASSERT(limit > 0); + + uint64_t amount = 0; + + // TODO(someday): This consumes data from the buffer, but we cannot know if the stream to + // which we're pumping will accept it until after the write() promise completes. If the + // write() promise rejects, we lose this data. We should consume the data from the buffer + // only after successful writes. + auto writeBuffer = inBuffer.asArray(limit, amount); + KJ_ASSERT(limit >= amount); + if (amount > 0) { + Promise promise = kj::evalNow([&]() { + return output.write(writeBuffer).attach(mv(writeBuffer)); + }).then([this, amount]() { + limit -= amount; + pumpedSoFar += amount; + if (limit == 0) { + fulfill(pumpedSoFar); + } + }).eagerlyEvaluate([this](Exception&& exception) { + reject(mv(exception)); + }); + + return canceler.wrap(mv(promise)).catch_([](kj::Exception&&) {}); + } else KJ_IF_MAYBE(reason, stoppage) { + if (reason->is()) { + // Unlike in the read case, it makes more sense to immediately propagate exceptions to the + // pump promise rather than show it a "short pump". + fulfill(pumpedSoFar); + } else { + reject(cp(reason->get())); + } + } + + return READY_NOW; + } + + Need need() override { return Need { 1, limit }; } + + private: + AsyncOutputStream& output; + uint64_t limit; + // Arguments to the outer pumpTo() call, decremented after every buffer consumption. + // + // Equal to zero once fulfiller has been fulfilled/rejected. + + uint64_t pumpedSoFar = 0; + // End result of the outer pumpTo(). + + Canceler canceler; + // When the pump is canceled, we also need to cancel any write operations in flight. + }; + + // ===================================================================================== + + Maybe analyzeSinks() { + // Return nullptr if there are no sinks at all. Otherwise, return the largest `minBytes` and the + // smallest `maxBytes` requested by any sink. The pull loop will use these values to calculate + // the optimal buffer size for the next inner read, so that a minimum amount of data is buffered + // at any given time. + + uint64_t minBytes = 0; + uint64_t maxBytes = kj::maxValue; + + uint nSinks = 0; + + for (auto& branch: branches) { + KJ_IF_MAYBE(sink, branch.sink) { + ++nSinks; + auto need = sink->need(); + minBytes = kj::max(minBytes, need.minBytes); + maxBytes = kj::min(maxBytes, need.maxBytes); + } + } + + if (nSinks > 0) { + KJ_ASSERT(minBytes > 0); + KJ_ASSERT(maxBytes > 0, "sink was filled but did not detach"); + + // Sinks may report non-overlapping needs. + maxBytes = kj::max(minBytes, maxBytes); + + return Sink::Need { minBytes, maxBytes }; + } + + // No active sinks. + return nullptr; + } + + void ensurePulling() { + if (!pulling) { + pulling = true; + UnwindDetector unwind; + KJ_DEFER(if (unwind.isUnwinding()) pulling = false); + pullPromise = pull(); + } + } + + Promise pull() { + return pullLoop().eagerlyEvaluate([this](Exception&& exception) { + // Exception from our loop, not from inner tryRead(). Something is broken; tell everybody! + pulling = false; + for (auto& branch: branches) { + KJ_IF_MAYBE(sink, branch.sink) { + sink->reject(KJ_EXCEPTION(FAILED, "Exception in tee loop", exception)); + } + } + }); + } + + constexpr static size_t MAX_BLOCK_SIZE = 1 << 14; // 16k + + Own inner; + const uint64_t bufferSizeLimit = kj::maxValue; + Maybe length; + List branches; + Maybe stoppage; + Promise pullPromise = READY_NOW; + bool pulling = false; + +private: + Promise pullLoop() { + // Use evalLater() so that two pump sinks added on the same turn of the event loop will not + // cause buffering. + return evalLater([this] { + // Attempt to fill any sinks that exist. + + Vector> promises; + + for (auto& branch: branches) { + KJ_IF_MAYBE(sink, branch.sink) { + promises.add(sink->fill(branch.buffer, stoppage)); + } + } + + // Respect the greatest of the sinks' backpressures. + return joinPromises(promises.releaseAsArray()); + }).then([this]() -> Promise { + // Check to see whether we need to perform an inner read. + + auto need = analyzeSinks(); + + if (need == nullptr) { + // No more sinks, stop pulling. + pulling = false; + return READY_NOW; + } + + if (stoppage != nullptr) { + // We're eof or errored, don't read, but loop so we can fill the sink(s). + return pullLoop(); + } + + auto& n = KJ_ASSERT_NONNULL(need); + + KJ_ASSERT(n.minBytes > 0); + + // We must perform an inner read. + + // We'd prefer not to explode our buffer, if that's cool. We cap `maxBytes` to the buffer size + // limit or our builtin MAX_BLOCK_SIZE, whichever is smaller. But, we make sure `maxBytes` is + // still >= `minBytes`. + n.maxBytes = kj::min(n.maxBytes, MAX_BLOCK_SIZE); + n.maxBytes = kj::min(n.maxBytes, bufferSizeLimit); + n.maxBytes = kj::max(n.minBytes, n.maxBytes); + for (auto& branch: branches) { + // TODO(perf): buffer.size() is O(n) where n = # of individual heap-allocated byte arrays. + if (branch.buffer.size() + n.maxBytes > bufferSizeLimit) { + stoppage = Stoppage(KJ_EXCEPTION(FAILED, "tee buffer size limit exceeded")); + return pullLoop(); + } + } + auto heapBuffer = heapArray(n.maxBytes); + + // gcc 4.9 quirk: If I don't hoist this into a separate variable and instead call + // + // inner->tryRead(heapBuffer.begin(), n.minBytes, heapBuffer.size()) + // + // `heapBuffer` seems to get moved into the lambda capture before the arguments to `tryRead()` + // are evaluated, meaning `inner` sees a nullptr destination. Bizarrely, `inner` sees the + // correct value for `heapBuffer.size()`... I dunno, man. + auto destination = heapBuffer.begin(); + + return kj::evalNow([&]() { return inner->tryRead(destination, n.minBytes, n.maxBytes); }) + .then([this, heapBuffer = mv(heapBuffer), minBytes = n.minBytes](size_t amount) mutable + -> Promise { + length = length.map([amount](uint64_t n) { + KJ_ASSERT(n >= amount); + return n - amount; + }); + + if (amount < heapBuffer.size()) { + heapBuffer = heapBuffer.slice(0, amount).attach(mv(heapBuffer)); + } + + KJ_ASSERT(stoppage == nullptr); + Maybe> bufferPtr = nullptr; + for (auto& branch: branches) { + // Prefer to move the buffer into the receiving branch's deque, rather than memcpy. + // + // TODO(perf): For the 2-branch case, this is fine, since the majority of the time + // only one buffer will be in use. If we generalize to the n-branch case, this would + // become memcpy-heavy. + KJ_IF_MAYBE(ptr, bufferPtr) { + branch.buffer.produce(heapArray(*ptr)); + } else { + bufferPtr = ArrayPtr(heapBuffer); + branch.buffer.produce(mv(heapBuffer)); + } + } + + if (amount < minBytes) { + // Short read, EOF. + stoppage = Stoppage(Eof()); + } + + return pullLoop(); + }, [this](Exception&& exception) { + // Exception from the inner tryRead(). Propagate. + stoppage = Stoppage(mv(exception)); + return pullLoop(); + }); + }); + } +}; + +constexpr size_t AsyncTee::MAX_BLOCK_SIZE; + +uint64_t AsyncTee::Buffer::consume(ArrayPtr& readBuffer, size_t& minBytes) { + uint64_t totalAmount = 0; + + while (readBuffer.size() > 0 && !bufferList.empty()) { + auto& bytes = bufferList.front(); + auto amount = kj::min(bytes.size(), readBuffer.size()); + memcpy(readBuffer.begin(), bytes.begin(), amount); + totalAmount += amount; + + readBuffer = readBuffer.slice(amount, readBuffer.size()); + minBytes -= kj::min(amount, minBytes); + + if (amount == bytes.size()) { + bufferList.pop_front(); + } else { + bytes = heapArray(bytes.slice(amount, bytes.size())); + return totalAmount; + } + } + + return totalAmount; +} + +void AsyncTee::Buffer::produce(Array bytes) { + bufferList.push_back(mv(bytes)); +} + +Array> AsyncTee::Buffer::asArray( + uint64_t maxBytes, uint64_t& amount) { + amount = 0; + + Vector> buffers; + Vector> ownBuffers; + + while (maxBytes > 0 && !bufferList.empty()) { + auto& bytes = bufferList.front(); + + if (bytes.size() <= maxBytes) { + amount += bytes.size(); + maxBytes -= bytes.size(); + + buffers.add(bytes); + ownBuffers.add(mv(bytes)); + + bufferList.pop_front(); + } else { + auto ownBytes = heapArray(bytes.slice(0, maxBytes)); + buffers.add(ownBytes); + ownBuffers.add(mv(ownBytes)); + + bytes = heapArray(bytes.slice(maxBytes, bytes.size())); + + amount += maxBytes; + maxBytes = 0; + } + } + + + if (buffers.size() > 0) { + return buffers.releaseAsArray().attach(mv(ownBuffers)); + } + + return {}; +} + +bool AsyncTee::Buffer::empty() const { + return bufferList.empty(); +} + +uint64_t AsyncTee::Buffer::size() const { + uint64_t result = 0; + + for (auto& bytes: bufferList) { + result += bytes.size(); + } + + return result; +} + +} // namespace + +Tee newTee(Own input, uint64_t limit) { + KJ_IF_MAYBE(t, input->tryTee(limit)) { + return { { mv(input), mv(*t) }}; + } + + auto impl = refcounted(mv(input), limit); + Own branch1 = heap(addRef(*impl)); + Own branch2 = heap(mv(impl)); + return { { mv(branch1), mv(branch2) } }; +} + +namespace { + +class PromisedAsyncIoStream final: public kj::AsyncIoStream, private kj::TaskSet::ErrorHandler { + // An AsyncIoStream which waits for a promise to resolve then forwards all calls to the promised + // stream. + +public: + PromisedAsyncIoStream(kj::Promise> promise) + : promise(promise.then([this](kj::Own result) { + stream = kj::mv(result); + }).fork()), + tasks(*this) {} + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + KJ_IF_MAYBE(s, stream) { + return s->get()->read(buffer, minBytes, maxBytes); + } else { + return promise.addBranch().then([this,buffer,minBytes,maxBytes]() { + return KJ_ASSERT_NONNULL(stream)->read(buffer, minBytes, maxBytes); + }); + } + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + KJ_IF_MAYBE(s, stream) { + return s->get()->tryRead(buffer, minBytes, maxBytes); + } else { + return promise.addBranch().then([this,buffer,minBytes,maxBytes]() { + return KJ_ASSERT_NONNULL(stream)->tryRead(buffer, minBytes, maxBytes); + }); + } + } + + kj::Maybe tryGetLength() override { + KJ_IF_MAYBE(s, stream) { + return s->get()->tryGetLength(); + } else { + return nullptr; + } + } + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + KJ_IF_MAYBE(s, stream) { + return s->get()->pumpTo(output, amount); + } else { + return promise.addBranch().then([this,&output,amount]() { + return KJ_ASSERT_NONNULL(stream)->pumpTo(output, amount); + }); + } + } + + kj::Promise write(const void* buffer, size_t size) override { + KJ_IF_MAYBE(s, stream) { + return s->get()->write(buffer, size); + } else { + return promise.addBranch().then([this,buffer,size]() { + return KJ_ASSERT_NONNULL(stream)->write(buffer, size); + }); + } + } + kj::Promise write(kj::ArrayPtr> pieces) override { + KJ_IF_MAYBE(s, stream) { + return s->get()->write(pieces); + } else { + return promise.addBranch().then([this,pieces]() { + return KJ_ASSERT_NONNULL(stream)->write(pieces); + }); + } + } + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + KJ_IF_MAYBE(s, stream) { + // Call input.pumpTo() on the resolved stream instead, so that if it does some dynamic_casts + // or whatnot to detect stream types it can retry those on the inner stream. + return input.pumpTo(**s, amount); + } else { + return promise.addBranch().then([this,&input,amount]() { + // Here we actually have no choice but to call input.pumpTo() because if we called + // tryPumpFrom(input, amount) and it returned nullptr, what would we do? It's too late for + // us to return nullptr. But the thing about dynamic_cast also applies. + return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount); + }); + } + } + + Promise whenWriteDisconnected() override { + KJ_IF_MAYBE(s, stream) { + return s->get()->whenWriteDisconnected(); + } else { + return promise.addBranch().then([this]() { + return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected(); + }, [](kj::Exception&& e) -> kj::Promise { + if (e.getType() == kj::Exception::Type::DISCONNECTED) { + return kj::READY_NOW; + } else { + return kj::mv(e); + } + }); + } + } + + void shutdownWrite() override { + KJ_IF_MAYBE(s, stream) { + return s->get()->shutdownWrite(); + } else { + tasks.add(promise.addBranch().then([this]() { + return KJ_ASSERT_NONNULL(stream)->shutdownWrite(); + })); + } + } + + void abortRead() override { + KJ_IF_MAYBE(s, stream) { + return s->get()->abortRead(); + } else { + tasks.add(promise.addBranch().then([this]() { + return KJ_ASSERT_NONNULL(stream)->abortRead(); + })); + } + } + + kj::Maybe getFd() const override { + KJ_IF_MAYBE(s, stream) { + return s->get()->getFd(); + } else { + return nullptr; + } + } + +private: + kj::ForkedPromise promise; + kj::Maybe> stream; + kj::TaskSet tasks; + + void taskFailed(kj::Exception&& exception) override { + KJ_LOG(ERROR, exception); + } +}; + +class PromisedAsyncOutputStream final: public kj::AsyncOutputStream { + // An AsyncOutputStream which waits for a promise to resolve then forwards all calls to the + // promised stream. + // + // TODO(cleanup): Can this share implementation with PromiseIoStream? Seems hard. + +public: + PromisedAsyncOutputStream(kj::Promise> promise) + : promise(promise.then([this](kj::Own result) { + stream = kj::mv(result); + }).fork()) {} + + kj::Promise write(const void* buffer, size_t size) override { + KJ_IF_MAYBE(s, stream) { + return s->get()->write(buffer, size); + } else { + return promise.addBranch().then([this,buffer,size]() { + return KJ_ASSERT_NONNULL(stream)->write(buffer, size); + }); + } + } + kj::Promise write(kj::ArrayPtr> pieces) override { + KJ_IF_MAYBE(s, stream) { + return s->get()->write(pieces); + } else { + return promise.addBranch().then([this,pieces]() { + return KJ_ASSERT_NONNULL(stream)->write(pieces); + }); + } + } + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + KJ_IF_MAYBE(s, stream) { + return s->get()->tryPumpFrom(input, amount); + } else { + return promise.addBranch().then([this,&input,amount]() { + // Call input.pumpTo() on the resolved stream instead. + return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount); + }); + } + } + + Promise whenWriteDisconnected() override { + KJ_IF_MAYBE(s, stream) { + return s->get()->whenWriteDisconnected(); + } else { + return promise.addBranch().then([this]() { + return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected(); + }, [](kj::Exception&& e) -> kj::Promise { + if (e.getType() == kj::Exception::Type::DISCONNECTED) { + return kj::READY_NOW; + } else { + return kj::mv(e); + } + }); + } + } + +private: + kj::ForkedPromise promise; + kj::Maybe> stream; +}; + +} // namespace + +Own newPromisedStream(Promise> promise) { + return heap(kj::mv(promise)); +} +Own newPromisedStream(Promise> promise) { + return heap(kj::mv(promise)); +} + +Promise AsyncCapabilityStream::writeWithFds( + ArrayPtr data, ArrayPtr> moreData, + ArrayPtr fds) { + // HACK: AutoCloseFd actually contains an `int` under the hood. We can reinterpret_cast to avoid + // unnecessary memory allocation. + static_assert(sizeof(AutoCloseFd) == sizeof(int), "this optimization won't work"); + auto intArray = arrayPtr(reinterpret_cast(fds.begin()), fds.size()); + + // Be extra-paranoid about aliasing rules by injecting a compiler barrier here. Probably + // not necessary but also probably doesn't hurt. +#if _MSC_VER + _ReadWriteBarrier(); +#else + __asm__ __volatile__("": : :"memory"); +#endif + + return writeWithFds(data, moreData, intArray); +} + +Promise> AsyncCapabilityStream::receiveStream() { + return tryReceiveStream() + .then([](Maybe>&& result) + -> Promise> { + KJ_IF_MAYBE(r, result) { + return kj::mv(*r); + } else { + return KJ_EXCEPTION(FAILED, "EOF when expecting to receive capability"); + } + }); +} + +kj::Promise>> AsyncCapabilityStream::tryReceiveStream() { + struct ResultHolder { + byte b; + Own stream; + }; + auto result = kj::heap(); + auto promise = tryReadWithStreams(&result->b, 1, 1, &result->stream, 1); + return promise.then([result = kj::mv(result)](ReadResult actual) mutable + -> Maybe> { + if (actual.byteCount == 0) { + return nullptr; + } + + KJ_REQUIRE(actual.capCount == 1, + "expected to receive a capability (e.g. file descriptor via SCM_RIGHTS), but didn't") { + return nullptr; + } + + return kj::mv(result->stream); + }); +} + +Promise AsyncCapabilityStream::sendStream(Own stream) { + static constexpr byte b = 0; + auto streams = kj::heapArray>(1); + streams[0] = kj::mv(stream); + return writeWithStreams(arrayPtr(&b, 1), nullptr, kj::mv(streams)); +} + +Promise AsyncCapabilityStream::receiveFd() { + return tryReceiveFd().then([](Maybe&& result) -> Promise { + KJ_IF_MAYBE(r, result) { + return kj::mv(*r); + } else { + return KJ_EXCEPTION(FAILED, "EOF when expecting to receive capability"); + } + }); +} + +kj::Promise> AsyncCapabilityStream::tryReceiveFd() { + struct ResultHolder { + byte b; + AutoCloseFd fd; + }; + auto result = kj::heap(); + auto promise = tryReadWithFds(&result->b, 1, 1, &result->fd, 1); + return promise.then([result = kj::mv(result)](ReadResult actual) mutable + -> Maybe { + if (actual.byteCount == 0) { + return nullptr; + } + + KJ_REQUIRE(actual.capCount == 1, + "expected to receive a file descriptor (e.g. via SCM_RIGHTS), but didn't") { + return nullptr; + } + + return kj::mv(result->fd); + }); +} + +Promise AsyncCapabilityStream::sendFd(int fd) { + static constexpr byte b = 0; + auto fds = kj::heapArray(1); + fds[0] = fd; + auto promise = writeWithFds(arrayPtr(&b, 1), nullptr, fds); + return promise.attach(kj::mv(fds)); +} + +void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) { + KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } +} +void AsyncIoStream::setsockopt(int level, int option, const void* value, uint length) { + KJ_UNIMPLEMENTED("Not a socket.") { break; } +} +void AsyncIoStream::getsockname(struct sockaddr* addr, uint* length) { + KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } +} +void AsyncIoStream::getpeername(struct sockaddr* addr, uint* length) { + KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } +} +void ConnectionReceiver::getsockopt(int level, int option, void* value, uint* length) { + KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } +} +void ConnectionReceiver::setsockopt(int level, int option, const void* value, uint length) { + KJ_UNIMPLEMENTED("Not a socket.") { break; } +} +void ConnectionReceiver::getsockname(struct sockaddr* addr, uint* length) { + KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } +} +void DatagramPort::getsockopt(int level, int option, void* value, uint* length) { + KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } +} +void DatagramPort::setsockopt(int level, int option, const void* value, uint length) { + KJ_UNIMPLEMENTED("Not a socket.") { break; } +} +Own NetworkAddress::bindDatagramPort() { + KJ_UNIMPLEMENTED("Datagram sockets not implemented."); +} +Own LowLevelAsyncIoProvider::wrapDatagramSocketFd( + Fd fd, LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) { + KJ_UNIMPLEMENTED("Datagram sockets not implemented."); +} +#if !_WIN32 +Own LowLevelAsyncIoProvider::wrapUnixSocketFd(Fd fd, uint flags) { + KJ_UNIMPLEMENTED("Unix socket with FD passing not implemented."); +} +#endif +CapabilityPipe AsyncIoProvider::newCapabilityPipe() { + KJ_UNIMPLEMENTED("Capability pipes not implemented."); +} + +Own LowLevelAsyncIoProvider::wrapInputFd(OwnFd&& fd, uint flags) { + return wrapInputFd(reinterpret_cast(fd.release()), flags | TAKE_OWNERSHIP); +} +Own LowLevelAsyncIoProvider::wrapOutputFd(OwnFd&& fd, uint flags) { + return wrapOutputFd(reinterpret_cast(fd.release()), flags | TAKE_OWNERSHIP); +} +Own LowLevelAsyncIoProvider::wrapSocketFd(OwnFd&& fd, uint flags) { + return wrapSocketFd(reinterpret_cast(fd.release()), flags | TAKE_OWNERSHIP); +} +#if !_WIN32 +Own LowLevelAsyncIoProvider::wrapUnixSocketFd(OwnFd&& fd, uint flags) { + return wrapUnixSocketFd(reinterpret_cast(fd.release()), flags | TAKE_OWNERSHIP); +} +#endif +Promise> LowLevelAsyncIoProvider::wrapConnectingSocketFd( + OwnFd&& fd, const struct sockaddr* addr, uint addrlen, uint flags) { + return wrapConnectingSocketFd(reinterpret_cast(fd.release()), addr, addrlen, + flags | TAKE_OWNERSHIP); +} +Own LowLevelAsyncIoProvider::wrapListenSocketFd( + OwnFd&& fd, NetworkFilter& filter, uint flags) { + return wrapListenSocketFd(reinterpret_cast(fd.release()), filter, flags | TAKE_OWNERSHIP); +} +Own LowLevelAsyncIoProvider::wrapListenSocketFd(OwnFd&& fd, uint flags) { + return wrapListenSocketFd(reinterpret_cast(fd.release()), flags | TAKE_OWNERSHIP); +} +Own LowLevelAsyncIoProvider::wrapDatagramSocketFd( + OwnFd&& fd, NetworkFilter& filter, uint flags) { + return wrapDatagramSocketFd(reinterpret_cast(fd.release()), filter, flags | TAKE_OWNERSHIP); +} +Own LowLevelAsyncIoProvider::wrapDatagramSocketFd(OwnFd&& fd, uint flags) { + return wrapDatagramSocketFd(reinterpret_cast(fd.release()), flags | TAKE_OWNERSHIP); +} + +namespace { + +class DummyNetworkFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter { +public: + bool shouldAllow(const struct sockaddr* addr, uint addrlen) override { return true; } +}; + +} // namespace + +LowLevelAsyncIoProvider::NetworkFilter& LowLevelAsyncIoProvider::NetworkFilter::getAllAllowed() { + static DummyNetworkFilter result; + return result; +} + +// ======================================================================================= +// Convenience adapters. + +Promise> CapabilityStreamConnectionReceiver::accept() { + return inner.receiveStream() + .then([](Own&& stream) -> Own { + return kj::mv(stream); + }); +} + +Promise CapabilityStreamConnectionReceiver::acceptAuthenticated() { + return accept().then([](Own&& stream) { + return AuthenticatedStream { kj::mv(stream), UnknownPeerIdentity::newInstance() }; + }); +} + +uint CapabilityStreamConnectionReceiver::getPort() { + return 0; +} + +Promise> CapabilityStreamNetworkAddress::connect() { + CapabilityPipe pipe; + KJ_IF_MAYBE(p, provider) { + pipe = p->newCapabilityPipe(); + } else { + pipe = kj::newCapabilityPipe(); + } + auto result = kj::mv(pipe.ends[0]); + return inner.sendStream(kj::mv(pipe.ends[1])) + .then([result=kj::mv(result)]() mutable { + return Own(kj::mv(result)); + }); +} +Promise CapabilityStreamNetworkAddress::connectAuthenticated() { + return connect().then([](Own&& stream) { + return AuthenticatedStream { kj::mv(stream), UnknownPeerIdentity::newInstance() }; + }); +} +Own CapabilityStreamNetworkAddress::listen() { + return kj::heap(inner); +} + +Own CapabilityStreamNetworkAddress::clone() { + KJ_UNIMPLEMENTED("can't clone CapabilityStreamNetworkAddress"); +} +String CapabilityStreamNetworkAddress::toString() { + return kj::str(""); +} + +Promise FileInputStream::tryRead(void* buffer, size_t minBytes, size_t maxBytes) { + // Note that our contract with `minBytes` is that we should only return fewer than `minBytes` on + // EOF. A file read will only produce fewer than the requested number of bytes if EOF was reached. + // `minBytes` cannot be greater than `maxBytes`. So, this read satisfies the `minBytes` + // requirement. + size_t result = file.read(offset, arrayPtr(reinterpret_cast(buffer), maxBytes)); + offset += result; + return result; +} + +Maybe FileInputStream::tryGetLength() { + uint64_t size = file.stat().size; + return offset < size ? size - offset : 0; +} + +Promise FileOutputStream::write(const void* buffer, size_t size) { + file.write(offset, arrayPtr(reinterpret_cast(buffer), size)); + offset += size; + return kj::READY_NOW; +} + +Promise FileOutputStream::write(ArrayPtr> pieces) { + // TODO(perf): Extend kj::File with an array-of-arrays write? + for (auto piece: pieces) { + file.write(offset, piece); + offset += piece.size(); + } + return kj::READY_NOW; +} + +Promise FileOutputStream::whenWriteDisconnected() { + return kj::NEVER_DONE; +} + +// ======================================================================================= + +namespace { + +class AggregateConnectionReceiver final: public ConnectionReceiver { +public: + AggregateConnectionReceiver(Array> receiversParam) + : receivers(kj::mv(receiversParam)), + acceptTasks(kj::heapArray>>(receivers.size())) {} + + Promise> accept() override { + return acceptAuthenticated().then([](AuthenticatedStream&& authenticated) { + return kj::mv(authenticated.stream); + }); + } + + Promise acceptAuthenticated() override { + // Whenever our accept() is called, we want it to resolve to the first connection accepted by + // any of our child receivers. Naively, it may seem like we should call accept() on them all + // and exclusiveJoin() the results. Unfortunately, this might not work in a certain race + // condition: if two or more of our children receive connections simultaneously, both child + // accept() calls may return, but we'll only end up taking one and dropping the other. + // + // To avoid this problem, we must instead initiate `accept()` calls on all children, and even + // after one of them returns a result, we must allow the others to keep running. If we end up + // accepting any sockets from children when there is no outstanding accept() on the aggregate, + // we must put that socket into a backlog. We only restart accept() calls on children if the + // backlog is empty, and hence the maximum length of the backlog is the number of children + // minus 1. + + if (backlog.empty()) { + auto result = kj::newAdaptedPromise(*this); + ensureAllAccepting(); + return result; + } else { + auto result = kj::mv(backlog.front()); + backlog.pop_front(); + return result; + } + } + + uint getPort() override { + return receivers[0]->getPort(); + } + void getsockopt(int level, int option, void* value, uint* length) override { + return receivers[0]->getsockopt(level, option, value, length); + } + void setsockopt(int level, int option, const void* value, uint length) override { + // Apply to all. + for (auto& r: receivers) { + r->setsockopt(level, option, value, length); + } + } + void getsockname(struct sockaddr* addr, uint* length) override { + return receivers[0]->getsockname(addr, length); + } + +private: + Array> receivers; + Array>> acceptTasks; + + struct Waiter { + Waiter(PromiseFulfiller& fulfiller, + AggregateConnectionReceiver& parent) + : fulfiller(fulfiller), parent(parent) { + parent.waiters.add(*this); + } + ~Waiter() noexcept(false) { + if (link.isLinked()) { + parent.waiters.remove(*this); + } + } + + PromiseFulfiller& fulfiller; + AggregateConnectionReceiver& parent; + ListLink link; + }; + + List waiters; + std::deque> backlog; + // At least one of `waiters` or `backlog` is always empty. + + void ensureAllAccepting() { + for (auto i: kj::indices(receivers)) { + if (acceptTasks[i] == nullptr) { + acceptTasks[i] = acceptLoop(i); + } + } + } + + Promise acceptLoop(size_t index) { + return kj::evalNow([&]() { return receivers[index]->acceptAuthenticated(); }) + .then([this](AuthenticatedStream&& as) { + if (waiters.empty()) { + backlog.push_back(kj::mv(as)); + } else { + auto& waiter = waiters.front(); + waiter.fulfiller.fulfill(kj::mv(as)); + waiters.remove(waiter); + } + }, [this](Exception&& e) { + if (waiters.empty()) { + backlog.push_back(kj::mv(e)); + } else { + auto& waiter = waiters.front(); + waiter.fulfiller.reject(kj::mv(e)); + waiters.remove(waiter); + } + }).then([this, index]() -> Promise { + if (waiters.empty()) { + // Don't keep accepting if there's no one waiting. + // HACK: We can't cancel ourselves, so detach the task so we can null out the slot. + // We know that the promise we're detaching here is exactly the promise that's currently + // executing and has no further `.then()`s on it, so no further callbacks will run in + // detached state... we're just using `detach()` as a tricky way to have the event loop + // dispose of this promise later after we've returned. + // TODO(cleanup): This pattern has come up several times, we need a better way to handle + // it. + KJ_ASSERT_NONNULL(acceptTasks[index]).detach([](auto&&) {}); + acceptTasks[index] = nullptr; + return READY_NOW; + } else { + return acceptLoop(index); + } + }); + } +}; + +} // namespace + +Own newAggregateConnectionReceiver(Array> receivers) { + return kj::heap(kj::mv(receivers)); +} + +// ----------------------------------------------------------------------------- + +namespace _ { // private + +#if !_WIN32 + +kj::ArrayPtr safeUnixPath(const struct sockaddr_un* addr, uint addrlen) { + KJ_REQUIRE(addr->sun_family == AF_UNIX, "not a unix address"); + KJ_REQUIRE(addrlen >= offsetof(sockaddr_un, sun_path), "invalid unix address"); + + size_t maxPathlen = addrlen - offsetof(sockaddr_un, sun_path); + + size_t pathlen; + if (maxPathlen > 0 && addr->sun_path[0] == '\0') { + // Linux "abstract" unix address + pathlen = strnlen(addr->sun_path + 1, maxPathlen - 1) + 1; + } else { + pathlen = strnlen(addr->sun_path, maxPathlen); + } + return kj::arrayPtr(addr->sun_path, pathlen); +} + +#endif // !_WIN32 + +ArrayPtr localCidrs() { + static const CidrRange result[] = { + // localhost + "127.0.0.0/8"_kj, + "::1/128"_kj, + + // Trying to *connect* to 0.0.0.0 on many systems is equivalent to connecting to localhost. + // (wat) + "0.0.0.0/32"_kj, + "::/128"_kj, + }; + + // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly + // casting to our return type. + return kj::arrayPtr(result, kj::size(result)); +} + +ArrayPtr privateCidrs() { + static const CidrRange result[] = { + "10.0.0.0/8"_kj, // RFC1918 reserved for internal network + "100.64.0.0/10"_kj, // RFC6598 "shared address space" for carrier-grade NAT + "169.254.0.0/16"_kj, // RFC3927 "link local" (auto-configured LAN in absence of DHCP) + "172.16.0.0/12"_kj, // RFC1918 reserved for internal network + "192.168.0.0/16"_kj, // RFC1918 reserved for internal network + + "fc00::/7"_kj, // RFC4193 unique private network + "fe80::/10"_kj, // RFC4291 "link local" (auto-configured LAN in absence of DHCP) + }; + + // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly + // casting to our return type. + return kj::arrayPtr(result, kj::size(result)); +} + +ArrayPtr reservedCidrs() { + static const CidrRange result[] = { + "192.0.0.0/24"_kj, // RFC6890 reserved for special protocols + "224.0.0.0/4"_kj, // RFC1112 multicast + "240.0.0.0/4"_kj, // RFC1112 multicast / reserved for future use + "255.255.255.255/32"_kj, // RFC0919 broadcast address + + "2001::/23"_kj, // RFC2928 reserved for special protocols + "ff00::/8"_kj, // RFC4291 multicast + }; + + // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly + // casting to our return type. + return kj::arrayPtr(result, kj::size(result)); +} + +ArrayPtr exampleAddresses() { + static const CidrRange result[] = { + "192.0.2.0/24"_kj, // RFC5737 "example address" block 1 -- like example.com for IPs + "198.51.100.0/24"_kj, // RFC5737 "example address" block 2 -- like example.com for IPs + "203.0.113.0/24"_kj, // RFC5737 "example address" block 3 -- like example.com for IPs + "2001:db8::/32"_kj, // RFC3849 "example address" block -- like example.com for IPs + }; + + // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly + // casting to our return type. + return kj::arrayPtr(result, kj::size(result)); +} + +bool matchesAny(ArrayPtr cidrs, const struct sockaddr* addr) { + for (auto& cidr: cidrs) { + if (cidr.matches(addr)) return true; + } + return false; +} + +NetworkFilter::NetworkFilter() + : allowUnix(true), allowAbstractUnix(true) { + allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0)); + allowCidrs.add(CidrRange::inet6({}, {}, 0)); + denyCidrs.addAll(reservedCidrs()); +} + +NetworkFilter::NetworkFilter(ArrayPtr allow, ArrayPtr deny, + NetworkFilter& next) + : allowUnix(false), allowAbstractUnix(false), next(next) { + for (auto rule: allow) { + if (rule == "local") { + allowCidrs.addAll(localCidrs()); + } else if (rule == "network") { + // Can't be represented as a simple union of CIDRs, so we handle in shouldAllow(). + allowNetwork = true; + } else if (rule == "private") { + allowCidrs.addAll(privateCidrs()); + allowCidrs.addAll(localCidrs()); + } else if (rule == "public") { + // Can't be represented as a simple union of CIDRs, so we handle in shouldAllow(). + allowPublic = true; + } else if (rule == "unix") { + allowUnix = true; + } else if (rule == "unix-abstract") { + allowAbstractUnix = true; + } else { + allowCidrs.add(CidrRange(rule)); + } + } + + for (auto rule: deny) { + if (rule == "local") { + denyCidrs.addAll(localCidrs()); + } else if (rule == "network") { + KJ_FAIL_REQUIRE("don't deny 'network', allow 'local' instead"); + } else if (rule == "private") { + denyCidrs.addAll(privateCidrs()); + } else if (rule == "public") { + // Tricky: What if we allow 'network' and deny 'public'? + KJ_FAIL_REQUIRE("don't deny 'public', allow 'private' instead"); + } else if (rule == "unix") { + allowUnix = false; + } else if (rule == "unix-abstract") { + allowAbstractUnix = false; + } else { + denyCidrs.add(CidrRange(rule)); + } + } +} + +bool NetworkFilter::shouldAllow(const struct sockaddr* addr, uint addrlen) { + KJ_REQUIRE(addrlen >= sizeof(addr->sa_family)); + +#if !_WIN32 + if (addr->sa_family == AF_UNIX) { + auto path = safeUnixPath(reinterpret_cast(addr), addrlen); + if (path.size() > 0 && path[0] == '\0') { + return allowAbstractUnix; + } else { + return allowUnix; + } + } +#endif + + bool allowed = false; + uint allowSpecificity = 0; + + if (allowPublic) { + if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && + !matchesAny(privateCidrs(), addr) && !matchesAny(localCidrs(), addr)) { + allowed = true; + // Don't adjust allowSpecificity as this match has an effective specificity of zero. + } + } + + if (allowNetwork) { + if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && + !matchesAny(localCidrs(), addr)) { + allowed = true; + // Don't adjust allowSpecificity as this match has an effective specificity of zero. + } + } + + for (auto& cidr: allowCidrs) { + if (cidr.matches(addr)) { + allowSpecificity = kj::max(allowSpecificity, cidr.getSpecificity()); + allowed = true; + } + } + if (!allowed) return false; + for (auto& cidr: denyCidrs) { + if (cidr.matches(addr)) { + if (cidr.getSpecificity() >= allowSpecificity) return false; + } + } + + KJ_IF_MAYBE(n, next) { + return n->shouldAllow(addr, addrlen); + } else { + return true; + } +} + +bool NetworkFilter::shouldAllowParse(const struct sockaddr* addr, uint addrlen) { + bool matched = false; +#if !_WIN32 + if (addr->sa_family == AF_UNIX) { + auto path = safeUnixPath(reinterpret_cast(addr), addrlen); + if (path.size() > 0 && path[0] == '\0') { + if (allowAbstractUnix) matched = true; + } else { + if (allowUnix) matched = true; + } + } else { +#endif + if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && + (allowPublic || allowNetwork)) { + matched = true; + } + for (auto& cidr: allowCidrs) { + if (cidr.matchesFamily(addr->sa_family)) { + matched = true; + } + } +#if !_WIN32 + } +#endif + + if (matched) { + KJ_IF_MAYBE(n, next) { + return n->shouldAllowParse(addr, addrlen); + } else { + return true; + } + } else { + // No allow rule matches this address family, so don't even allow parsing it. + return false; + } +} + +} // namespace _ (private) + +// ======================================================================================= +// PeerIdentity implementations + +namespace { + +class NetworkPeerIdentityImpl final: public NetworkPeerIdentity { +public: + NetworkPeerIdentityImpl(kj::Own addr): addr(kj::mv(addr)) {} + + kj::String toString() override { return addr->toString(); } + NetworkAddress& getAddress() override { return *addr; } + +private: + kj::Own addr; +}; + +class LocalPeerIdentityImpl final: public LocalPeerIdentity { +public: + LocalPeerIdentityImpl(Credentials creds): creds(creds) {} + + kj::String toString() override { + char pidBuffer[16]; + kj::StringPtr pidStr = nullptr; + KJ_IF_MAYBE(p, creds.pid) { + pidStr = strPreallocated(pidBuffer, " pid:", *p); + } + + char uidBuffer[16]; + kj::StringPtr uidStr = nullptr; + KJ_IF_MAYBE(u, creds.uid) { + uidStr = strPreallocated(uidBuffer, " uid:", *u); + } + + return kj::str("(local peer", pidStr, uidStr, ")"); + } + + Credentials getCredentials() override { return creds; } + +private: + Credentials creds; +}; + +class UnknownPeerIdentityImpl final: public UnknownPeerIdentity { +public: + kj::String toString() override { + return kj::str("(unknown peer)"); + } +}; + +} // namespace + +kj::Own NetworkPeerIdentity::newInstance(kj::Own addr) { + return kj::heap(kj::mv(addr)); +} + +kj::Own LocalPeerIdentity::newInstance(LocalPeerIdentity::Credentials creds) { + return kj::heap(creds); +} + +kj::Own UnknownPeerIdentity::newInstance() { + static UnknownPeerIdentityImpl instance; + return { &instance, NullDisposer::instance }; +} + +Promise ConnectionReceiver::acceptAuthenticated() { + return accept().then([](Own stream) { + return AuthenticatedStream { kj::mv(stream), UnknownPeerIdentity::newInstance() }; + }); +} + +Promise NetworkAddress::connectAuthenticated() { + return connect().then([](Own stream) { + return AuthenticatedStream { kj::mv(stream), UnknownPeerIdentity::newInstance() }; + }); } } // namespace kj diff --git a/c++/src/kj/async-io.h b/c++/src/kj/async-io.h index 2804ed7289..de3e808247 100644 --- a/c++/src/kj/async-io.h +++ b/c++/src/kj/async-io.h @@ -19,17 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_ASYNC_IO_H_ -#define KJ_ASYNC_IO_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "async.h" -#include "function.h" -#include "thread.h" -#include "time.h" +#include +#include +#include + +KJ_BEGIN_HEADER struct sockaddr; @@ -37,17 +34,24 @@ namespace kj { #if _WIN32 class Win32EventPort; +class AutoCloseHandle; #else class UnixEventPort; #endif +class AutoCloseFd; class NetworkAddress; class AsyncOutputStream; +class AsyncIoStream; +class AncillaryMessage; + +class ReadableFile; +class File; // ======================================================================================= // Streaming I/O -class AsyncInputStream { +class AsyncInputStream: private AsyncObject { // Asynchronous equivalent of InputStream (from io.h). public: @@ -77,12 +81,31 @@ class AsyncInputStream { // The default implementation first tries calling output.tryPumpFrom(), but if that fails, it // performs a naive pump by allocating a buffer and reading to it / writing from it in a loop. - Promise> readAllBytes(); - Promise readAllText(); - // Read until EOF and return as one big byte array or string. + Promise> readAllBytes(uint64_t limit = kj::maxValue); + Promise readAllText(uint64_t limit = kj::maxValue); + // Read until EOF and return as one big byte array or string. Throw an exception if EOF is not + // seen before reading `limit` bytes. + // + // To prevent runaway memory allocation, consider using a more conservative value for `limit` than + // the default, particularly on untrusted data streams which may never see EOF. + + virtual void registerAncillaryMessageHandler(Function)> fn); + // Register interest in checking for ancillary messages (aka control messages) when reading. + // The provided callback will be called whenever any are encountered. The messages passed to + // the function do not live beyond when function returns. + // Only supported on Unix (the default impl throws UNIMPLEMENTED). Most apps will not use this. + + virtual Maybe> tryTee(uint64_t limit = kj::maxValue); + // Primarily intended as an optimization for the `tee` call. Returns an input stream whose state + // is independent from this one but which will return the exact same set of bytes read going + // forward. limit is a total limit on the amount of memory, in bytes, which a tee implementation + // may use to buffer stream data. An implementation must throw an exception if a read operation + // would cause the limit to be exceeded. If tryTee() can see that the new limit is impossible to + // satisfy, it should return nullptr so that the pessimized path is taken in newTee. This is + // likely to arise if tryTee() is called twice with different limits on the same stream. }; -class AsyncOutputStream { +class AsyncOutputStream: private AsyncObject { // Asynchronous equivalent of OutputStream (from io.h). public: @@ -100,6 +123,20 @@ class AsyncOutputStream { // output stream. If it finds one, it performs the pump. Otherwise, it returns null. // // The default implementation always returns null. + + virtual Promise whenWriteDisconnected() = 0; + // Returns a promise that resolves when the stream has become disconnected such that new write()s + // will fail with a DISCONNECTED exception. This is particularly useful, for example, to cancel + // work early when it is detected that no one will receive the result. + // + // Note that not all streams are able to detect this condition without actually performing a + // write(); such stream implementations may return a promise that never resolves. (In particular, + // as of this writing, whenWriteDisconnected() is not implemented on Windows. Also, for TCP + // streams, not all disconnects are detectable -- a power or network failure may lead the + // connection to hang forever, or until configured socket options lead to a timeout.) + // + // Unlike most other asynchronous stream methods, it is safe to call whenWriteDisconnected() + // multiple times without canceling the previous promises. }; class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream { @@ -128,6 +165,109 @@ class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream { // Note that we don't provide methods that return NetworkAddress because it usually wouldn't // be useful. You can't connect() to or listen() on these addresses, obviously, because they are // ephemeral addresses for a single connection. + + virtual kj::Maybe getFd() const { return nullptr; } + // Get the underlying Unix file descriptor, if any. Returns nullptr if this object actually + // isn't wrapping a file descriptor. + + virtual Maybe getWin32Handle() const { return nullptr; } + // Get the underlying Win32 HANDLE, if any. Returns nullptr if this object actually isn't + // wrapping a handle. +}; + +Promise unoptimizedPumpTo( + AsyncInputStream& input, AsyncOutputStream& output, uint64_t amount, + uint64_t completedSoFar = 0); +// Performs a pump using read() and write(), without calling the stream's pumpTo() nor +// tryPumpFrom() methods. This is intended to be used as a fallback by implementations of pumpTo() +// and tryPumpFrom() when they want to give up on optimization, but can't just call pumpTo() again +// because this would recursively retry the optimization. unoptimizedPumpTo() should only be called +// inside implementations of streams, never by the caller of a stream -- use the pumpTo() method +// instead. +// +// `completedSoFar` is the number of bytes out of `amount` that have already been pumped. This is +// provided for convenience for cases where the caller has already done some pumping before they +// give up. Otherwise, a `.then()` would need to be used to add the bytes to the final result. + +class AsyncCapabilityStream: public AsyncIoStream { + // An AsyncIoStream that also allows transmitting new stream objects and file descriptors + // (capabilities, in the object-capability model sense), in addition to bytes. + // + // Capabilities can be attached to bytes when they are written. On the receiving end, the read() + // that receives the first byte of such a message will also receive the capabilities. + // + // Note that AsyncIoStream's regular byte-oriented methods can be used on AsyncCapabilityStream, + // with the effect of silently dropping any capabilities attached to the respective bytes. E.g. + // using `AsyncIoStream::tryRead()` to read bytes that had been sent with `writeWithFds()` will + // silently drop the FDs (closing them if appropriate). Also note that pumping a stream with + // `pumpTo()` always drops all capabilities attached to the pumped data. (TODO(someday): Do we + // want a version of pumpTo() that preserves capabilities?) + // + // On Unix, KJ provides an implementation based on Unix domain sockets and file descriptor + // passing via SCM_RIGHTS. Due to the nature of SCM_RIGHTS, if the application accidentally + // read()s when it should have called receiveStream(), it will observe a NUL byte in the data + // and the capability will be discarded. Of course, an application should not depend on this + // behavior; it should avoid read()ing through a capability. + // + // KJ does not provide any inter-process implementation of this type on Windows, as there's no + // obvious implementation there. Handle passing on Windows requires at least one of the processes + // involved to have permission to modify the other's handle table, which is effectively full + // control. Handle passing between mutually non-trusting processes would require a trusted + // broker process to facilitate. One could possibly implement this type in terms of such a + // broker, or in terms of direct handle passing if at least one process trusts the other. + +public: + virtual Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds) = 0; + Promise writeWithFds(ArrayPtr data, + ArrayPtr> moreData, + ArrayPtr fds); + // Write some data to the stream with some file descriptors attached to it. + // + // The maximum number of FDs that can be sent at a time is usually subject to an OS-imposed + // limit. On Linux, this is 253. In practice, sending more than a handful of FDs at once is + // probably a bad idea. + + struct ReadResult { + size_t byteCount; + size_t capCount; + }; + + virtual Promise tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes, + AutoCloseFd* fdBuffer, size_t maxFds) = 0; + // Read data from the stream that may have file descriptors attached. Any attached descriptors + // will be placed in `fdBuffer`. If multiple bundles of FDs are encountered in the course of + // reading the amount of data requested by minBytes/maxBytes, then they will be concatenated. If + // more FDs are received than fit in the buffer, then the excess will be discarded and closed -- + // this behavior, while ugly, is important to defend against denial-of-service attacks that may + // fill up the FD table with garbage. Applications must think carefully about how many FDs they + // really need to receive at once and set a well-defined limit. + + virtual Promise writeWithStreams(ArrayPtr data, + ArrayPtr> moreData, + Array> streams) = 0; + virtual Promise tryReadWithStreams( + void* buffer, size_t minBytes, size_t maxBytes, + Own* streamBuffer, size_t maxStreams) = 0; + // Like above, but passes AsyncCapabilityStream objects. The stream implementations must be from + // the same AsyncIoProvider. + + // --------------------------------------------------------------------------- + // Helpers for sending individual capabilities. + // + // These are equivalent to the above methods with the constraint that only one FD is + // sent/received at a time and the corresponding data is a single zero-valued byte. + + Promise> receiveStream(); + Promise>> tryReceiveStream(); + Promise sendStream(Own stream); + // Transfer a single stream. + + Promise receiveFd(); + Promise> tryReceiveFd(); + Promise sendFd(int fd); + // Transfer a single raw file descriptor. }; struct OneWayPipe { @@ -137,6 +277,13 @@ struct OneWayPipe { Own out; }; +OneWayPipe newOneWayPipe(kj::Maybe expectedLength = nullptr); +// Constructs a OneWayPipe that operates in-process. The pipe does not do any buffering -- it waits +// until both a read() and a write() call are pending, then resolves both. +// +// If `expectedLength` is non-null, then the pipe will be expected to transmit exactly that many +// bytes. The input end's `tryGetLength()` will return the number of bytes left. + struct TwoWayPipe { // A data pipe that supports sending in both directions. Each end's output sends data to the // other end's input. (Typically backed by socketpair() system call.) @@ -144,13 +291,164 @@ struct TwoWayPipe { Own ends[2]; }; -class ConnectionReceiver { +TwoWayPipe newTwoWayPipe(); +// Constructs a TwoWayPipe that operates in-process. The pipe does not do any buffering -- it waits +// until both a read() and a write() call are pending, then resolves both. + +struct CapabilityPipe { + // Like TwoWayPipe but allowing capability-passing. + + Own ends[2]; +}; + +CapabilityPipe newCapabilityPipe(); +// Like newTwoWayPipe() but creates a capability pipe. +// +// The requirement of `writeWithStreams()` that "The stream implementations must be from the same +// AsyncIoProvider." does not apply to this pipe; any kind of AsyncCapabilityStream implementation +// is supported. +// +// This implementation does not know how to convert streams to FDs or vice versa; if you write FDs +// you must read FDs, and if you write streams you must read streams. + +struct Tee { + // Two AsyncInputStreams which each read the same data from some wrapped inner AsyncInputStream. + + Own branches[2]; +}; + +Tee newTee(Own input, uint64_t limit = kj::maxValue); +// Constructs a Tee that operates in-process. The tee buffers data if any read or pump operations is +// called on one of the two input ends. If a read or pump operation is subsequently called on the +// other input end, the buffered data is consumed. +// +// `pumpTo()` operations on the input ends will proactively read from the inner stream and block +// while writing to the output stream. While one branch has an active `pumpTo()` operation, any +// `tryRead()` operation on the other branch will not be allowed to read faster than allowed by the +// pump's backpressure. (In other words, it will never cause buffering on the pump.) Similarly, if +// there are `pumpTo()` operations active on both branches, the greater of the two backpressures is +// respected -- the two pumps progress in lockstep, and there is no buffering. +// +// At no point will a branch's buffer be allowed to grow beyond `limit` bytes. If the buffer would +// grow beyond the limit, an exception is generated, which both branches see once they have +// exhausted their buffers. +// +// It is recommended that you use a more conservative value for `limit` than the default. + +Own newPromisedStream(Promise> promise); +Own newPromisedStream(Promise> promise); +// Constructs an Async*Stream which waits for a promise to resolve, then forwards all calls to the +// promised stream. + +// ======================================================================================= +// Authenticated streams + +class PeerIdentity { + // PeerIdentity provides information about a connecting client. Various subclasses exist to + // address different network types. +public: + virtual kj::String toString() = 0; + // Returns a human-readable string identifying the peer. Where possible, this string will be + // in the same format as the addresses you could pass to `kj::Network::parseAddress()`. However, + // only certain subclasses of `PeerIdentity` guarantee this property. +}; + +struct AuthenticatedStream { + // A pair of an `AsyncIoStream` and a `PeerIdentity`. This is used as the return type of + // `NetworkAddress::connectAuthenticated()` and `ConnectionReceiver::acceptAuthenticated()`. + + Own stream; + // The byte stream. + + Own peerIdentity; + // An object indicating who is at the other end of the stream. + // + // Different subclasses of `PeerIdentity` are used in different situations: + // - TCP connections will use NetworkPeerIdentity, which gives the network address of the client. + // - Local (unix) socket connections will use LocalPeerIdentity, which identifies the UID + // and PID of the process that initiated the connection. + // - TLS connections will use TlsPeerIdentity which provides details of the client certificate, + // if any was provided. + // - When no meaningful peer identity can be provided, `UnknownPeerIdentity` is returned. + // + // Implementations of `Network`, `ConnectionReceiver`, `NetworkAddress`, etc. should document the + // specific assumptions the caller can make about the type of `PeerIdentity`s used, allowing for + // identities to be statically downcast if the right conditions are met. In the absence of + // documented promises, RTTI may be needed to query the type. +}; + +class NetworkPeerIdentity: public PeerIdentity { + // PeerIdentity used for network protocols like TCP/IP. This identifies the remote peer. + // + // This is only "authenticated" to the extent that we know data written to the stream will be + // routed to the given address. This does not preclude the possibility of man-in-the-middle + // attacks by attackers who are able to manipulate traffic along the route. +public: + virtual NetworkAddress& getAddress() = 0; + // Obtain the peer's address as a NetworkAddress object. The returned reference's lifetime is the + // same as the `NetworkPeerIdentity`, but you can always call `clone()` on it to get a copy that + // lives longer. + + static kj::Own newInstance(kj::Own addr); + // Construct an instance of this interface wrapping the given address. +}; + +class LocalPeerIdentity: public PeerIdentity { + // PeerIdentity used for connections between processes on the local machine -- in particular, + // Unix sockets. + // + // (This interface probably isn't useful on Windows.) +public: + struct Credentials { + kj::Maybe pid; + kj::Maybe uid; + + // We don't cover groups at present because some systems produce a list of groups while others + // only provide the peer's main group, the latter being pretty useless. + }; + + virtual Credentials getCredentials() = 0; + // Get the PID and UID of the peer process, if possible. + // + // Either ID may be null if the peer could not be identified. Some operating systems do not + // support retrieving these credentials, or can only provide one or the other. Some situations + // (like user and PID namespaces on Linux) may also make it impossible to represent the peer's + // credentials accurately. + // + // Note the meaning here can be subtle. Multiple processes can potentially have the socket in + // their file descriptor tables. The identified process is the one who called `connect()` or + // `listen()`. + // + // On Linux this is implemented with SO_PEERCRED. + + static kj::Own newInstance(Credentials creds); + // Construct an instance of this interface wrapping the given credentials. +}; + +class UnknownPeerIdentity: public PeerIdentity { +public: + static kj::Own newInstance(); + // Get an instance of this interface. This actually always returns the same instance with no + // memory allocation. +}; + +// ======================================================================================= +// Accepting connections + +class ConnectionReceiver: private AsyncObject { // Represents a server socket listening on a port. public: virtual Promise> accept() = 0; // Accept the next incoming connection. + virtual Promise acceptAuthenticated(); + // Accept the next incoming connection, and also provide a PeerIdentity with any information + // about the client. + // + // For backwards-compatibility, the default implementation of this method calls `accept()` and + // then adds `UnknownPeerIdentity`. + virtual uint getPort() = 0; // Gets the port number, if applicable (i.e. if listening on IP). This is useful if you didn't // specify a port when constructing the NetworkAddress -- one will have been assigned @@ -158,9 +456,14 @@ class ConnectionReceiver { virtual void getsockopt(int level, int option, void* value, uint* length); virtual void setsockopt(int level, int option, const void* value, uint length); + virtual void getsockname(struct sockaddr* addr, uint* length); // Same as the methods of AsyncIoStream. }; +Own newAggregateConnectionReceiver(Array> receivers); +// Create a ConnectionReceiver that listens on several other ConnectionReceivers and returns +// sockets from any of them. + // ======================================================================================= // Datagram I/O @@ -179,14 +482,14 @@ class AncillaryMessage { // Protocol-specific message type. template - inline Maybe as(); + inline Maybe as() const; // Interpret the ancillary message as the given struct type. Most ancillary messages are some // sort of struct, so this is a convenient way to access it. Returns nullptr if the message // is smaller than the struct -- this can happen if the message was truncated due to // insufficient ancillary buffer space. template - inline ArrayPtr asArray(); + inline ArrayPtr asArray() const; // Interpret the ancillary message as an array of items. If the message size does not evenly // divide into elements of type T, the remainder is discarded -- this can happen if the message // was truncated due to insufficient ancillary buffer space. @@ -221,7 +524,7 @@ class DatagramReceiver { // Get the content of the datagram. virtual MaybeTruncated> getAncillary() = 0; - // Ancilarry messages received with the datagram. See the recvmsg() system call and the cmsghdr + // Ancillary messages received with the datagram. See the recvmsg() system call and the cmsghdr // struct. Most apps don't need this. // // If the returned value is truncated, then the last message in the array may itself be @@ -267,7 +570,7 @@ class DatagramPort { // ======================================================================================= // Networks -class NetworkAddress { +class NetworkAddress: private AsyncObject { // Represents a remote address to which the application can connect. public: @@ -276,6 +579,14 @@ class NetworkAddress { // // The address must not be a wildcard ("*"). If it is an IP address, it must have a port number. + virtual Promise connectAuthenticated(); + // Connect to the address and return both the connection and information about the peer identity. + // This is especially useful when using TLS, to get certificate details. + // + // For backwards-compatibility, the default implementation of this method calls `connect()` and + // then uses a `NetworkPeerIdentity` wrapping a clone of this `NetworkAddress` -- which is not + // particularly useful. + virtual Own listen() = 0; // Listen for incoming connections on this address. // @@ -319,6 +630,67 @@ class Network { virtual Own getSockaddr(const void* sockaddr, uint len) = 0; // Construct a network address from a legacy struct sockaddr. + + virtual Own restrictPeers( + kj::ArrayPtr allow, + kj::ArrayPtr deny = nullptr) KJ_WARN_UNUSED_RESULT = 0; + // Constructs a new Network instance wrapping this one which restricts which peer addresses are + // permitted (both for outgoing and incoming connections). + // + // Communication will be allowed only with peers whose addresses match one of the patterns + // specified in the `allow` array. If a `deny` array is specified, then any address which matches + // a pattern in `deny` and *does not* match any more-specific pattern in `allow` will also be + // denied. + // + // The syntax of address patterns depends on the network, except that three special patterns are + // defined for all networks: + // - "private": Matches network addresses that are reserved by standards for private networks, + // such as "10.0.0.0/8" or "192.168.0.0/16". This is a superset of "local". + // - "public": Opposite of "private". + // - "local": Matches network addresses that are defined by standards to only be accessible from + // the local machine, such as "127.0.0.0/8" or Unix domain addresses. + // - "network": Opposite of "local". + // + // For the standard KJ network implementation, the following patterns are also recognized: + // - Network blocks specified in CIDR notation (ipv4 and ipv6), such as "192.0.2.0/24" or + // "2001:db8::/32". + // - "unix" to match all Unix domain addresses. (In the future, we may support specifying a + // glob.) + // - "unix-abstract" to match Linux's "abstract unix domain" addresses. (In the future, we may + // support specifying a glob.) + // + // Network restrictions apply *after* DNS resolution (otherwise they'd be useless). + // + // It is legal to parseAddress() a restricted address. An exception won't be thrown until + // connect() is called. + // + // It's possible to listen() on a restricted address. However, connections will only be accepted + // from non-restricted addresses; others will be dropped. If a particular listen address has no + // valid peers (e.g. because it's a unix socket address and unix sockets are not allowed) then + // listen() may throw (or may simply never receive any connections). + // + // Examples: + // + // auto restricted = network->restrictPeers({"public"}); + // + // Allows connections only to/from public internet addresses. Use this when connecting to an + // address specified by a third party that is not trusted and is not themselves already on your + // private network. + // + // auto restricted = network->restrictPeers({"private"}); + // + // Allows connections only to/from the private network. Use this on the server side to reject + // connections from the public internet. + // + // auto restricted = network->restrictPeers({"192.0.2.0/24"}, {"192.0.2.3/32"}); + // + // Allows connections only to/from 192.0.2.*, except 192.0.2.3 which is blocked. + // + // auto restricted = network->restrictPeers({"10.0.0.0/8", "10.1.2.3/32"}, {"10.1.2.0/24"}); + // + // Allows connections to/from 10.*.*.*, with the exception of 10.1.2.* (which is denied), with an + // exception to the exception of 10.1.2.3 (which is allowed, because it is matched by an allow + // rule that is more specific than the deny rule). }; // ======================================================================================= @@ -340,6 +712,13 @@ class AsyncIoProvider { // Creates two AsyncIoStreams representing the two ends of a two-way pipe (e.g. created with // socketpair(2) system call). Data written to one end can be read from the other. + virtual CapabilityPipe newCapabilityPipe(); + // Creates two AsyncCapabilityStreams representing the two ends of a two-way capability pipe. + // + // The default implementation throws an unimplemented exception. In particular this is not + // implemented by the default AsyncIoProvider on Windows, since Windows lacks any sane way to + // pass handles over a stream. + virtual Network& getNetwork() = 0; // Creates a new `Network` instance representing the networks exposed by the operating system. // @@ -400,16 +779,11 @@ class LowLevelAsyncIoProvider { // Different implementations of this interface might work on top of different event handling // primitives, such as poll vs. epoll vs. kqueue vs. some higher-level event library. // - // On Windows, this interface can be used to import native HANDLEs into the async framework. + // On Windows, this interface can be used to import native SOCKETs into the async framework. // Different implementations of this interface might work on top of different event handling // primitives, such as I/O completion ports vs. completion routines. - // - // TODO(port): Actually implement Windows support. public: - // --------------------------------------------------------------------------- - // Unix-specific stuff - enum Flags { // Flags controlling how to wrap a file descriptor. @@ -440,11 +814,13 @@ class LowLevelAsyncIoProvider { #if _WIN32 typedef uintptr_t Fd; + typedef AutoCloseHandle OwnFd; // On Windows, the `fd` parameter to each of these methods must be a SOCKET, and must have the // flag WSA_FLAG_OVERLAPPED (which socket() uses by default, but WSASocket() wants you to specify // explicitly). #else typedef int Fd; + typedef AutoCloseFd OwnFd; // On Unix, any arbitrary file descriptor is supported. #endif @@ -463,6 +839,15 @@ class LowLevelAsyncIoProvider { // // `flags` is a bitwise-OR of the values of the `Flags` enum. +#if !_WIN32 + virtual Own wrapUnixSocketFd(Fd fd, uint flags = 0); + // Like wrapSocketFd() but also support capability passing via SCM_RIGHTS. The socket must be + // a Unix domain socket. + // + // The default implementation throws UNIMPLEMENTED, for backwards-compatibility with + // LowLevelAsyncIoProvider implementations written before this method was added. +#endif + virtual Promise> wrapConnectingSocketFd( Fd fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) = 0; // Create an AsyncIoStream wrapping a socket and initiate a connection to the given address. @@ -470,13 +855,29 @@ class LowLevelAsyncIoProvider { // // `flags` is a bitwise-OR of the values of the `Flags` enum. - virtual Own wrapListenSocketFd(Fd fd, uint flags = 0) = 0; + class NetworkFilter { + public: + virtual bool shouldAllow(const struct sockaddr* addr, uint addrlen) = 0; + // Returns true if incoming connections or datagrams from the given peer should be accepted. + // If false, they will be dropped. This is used to implement kj::Network::restrictPeers(). + + static NetworkFilter& getAllAllowed(); + }; + + virtual Own wrapListenSocketFd( + Fd fd, NetworkFilter& filter, uint flags = 0) = 0; + inline Own wrapListenSocketFd(Fd fd, uint flags = 0) { + return wrapListenSocketFd(fd, NetworkFilter::getAllAllowed(), flags); + } // Create an AsyncIoStream wrapping a listen socket file descriptor. This socket should already // have had `bind()` and `listen()` called on it, so it's ready for `accept()`. // // `flags` is a bitwise-OR of the values of the `Flags` enum. - virtual Own wrapDatagramSocketFd(Fd fd, uint flags = 0); + virtual Own wrapDatagramSocketFd(Fd fd, NetworkFilter& filter, uint flags = 0); + inline Own wrapDatagramSocketFd(Fd fd, uint flags = 0) { + return wrapDatagramSocketFd(fd, NetworkFilter::getAllAllowed(), flags); + } virtual Timer& getTimer() = 0; // Returns a `Timer` based on real time. Time does not pass while event handlers are running -- @@ -485,6 +886,22 @@ class LowLevelAsyncIoProvider { // // This timer is not affected by changes to the system date. It is unspecified whether the timer // continues to count while the system is suspended. + + Own wrapInputFd(OwnFd&& fd, uint flags = 0); + Own wrapOutputFd(OwnFd&& fd, uint flags = 0); + Own wrapSocketFd(OwnFd&& fd, uint flags = 0); +#if !_WIN32 + Own wrapUnixSocketFd(OwnFd&& fd, uint flags = 0); +#endif + Promise> wrapConnectingSocketFd( + OwnFd&& fd, const struct sockaddr* addr, uint addrlen, uint flags = 0); + Own wrapListenSocketFd( + OwnFd&& fd, NetworkFilter& filter, uint flags = 0); + Own wrapListenSocketFd(OwnFd&& fd, uint flags = 0); + Own wrapDatagramSocketFd(OwnFd&& fd, NetworkFilter& filter, uint flags = 0); + Own wrapDatagramSocketFd(OwnFd&& fd, uint flags = 0); + // Convenience wrappers which transfer ownership via AutoCloseFd (Unix) or AutoCloseHandle + // (Windows). TAKE_OWNERSHIP will be implicitly added to `flags`. }; Own newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel); @@ -532,6 +949,125 @@ AsyncIoContext setupAsyncIo(); // note that this means that server processes which daemonize themselves at startup must wait // until after daemonization to create an AsyncIoContext. +// ======================================================================================= +// Convenience adapters. + +class CapabilityStreamConnectionReceiver final: public ConnectionReceiver { + // Trivial wrapper which allows an AsyncCapabilityStream to act as a ConnectionReceiver. accept() + // calls receiveStream(). + +public: + CapabilityStreamConnectionReceiver(AsyncCapabilityStream& inner) + : inner(inner) {} + + Promise> accept() override; + uint getPort() override; + + Promise acceptAuthenticated() override; + // Always produces UnknownIdentity. Capability-based security patterns should not rely on + // authenticating peers; the other end of the capability stream should only be given to + // authorized parties in the first place. + +private: + AsyncCapabilityStream& inner; +}; + +class CapabilityStreamNetworkAddress final: public NetworkAddress { + // Trivial wrapper which allows an AsyncCapabilityStream to act as a NetworkAddress. + // + // connect() is implemented by calling provider.newCapabilityPipe(), sending one end over the + // original capability stream, and returning the other end. If `provider` is null, then the + // global kj::newCapabilityPipe() will be used, but this ONLY works if `inner` itself is agnostic + // to the type of streams it receives, e.g. because it was also created using + // kj::NewCapabilityPipe(). + // + // listen().accept() is implemented by receiving new streams over the original stream. + // + // Note that clone() doesn't work (due to ownership issues) and toString() returns a static + // string. + +public: + CapabilityStreamNetworkAddress(kj::Maybe provider, AsyncCapabilityStream& inner) + : provider(provider), inner(inner) {} + + Promise> connect() override; + Own listen() override; + + Own clone() override; + String toString() override; + + Promise connectAuthenticated() override; + // Always produces UnknownIdentity. Capability-based security patterns should not rely on + // authenticating peers; the other end of the capability stream should only be given to + // authorized parties in the first place. + +private: + kj::Maybe provider; + AsyncCapabilityStream& inner; +}; + +class FileInputStream: public AsyncInputStream { + // InputStream that reads from a disk file -- and enables sendfile() optimization. + // + // Reads are performed synchronously -- no actual attempt is made to use asynchronous file I/O. + // True asynchronous file I/O is complicated and is mostly unnecessary in the presence of + // caching. Only certain niche programs can expect to benefit from it. For the rest, it's better + // to use regular syrchronous disk I/O, so that's what this class does. + // + // The real purpose of this class, aside from general convenience, is to enable sendfile() + // optimization. When you use this class's pumpTo() method, and the destination is a socket, + // the system will detect this and optimize to sendfile(), so that the file data never needs to + // be read into userspace. + // + // NOTE: As of this writing, sendfile() optimization is only implemented on Linux. + +public: + FileInputStream(const ReadableFile& file, uint64_t offset = 0) + : file(file), offset(offset) {} + + const ReadableFile& getUnderlyingFile() { return file; } + uint64_t getOffset() { return offset; } + void seek(uint64_t newOffset) { offset = newOffset; } + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes); + Maybe tryGetLength(); + + // (pumpTo() is not actually overridden here, but AsyncStreamFd's tryPumpFrom() will detect when + // the source is a file.) + +private: + const ReadableFile& file; + uint64_t offset; +}; + +class FileOutputStream: public AsyncOutputStream { + // OutputStream that writes to a disk file. + // + // As with FileInputStream, calls are not actually async. Async would be even less useful here + // because writes should usually land in cache anyway. + // + // sendfile() optimization does not apply when writing to a file, but on Linux, splice() can + // be used to achieve a similar effect. + // + // NOTE: As of this writing, splice() optimization is not implemented. + +public: + FileOutputStream(const File& file, uint64_t offset = 0) + : file(file), offset(offset) {} + + const File& getUnderlyingFile() { return file; } + uint64_t getOffset() { return offset; } + void seek(uint64_t newOffset) { offset = newOffset; } + + Promise write(const void* buffer, size_t size); + Promise write(ArrayPtr> pieces); + Promise whenWriteDisconnected(); + +private: + const File& file; + uint64_t offset; +}; + // ======================================================================================= // inline implementation details @@ -543,7 +1079,7 @@ inline int AncillaryMessage::getLevel() const { return level; } inline int AncillaryMessage::getType() const { return type; } template -inline Maybe AncillaryMessage::as() { +inline Maybe AncillaryMessage::as() const { if (data.size() >= sizeof(T)) { return *reinterpret_cast(data.begin()); } else { @@ -552,10 +1088,61 @@ inline Maybe AncillaryMessage::as() { } template -inline ArrayPtr AncillaryMessage::asArray() { +inline ArrayPtr AncillaryMessage::asArray() const { return arrayPtr(reinterpret_cast(data.begin()), data.size() / sizeof(T)); } +class SecureNetworkWrapper { + // Abstract interface for a class which implements a "secure" network as a wrapper around an + // insecure one. "secure" means: + // * Connections to a server will only succeed if it can be verified that the requested hostname + // actually belongs to the responding server. + // * No man-in-the-middle attacker can potentially see the bytes sent and received. + // + // The typical implementation uses TLS. The object in this case could be configured to use cerain + // keys, certificates, etc. See kj/compat/tls.h for such an implementation. + // + // However, an implementation could use some other form of encryption, or might not need to use + // encryption at all. For example, imagine a kj::Network that exists only on a single machine, + // providing communications between various processes using unix sockets. Perhaps the "hostnames" + // are actually PIDs in this case. An implementation of such a network could verify the other + // side's identity using an `SCM_CREDENTIALS` auxiliary message, which cannot be forged. Once + // verified, there is no need to encrypt since unix sockets cannot be intercepted. + +public: + virtual kj::Promise> wrapServer(kj::Own stream) = 0; + // Act as the server side of a connection. The given stream is already connected to a client, but + // no authentication has occurred. The returned stream represents the secure transport once + // established. + + virtual kj::Promise> wrapClient( + kj::Own stream, kj::StringPtr expectedServerHostname) = 0; + // Act as the client side of a connection. The given stream is already connecetd to a server, but + // no authentication has occurred. This method will verify that the server actually is the given + // hostname, then return the stream representing a secure transport to that server. + + virtual kj::Promise wrapServer(kj::AuthenticatedStream stream) = 0; + virtual kj::Promise wrapClient( + kj::AuthenticatedStream stream, kj::StringPtr expectedServerHostname) = 0; + // Same as above, but implementing kj::AuthenticatedStream, which provides PeerIdentity objects + // with more details about the peer. The SecureNetworkWrapper will provide its own implementation + // of PeerIdentity with the specific details it is able to authenticate. + + virtual kj::Own wrapPort(kj::Own port) = 0; + // Wrap a connection listener. This is equivalent to calling wrapServer() on every connection + // received. + + virtual kj::Own wrapAddress( + kj::Own address, kj::StringPtr expectedServerHostname) = 0; + // Wrap a NetworkAddress. This is equivalent to calling `wrapClient()` on every connection + // formed by calling `connect()` on the address. + + virtual kj::Own wrapNetwork(kj::Network& network) = 0; + // Wrap a whole `kj::Network`. This automatically wraps everything constructed using the network. + // The network will only accept address strings that can be authenticated, and will automatically + // authenticate servers against those addresses when connecting to them. +}; + } // namespace kj -#endif // KJ_ASYNC_IO_H_ +KJ_END_HEADER diff --git a/c++/src/kj/async-prelude.h b/c++/src/kj/async-prelude.h index 0a5843f88a..6289bf3fa0 100644 --- a/c++/src/kj/async-prelude.h +++ b/c++/src/kj/async-prelude.h @@ -22,15 +22,32 @@ // This file contains a bunch of internal declarations that must appear before async.h can start. // We don't define these directly in async.h because it makes the file hard to read. -#ifndef KJ_ASYNC_PRELUDE_H_ -#define KJ_ASYNC_PRELUDE_H_ +#pragma once -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header +#include +#include +#include + +// Detect whether or not we should enable kj::Promise coroutine integration. +// +// TODO(someday): Support coroutines with -fno-exceptions. +#if !KJ_NO_EXCEPTIONS +#ifdef __has_include +#if (__cpp_impl_coroutine >= 201902L) && __has_include() +// C++20 Coroutines detected. +#include +#define KJ_HAS_COROUTINE 1 +#define KJ_COROUTINE_STD_NAMESPACE std +#elif (__cpp_coroutines >= 201703L) && __has_include() +// Coroutines TS detected. +#include +#define KJ_HAS_COROUTINE 1 +#define KJ_COROUTINE_STD_NAMESPACE std::experimental +#endif +#endif #endif -#include "exception.h" -#include "tuple.h" +KJ_BEGIN_HEADER namespace kj { @@ -38,22 +55,42 @@ class EventLoop; template class Promise; class WaitScope; +class TaskSet; -template -Promise> joinPromises(Array>&& promises); -Promise joinPromises(Array>&& promises); +Promise joinPromises(Array>&& promises, SourceLocation location = {}); +Promise joinPromisesFailFast(Array>&& promises, SourceLocation location = {}); +// Out-of-line specialization of template function defined in async.h. namespace _ { // private -template struct JoinPromises_ { typedef T Type; }; -template struct JoinPromises_> { typedef T Type; }; +template +Promise chainPromiseType(T*); +template +Promise chainPromiseType(Promise*); template -using JoinPromises = typename JoinPromises_::Type; -// If T is Promise, resolves to U, otherwise resolves to T. -// -// TODO(cleanup): Rename to avoid confusion with joinPromises() call which is completely -// unrelated. +using ChainPromises = decltype(chainPromiseType((T*)nullptr)); +// Constructs a promise for T, reducing double-promises. That is, if T is Promise, resolves to +// Promise, otherwise resolves to Promise. + +template +Promise reducePromiseType(T*, ...); +template +Promise reducePromiseType(Promise*, ...); +template >()))> +Reduced reducePromiseType(T*, bool); + +template +using ReducePromises = decltype(reducePromiseType((T*)nullptr, false)); +// Like ChainPromises, but also takes into account whether T has a method `reducePromise` that +// reduces Promise to something else. In particular this allows Promise> +// to reduce to capnp::RemotePromise. + +template struct UnwrapPromise_; +template struct UnwrapPromise_> { typedef T Type; }; + +template +using UnwrapPromise = typename UnwrapPromise_::Type; class PropagateException { // A functor which accepts a kj::Exception as a parameter and returns a broken promise of @@ -90,7 +127,7 @@ using ReturnType = typename ReturnType_::Type; template struct SplitTuplePromise_ { typedef Promise Type; }; template struct SplitTuplePromise_> { - typedef kj::Tuple>...> Type; + typedef kj::Tuple...> Type; }; template @@ -171,10 +208,16 @@ class PromiseNode; class ChainPromiseNode; template class ForkHub; - -class TaskSetImpl; +class FiberStack; +class FiberBase; class Event; +class XThreadEvent; +class XThreadPaf; + +class PromiseDisposer; +using OwnPromiseNode = Own; +// PromiseNode uses a static disposer. class PromiseBase { public: @@ -182,37 +225,39 @@ class PromiseBase { // Dump debug info about this promise. private: - Own node; + OwnPromiseNode node; PromiseBase() = default; - PromiseBase(Own&& node): node(kj::mv(node)) {} + PromiseBase(OwnPromiseNode&& node): node(kj::mv(node)) {} - friend class kj::EventLoop; - friend class ChainPromiseNode; template friend class kj::Promise; - friend class TaskSetImpl; - template - friend Promise> kj::joinPromises(Array>&& promises); - friend Promise kj::joinPromises(Array>&& promises); + friend class PromiseNode; }; void detach(kj::Promise&& promise); -void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope& waitScope); +void waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, WaitScope& waitScope, + SourceLocation location); +bool pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location); Promise yield(); -Own neverDone(); +Promise yieldHarder(); +OwnPromiseNode readyNow(); +OwnPromiseNode neverDone(); + +class ReadyNow { +public: + operator Promise() const; +}; class NeverDone { public: template - operator Promise() const { - return Promise(false, neverDone()); - } + operator Promise() const; - KJ_NORETURN(void wait(WaitScope& waitScope) const); + KJ_NORETURN(void wait(WaitScope& waitScope, SourceLocation location = {}) const); }; } // namespace _ (private) } // namespace kj -#endif // KJ_ASYNC_PRELUDE_H_ +KJ_END_HEADER diff --git a/c++/src/kj/async-queue-test.c++ b/c++/src/kj/async-queue-test.c++ new file mode 100644 index 0000000000..3d8c8dd8fd --- /dev/null +++ b/c++/src/kj/async-queue-test.c++ @@ -0,0 +1,151 @@ +// Copyright (c) 2021 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "async-queue.h" + +#include +#include +#include + +namespace kj { +namespace { + +struct QueueTest { + kj::AsyncIoContext io = setupAsyncIo(); + ProducerConsumerQueue queue; + + QueueTest() = default; + QueueTest(QueueTest&&) = delete; + QueueTest(const QueueTest&) = delete; + QueueTest& operator=(QueueTest&&) = delete; + QueueTest& operator=(const QueueTest&) = delete; + + struct Producer { + QueueTest& test; + Promise promise = kj::READY_NOW; + + Producer(QueueTest& test): test(test) {} + + void push(size_t i) { + auto push = [&, i]() -> Promise { + test.queue.push(i); + return kj::READY_NOW; + }; + promise = promise.then(kj::mv(push)); + } + }; + + struct Consumer { + QueueTest& test; + Promise promise = kj::READY_NOW; + + Consumer(QueueTest& test): test(test) {} + + void pop(Vector& bits) { + auto pop = [&]() { + return test.queue.pop(); + }; + auto checkPop = [&](size_t j) -> Promise { + bits[j] = true; + return kj::READY_NOW; + }; + promise = promise.then(kj::mv(pop)).then(kj::mv(checkPop)); + } + }; +}; + +KJ_TEST("ProducerConsumerQueue with various amounts of producers and consumers") { + QueueTest test; + + size_t constexpr kItemCount = 1000; + for (auto producerCount: { 1, 5, 10 }) { + for (auto consumerCount: { 1, 5, 10 }) { + KJ_LOG(INFO, "Testing a new set of Producers and Consumers", // + producerCount, consumerCount, kItemCount); + // Make a vector to track our entries. + auto bits = Vector(kItemCount); + for (auto i KJ_UNUSED : kj::zeroTo(kItemCount)) { + bits.add(false); + } + + // Make enough producers. + auto producers = Vector(); + for (auto i KJ_UNUSED : kj::zeroTo(producerCount)) { + producers.add(test); + } + + // Make enough consumers. + auto consumers = Vector(); + for (auto i KJ_UNUSED : kj::zeroTo(consumerCount)) { + consumers.add(test); + } + + for (auto i : kj::zeroTo(kItemCount)) { + // Use a producer and a consumer for each entry. + + auto& producer = producers[i % producerCount]; + producer.push(i); + + auto& consumer = consumers[i % consumerCount]; + consumer.pop(bits); + } + + // Confirm that all entries are produced and consumed. + auto promises = Vector>(); + for (auto& producer: producers) { + promises.add(kj::mv(producer.promise)); + } + for (auto& consumer: consumers) { + promises.add(kj::mv(consumer.promise)); + } + joinPromises(promises.releaseAsArray()).wait(test.io.waitScope); + for (auto i : kj::zeroTo(kItemCount)) { + KJ_ASSERT(bits[i], i); + } + } + } +} + +KJ_TEST("ProducerConsumerQueue with rejectAll()") { + QueueTest test; + + for (auto consumerCount: { 1, 5, 10 }) { + KJ_LOG(INFO, "Testing a new set of consumers with rejection", consumerCount); + + // Make enough consumers. + auto promises = Vector>(); + for (auto i KJ_UNUSED : kj::zeroTo(consumerCount)) { + promises.add(test.queue.pop().ignoreResult()); + } + + for (auto& promise: promises) { + KJ_EXPECT(!promise.poll(test.io.waitScope), "All of our consumers should be waiting"); + } + test.queue.rejectAll(KJ_EXCEPTION(FAILED, "Total rejection")); + + // We should have finished and swallowed the errors. + auto promise = joinPromises(promises.releaseAsArray()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("Total rejection", promise.wait(test.io.waitScope)); + } +} + +} // namespace +} // namespace kj diff --git a/c++/src/kj/async-queue.h b/c++/src/kj/async-queue.h new file mode 100644 index 0000000000..7a815faa35 --- /dev/null +++ b/c++/src/kj/async-queue.h @@ -0,0 +1,156 @@ +// Copyright (c) 2021 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "async.h" +#include +#include +#include +#include + +#include + +KJ_BEGIN_HEADER + +namespace kj { + +template +class WaiterQueue { +public: + // A WaiterQueue creates Nodes that blend newAdaptedPromise and List. + + WaiterQueue() = default; + KJ_DISALLOW_COPY_AND_MOVE(WaiterQueue); + + Promise wait() { + return newAdaptedPromise(queue); + } + + void fulfill(T&& value) { + KJ_IREQUIRE(!empty()); + + auto& node = static_cast(queue.front()); + node.fulfiller.fulfill(kj::mv(value)); + node.remove(); + } + + void reject(Exception&& exception) { + KJ_IREQUIRE(!empty()); + + auto& node = static_cast(queue.front()); + node.fulfiller.reject(kj::mv(exception)); + node.remove(); + } + + bool empty() const { + return queue.empty(); + } + +private: + struct BaseNode { + // This is a separate structure because List requires a predefined memory layout but + // newAdaptedPromise() only provides access to the Adaptor type in the ctor. + + BaseNode(PromiseFulfiller& fulfiller): fulfiller(fulfiller) {} + + PromiseFulfiller& fulfiller; + ListLink link; + }; + + using Queue = List; + + struct Node: public BaseNode { + Node(PromiseFulfiller& fulfiller, Queue& queue): BaseNode(fulfiller), queue(queue) { + queue.add(*this); + } + + ~Node() noexcept(false) { + // When the associated Promise is destructed, so is the Node thus we should leave the queue. + remove(); + } + + void remove() { + if(BaseNode::link.isLinked()){ + queue.remove(*this); + } + } + + Queue& queue; + }; + + Queue queue; +}; + +template +class ProducerConsumerQueue { + // ProducerConsumerQueue is an async FIFO queue. + +public: + void push(T v) { + // Push an existing value onto the queue. + + if (!waiters.empty()) { + // We have at least one waiter, give the value to the oldest. + KJ_IASSERT(values.empty()); + + // Fulfill the first waiter and return without store our value. + waiters.fulfill(kj::mv(v)); + } else { + // We don't have any waiters, store the value. + values.push_front(kj::mv(v)); + } + } + + void rejectAll(Exception e) { + // Reject all waiters with a given exception. + + while (!waiters.empty()) { + auto newE = Exception(e); + waiters.reject(kj::mv(newE)); + } + } + + Promise pop() { + // Eventually pop a value from the queue. + // Note that if your sinks lag your sources, the promise will always be ready. + + if (!values.empty()) { + // We have at least one value, get the oldest. + KJ_IASSERT(waiters.empty()); + + auto value = kj::mv(values.back()); + values.pop_back(); + return kj::mv(value); + } else { + // We don't have any values, add ourselves to the waiting queue. + return waiters.wait(); + } + } + +private: + std::list values; + WaiterQueue waiters; +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/async-test.c++ b/c++/src/kj/async-test.c++ index 8a0f3d31fb..3ac6024cae 100644 --- a/c++/src/kj/async-test.c++ +++ b/c++/src/kj/async-test.c++ @@ -22,6 +22,17 @@ #include "async.h" #include "debug.h" #include +#include "mutex.h" +#include "thread.h" + +#if !KJ_USE_FIBERS && !_WIN32 +#include +#endif + +#if KJ_USE_FIBERS && __linux__ +#include +#include +#endif namespace kj { namespace { @@ -33,6 +44,36 @@ TEST(Async, GetFunctorStartAddress) { } #endif +#if KJ_USE_FIBERS +bool isLibcContextHandlingKnownBroken() { + // manylinux2014-x86's libc implements getcontext() to fail with ENOSYS. This is flagrantly + // against spec: getcontext() is not a syscall and is documented as never failing. Our configure + // script cannot detect this problem because it would require actually executing code to see + // what happens, which wouldn't work when cross-compiling. It would have been so much better if + // they had removed the symbol from libc entirely. But as a work-around, we will skip the tests + // when libc is broken. +#if __linux__ + static bool result = ([]() { + ucontext_t context; + if (getcontext(&context) < 0 && errno == ENOSYS) { + KJ_LOG(WARNING, + "This platform's libc is broken. Its getcontext() errors with ENOSYS. Fibers will not " + "work, so we'll skip the tests, but libkj was still built with fiber support, which " + "is broken. Please tell your libc maitnainer to remove the getcontext() function " + "entirely rather than provide an intentionally-broken version -- that way, the " + "configure script will detect that it should build libkj without fiber support."); + return true; + } else { + return false; + } + })(); + return result; +#else + return false; +#endif +} +#endif + TEST(Async, EvalVoid) { EventLoop loop; WaitScope waitScope(loop); @@ -186,9 +227,9 @@ TEST(Async, DeepChain) { // Create a ridiculous chain of promises. for (uint i = 0; i < 1000; i++) { - promise = evalLater(mvCapture(promise, [](Promise promise) { + promise = evalLater([promise=kj::mv(promise)]() mutable { return kj::mv(promise); - })); + }); } loop.run(); @@ -223,9 +264,9 @@ TEST(Async, DeepChain2) { // Create a ridiculous chain of promises. for (uint i = 0; i < 1000; i++) { - promise = evalLater(mvCapture(promise, [](Promise promise) { + promise = evalLater([promise=kj::mv(promise)]() mutable { return kj::mv(promise); - })); + }); } promise.wait(waitScope); @@ -262,9 +303,9 @@ TEST(Async, DeepChain3) { Promise makeChain2(uint i, Promise promise) { if (i > 0) { - return evalLater(mvCapture(promise, [i](Promise&& promise) -> Promise { + return evalLater([i, promise=kj::mv(promise)]() mutable -> Promise { return makeChain2(i - 1, kj::mv(promise)); - })); + }); } else { return kj::mv(promise); } @@ -364,12 +405,30 @@ TEST(Async, SeparateFulfillerDiscarded) { EventLoop loop; WaitScope waitScope(loop); - auto pair = newPromiseAndFulfiller(); + auto pair = newPromiseAndFulfiller(); pair.fulfiller = nullptr; - EXPECT_ANY_THROW(pair.promise.wait(waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "PromiseFulfiller was destroyed without fulfilling the promise", + pair.promise.wait(waitScope)); } +#if !KJ_NO_EXCEPTIONS +TEST(Async, SeparateFulfillerDiscardedDuringUnwind) { + EventLoop loop; + WaitScope waitScope(loop); + + auto pair = newPromiseAndFulfiller(); + kj::runCatchingExceptions([&]() { + auto fulfillerToDrop = kj::mv(pair.fulfiller); + kj::throwFatalException(KJ_EXCEPTION(FAILED, "test exception")); + }); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "test exception", pair.promise.wait(waitScope)); +} +#endif + TEST(Async, SeparateFulfillerMemoryLeak) { auto paf = kj::newPromiseAndFulfiller(); paf.fulfiller->fulfill(); @@ -379,57 +438,75 @@ TEST(Async, Ordering) { EventLoop loop; WaitScope waitScope(loop); + class ErrorHandlerImpl: public TaskSet::ErrorHandler { + public: + void taskFailed(kj::Exception&& exception) override { + KJ_FAIL_EXPECT(exception); + } + }; + int counter = 0; - Promise promises[6] = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + ErrorHandlerImpl errorHandler; + kj::TaskSet tasks(errorHandler); - promises[1] = evalLater([&]() { + tasks.add(evalLater([&]() { EXPECT_EQ(0, counter++); { // Use a promise and fulfiller so that we can fulfill the promise after waiting on it in // order to induce depth-first scheduling. auto paf = kj::newPromiseAndFulfiller(); - promises[2] = paf.promise.then([&]() { + tasks.add(paf.promise.then([&]() { EXPECT_EQ(1, counter++); - }).eagerlyEvaluate(nullptr); + })); paf.fulfiller->fulfill(); } // .then() is scheduled breadth-first if the promise has already resolved, but depth-first // if the promise resolves later. - promises[3] = Promise(READY_NOW).then([&]() { + tasks.add(Promise(READY_NOW).then([&]() { EXPECT_EQ(4, counter++); }).then([&]() { EXPECT_EQ(5, counter++); - }).eagerlyEvaluate(nullptr); + tasks.add(kj::evalLast([&]() { + EXPECT_EQ(7, counter++); + tasks.add(kj::evalLater([&]() { + EXPECT_EQ(8, counter++); + })); + })); + })); { auto paf = kj::newPromiseAndFulfiller(); - promises[4] = paf.promise.then([&]() { + tasks.add(paf.promise.then([&]() { EXPECT_EQ(2, counter++); - }).eagerlyEvaluate(nullptr); + tasks.add(kj::evalLast([&]() { + EXPECT_EQ(9, counter++); + tasks.add(kj::evalLater([&]() { + EXPECT_EQ(10, counter++); + })); + })); + })); paf.fulfiller->fulfill(); } // evalLater() is like READY_NOW.then(). - promises[5] = evalLater([&]() { + tasks.add(evalLater([&]() { EXPECT_EQ(6, counter++); - }).eagerlyEvaluate(nullptr); - }).eagerlyEvaluate(nullptr); + })); + })); - promises[0] = evalLater([&]() { + tasks.add(evalLater([&]() { EXPECT_EQ(3, counter++); - // Making this a chain should NOT cause it to preempt promises[1]. (This was a problem at one - // point.) + // Making this a chain should NOT cause it to preempt the first promise. (This was a problem + // at one point.) return Promise(READY_NOW); - }).eagerlyEvaluate(nullptr); + })); - for (auto i: indices(promises)) { - kj::mv(promises[i]).wait(waitScope); - } + tasks.onEmpty().wait(waitScope); - EXPECT_EQ(7, counter); + EXPECT_EQ(11, counter); } TEST(Async, Fork) { @@ -440,14 +517,28 @@ TEST(Async, Fork) { auto fork = promise.fork(); +#if __GNUC__ && !__clang__ && __GNUC__ >= 7 +// GCC 7 decides the open-brace below is "misleadingly indented" as if it were guarded by the `for` +// that appears in the implementation of KJ_REQUIRE(). Shut up shut up shut up. +#pragma GCC diagnostic ignored "-Wmisleading-indentation" +#endif + KJ_ASSERT(!fork.hasBranches()); + { + auto cancelBranch = fork.addBranch(); + KJ_ASSERT(fork.hasBranches()); + } + KJ_ASSERT(!fork.hasBranches()); + auto branch1 = fork.addBranch().then([](int i) { EXPECT_EQ(123, i); return 456; }); + KJ_ASSERT(fork.hasBranches()); auto branch2 = fork.addBranch().then([](int i) { EXPECT_EQ(123, i); return 789; }); + KJ_ASSERT(fork.hasBranches()); { auto releaseFork = kj::mv(fork); @@ -490,6 +581,34 @@ TEST(Async, ForkRef) { EXPECT_EQ(789, branch2.wait(waitScope)); } +TEST(Async, ForkMaybeRef) { + EventLoop loop; + WaitScope waitScope(loop); + + Promise>> promise = evalLater([&]() { + return Maybe>(refcounted(123)); + }); + + auto fork = promise.fork(); + + auto branch1 = fork.addBranch().then([](Maybe>&& i) { + EXPECT_EQ(123, KJ_REQUIRE_NONNULL(i)->i); + return 456; + }); + auto branch2 = fork.addBranch().then([](Maybe>&& i) { + EXPECT_EQ(123, KJ_REQUIRE_NONNULL(i)->i); + return 789; + }); + + { + auto releaseFork = kj::mv(fork); + } + + EXPECT_EQ(456, branch1.wait(waitScope)); + EXPECT_EQ(789, branch2.wait(waitScope)); +} + + TEST(Async, Split) { EventLoop loop; WaitScope waitScope(loop); @@ -548,36 +667,167 @@ TEST(Async, ExclusiveJoin) { } TEST(Async, ArrayJoin) { + for (auto specificJoinPromisesOverload: { + +[](kj::Array> promises) { return joinPromises(kj::mv(promises)); }, + +[](kj::Array> promises) { return joinPromisesFailFast(kj::mv(promises)); } + }) { + EventLoop loop; + WaitScope waitScope(loop); + + auto builder = heapArrayBuilder>(3); + builder.add(123); + builder.add(456); + builder.add(789); + + Promise> promise = specificJoinPromisesOverload(builder.finish()); + + auto result = promise.wait(waitScope); + + ASSERT_EQ(3u, result.size()); + EXPECT_EQ(123, result[0]); + EXPECT_EQ(456, result[1]); + EXPECT_EQ(789, result[2]); + } +} + +TEST(Async, ArrayJoinVoid) { + for (auto specificJoinPromisesOverload: { + +[](kj::Array> promises) { return joinPromises(kj::mv(promises)); }, + +[](kj::Array> promises) { return joinPromisesFailFast(kj::mv(promises)); } + }) { + EventLoop loop; + WaitScope waitScope(loop); + + auto builder = heapArrayBuilder>(3); + builder.add(READY_NOW); + builder.add(READY_NOW); + builder.add(READY_NOW); + + Promise promise = specificJoinPromisesOverload(builder.finish()); + + promise.wait(waitScope); + } +} + +struct Pafs { + kj::Array> promises; + kj::Array>> fulfillers; +}; + +Pafs makeCompletionCountingPafs(uint count, uint& tasksCompleted) { + auto promisesBuilder = heapArrayBuilder>(count); + auto fulfillersBuilder = heapArrayBuilder>>(count); + + for (auto KJ_UNUSED value: zeroTo(count)) { + auto paf = newPromiseAndFulfiller(); + promisesBuilder.add(paf.promise.then([&tasksCompleted]() { + ++tasksCompleted; + })); + fulfillersBuilder.add(kj::mv(paf.fulfiller)); + } + + return { promisesBuilder.finish(), fulfillersBuilder.finish() }; +} + +TEST(Async, ArrayJoinException) { EventLoop loop; WaitScope waitScope(loop); - auto builder = heapArrayBuilder>(3); - builder.add(123); - builder.add(456); - builder.add(789); + uint tasksCompleted = 0; + auto pafs = makeCompletionCountingPafs(5, tasksCompleted); + auto& fulfillers = pafs.fulfillers; + Promise promise = joinPromises(kj::mv(pafs.promises)); - Promise> promise = joinPromises(builder.finish()); + { + uint i = 0; + KJ_EXPECT(tasksCompleted == 0); + + // Joined tasks are not completed early. + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == 0); + + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == 0); + + // Rejected tasks do not fail-fast. + fulfillers[i++]->reject(KJ_EXCEPTION(FAILED, "Test exception")); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == 0); + + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == 0); + + // The final fulfillment makes the promise ready. + fulfillers[i++]->fulfill(); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("Test exception", promise.wait(waitScope)); + KJ_EXPECT(tasksCompleted == 4); + } +} - auto result = promise.wait(waitScope); +TEST(Async, ArrayJoinFailFastException) { + EventLoop loop; + WaitScope waitScope(loop); - ASSERT_EQ(3u, result.size()); - EXPECT_EQ(123, result[0]); - EXPECT_EQ(456, result[1]); - EXPECT_EQ(789, result[2]); + uint tasksCompleted = 0; + auto pafs = makeCompletionCountingPafs(5, tasksCompleted); + auto& fulfillers = pafs.fulfillers; + Promise promise = joinPromisesFailFast(kj::mv(pafs.promises)); + + { + uint i = 0; + KJ_EXPECT(tasksCompleted == 0); + + // Joined tasks are completed eagerly, not waiting until the join node is awaited. + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == i); + + fulfillers[i++]->fulfill(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(tasksCompleted == i); + + fulfillers[i++]->reject(KJ_EXCEPTION(FAILED, "Test exception")); + + // The first rejection makes the promise ready. + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("Test exception", promise.wait(waitScope)); + KJ_EXPECT(tasksCompleted == i - 1); + } } -TEST(Async, ArrayJoinVoid) { +TEST(Async, Canceler) { EventLoop loop; WaitScope waitScope(loop); + Canceler canceler; - auto builder = heapArrayBuilder>(3); - builder.add(READY_NOW); - builder.add(READY_NOW); - builder.add(READY_NOW); + auto never = canceler.wrap(kj::Promise(kj::NEVER_DONE)); + auto now = canceler.wrap(kj::Promise(kj::READY_NOW)); + auto neverI = canceler.wrap(kj::Promise(kj::NEVER_DONE).then([]() { return 123u; })); + auto nowI = canceler.wrap(kj::Promise(123u)); - Promise promise = joinPromises(builder.finish()); + KJ_EXPECT(!never.poll(waitScope)); + KJ_EXPECT(now.poll(waitScope)); + KJ_EXPECT(!neverI.poll(waitScope)); + KJ_EXPECT(nowI.poll(waitScope)); - promise.wait(waitScope); + canceler.cancel("foobar"); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("foobar", never.wait(waitScope)); + now.wait(waitScope); + KJ_EXPECT_THROW_MESSAGE("foobar", neverI.wait(waitScope)); + KJ_EXPECT(nowI.wait(waitScope) == 123u); +} + +TEST(Async, CancelerDoubleWrap) { + EventLoop loop; + WaitScope waitScope(loop); + + // This used to crash. + Canceler canceler; + auto promise = canceler.wrap(canceler.wrap(kj::Promise(kj::NEVER_DONE))); + canceler.cancel("whoops"); } class ErrorHandlerImpl: public TaskSet::ErrorHandler { @@ -608,7 +858,7 @@ TEST(Async, TaskSet) { EXPECT_EQ(2, counter++); })); - (void)evalLater([&]() { + auto ignore KJ_UNUSED = evalLater([&]() { KJ_FAIL_EXPECT("Promise without waiter shouldn't execute."); }); @@ -620,6 +870,153 @@ TEST(Async, TaskSet) { EXPECT_EQ(1u, errorHandler.exceptionCount); } +#if KJ_USE_FIBERS || !_WIN32 +// This test requires either fibers or pthreads in order to limit the stack size. Currently we +// don't have a version that works on Windows without fibers, so skip the test there. + +TEST(Async, LargeTaskSetDestruction) { + static constexpr size_t stackSize = 200 * 1024; + + static auto testBody = [] { + + ErrorHandlerImpl errorHandler; + TaskSet tasks(errorHandler); + + for (int i = 0; i < stackSize / sizeof(void*); i++) { + tasks.add(kj::NEVER_DONE); + } + }; + +#if KJ_USE_FIBERS + if (isLibcContextHandlingKnownBroken()) return; + + EventLoop loop; + WaitScope waitScope(loop); + + startFiber(stackSize, + [](WaitScope&) mutable { + testBody(); + }).wait(waitScope); + +#else + pthread_attr_t attr; + KJ_REQUIRE(0 == pthread_attr_init(&attr)); + KJ_DEFER(KJ_REQUIRE(0 == pthread_attr_destroy(&attr))); + + KJ_REQUIRE(0 == pthread_attr_setstacksize(&attr, stackSize)); + pthread_t thread; + KJ_REQUIRE(0 == pthread_create(&thread, &attr, [](void*) -> void* { + EventLoop loop; + WaitScope waitScope(loop); + testBody(); + return nullptr; + }, nullptr)); + KJ_REQUIRE(0 == pthread_join(thread, nullptr)); +#endif +} + +#endif // KJ_USE_FIBERS || !_WIN32 + +TEST(Async, TaskSet) { + EventLoop loop; + WaitScope waitScope(loop); + + bool destroyed = false; + + { + ErrorHandlerImpl errorHandler; + TaskSet tasks(errorHandler); + + tasks.add(kj::Promise(kj::NEVER_DONE) + .attach(kj::defer([&]() { + // During cancellation, append another task! + // It had better be canceled too! + tasks.add(kj::Promise(kj::READY_NOW) + .then([]() { KJ_FAIL_EXPECT("shouldn't get here"); }, + [](auto) { KJ_FAIL_EXPECT("shouldn't get here"); }) + .attach(kj::defer([&]() { + destroyed = true; + }))); + }))); + } + + KJ_EXPECT(destroyed); + + // Give a chance for the "shouldn't get here" asserts to execute, if the event is still running, + // which it shouldn't be. + waitScope.poll(); +} + +TEST(Async, TaskSetOnEmpty) { + EventLoop loop; + WaitScope waitScope(loop); + ErrorHandlerImpl errorHandler; + TaskSet tasks(errorHandler); + + KJ_EXPECT(tasks.isEmpty()); + + auto paf = newPromiseAndFulfiller(); + tasks.add(kj::mv(paf.promise)); + tasks.add(evalLater([]() {})); + + KJ_EXPECT(!tasks.isEmpty()); + + auto promise = tasks.onEmpty(); + KJ_EXPECT(!promise.poll(waitScope)); + KJ_EXPECT(!tasks.isEmpty()); + + paf.fulfiller->fulfill(); + KJ_ASSERT(promise.poll(waitScope)); + KJ_EXPECT(tasks.isEmpty()); + promise.wait(waitScope); +} + +KJ_TEST("TaskSet::clear()") { + EventLoop loop; + WaitScope waitScope(loop); + + class ClearOnError: public TaskSet::ErrorHandler { + public: + TaskSet* tasks; + void taskFailed(kj::Exception&& exception) override { + KJ_EXPECT(exception.getDescription().endsWith("example TaskSet failure")); + tasks->clear(); + } + }; + + ClearOnError errorHandler; + TaskSet tasks(errorHandler); + errorHandler.tasks = &tasks; + + auto doTest = [&](auto&& causeClear) { + KJ_EXPECT(tasks.isEmpty()); + + uint count = 0; + tasks.add(kj::Promise(kj::READY_NOW).attach(kj::defer([&]() { ++count; }))); + tasks.add(kj::Promise(kj::NEVER_DONE).attach(kj::defer([&]() { ++count; }))); + tasks.add(kj::Promise(kj::NEVER_DONE).attach(kj::defer([&]() { ++count; }))); + + auto onEmpty = tasks.onEmpty(); + KJ_EXPECT(!onEmpty.poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(!tasks.isEmpty()); + + causeClear(); + KJ_EXPECT(tasks.isEmpty()); + onEmpty.wait(waitScope); + KJ_EXPECT(count == 3); + }; + + // Try it where we just call clear() directly. + doTest([&]() { tasks.clear(); }); + + // Try causing clear() inside taskFailed(), ensuring that this is permitted. + doTest([&]() { + tasks.add(KJ_EXCEPTION(FAILED, "example TaskSet failure")); + waitScope.poll(); + }); +} + class DestructorDetector { public: DestructorDetector(bool& setTrue): setTrue(setTrue) {} @@ -678,7 +1075,10 @@ TEST(Async, Detach) { bool ran2 = false; bool ran3 = false; - (void)evalLater([&]() { ran1 = true; }); // let returned promise be destroyed (canceled) + { + // let returned promise be destroyed (canceled) + auto ignore KJ_UNUSED = evalLater([&]() { ran1 = true; }); + } evalLater([&]() { ran2 = true; }).detach([](kj::Exception&&) { ADD_FAILURE(); }); evalLater([]() { KJ_FAIL_ASSERT("foo"){break;} }).detach([&](kj::Exception&& e) { ran3 = true; }); @@ -748,5 +1148,600 @@ TEST(Async, SetRunnable) { } } +TEST(Async, Poll) { + EventLoop loop; + WaitScope waitScope(loop); + + auto paf = newPromiseAndFulfiller(); + KJ_ASSERT(!paf.promise.poll(waitScope)); + paf.fulfiller->fulfill(); + KJ_ASSERT(paf.promise.poll(waitScope)); + paf.promise.wait(waitScope); +} + +KJ_TEST("Maximum turn count during wait scope poll is enforced") { + EventLoop loop; + WaitScope waitScope(loop); + ErrorHandlerImpl errorHandler; + TaskSet tasks(errorHandler); + + auto evaluated1 = false; + tasks.add(evalLater([&]() { + evaluated1 = true; + })); + + auto evaluated2 = false; + tasks.add(evalLater([&]() { + evaluated2 = true; + })); + + auto evaluated3 = false; + tasks.add(evalLater([&]() { + evaluated3 = true; + })); + + uint count; + + // Check that only events up to a maximum are resolved: + count = waitScope.poll(2); + KJ_ASSERT(count == 2); + KJ_EXPECT(evaluated1); + KJ_EXPECT(evaluated2); + KJ_EXPECT(!evaluated3); + + // Get the last remaining event in the queue: + count = waitScope.poll(1); + KJ_ASSERT(count == 1); + KJ_EXPECT(evaluated3); + + // No more events: + count = waitScope.poll(1); + KJ_ASSERT(count == 0); +} + +KJ_TEST("exclusiveJoin both events complete simultaneously") { + // Previously, if both branches of an exclusiveJoin() completed simultaneously, then the parent + // event could be armed twice. This is an error, but the exact results of this error depend on + // the parent PromiseNode type. One case where it matters is ArrayJoinPromiseNode, which counts + // events and decides it is done when it has received exactly the number of events expected. + + EventLoop loop; + WaitScope waitScope(loop); + + auto builder = kj::heapArrayBuilder>(2); + builder.add(kj::Promise(123).exclusiveJoin(kj::Promise(456))); + builder.add(kj::NEVER_DONE); + auto joined = kj::joinPromises(builder.finish()); + + KJ_EXPECT(!joined.poll(waitScope)); +} + +#if KJ_USE_FIBERS +KJ_TEST("start a fiber") { + if (isLibcContextHandlingKnownBroken()) return; + + EventLoop loop; + WaitScope waitScope(loop); + + auto paf = newPromiseAndFulfiller(); + + Promise fiber = startFiber(65536, + [promise = kj::mv(paf.promise)](WaitScope& fiberScope) mutable { + int i = promise.wait(fiberScope); + KJ_EXPECT(i == 123); + return "foo"_kj; + }); + + KJ_EXPECT(!fiber.poll(waitScope)); + + paf.fulfiller->fulfill(123); + + KJ_ASSERT(fiber.poll(waitScope)); + KJ_EXPECT(fiber.wait(waitScope) == "foo"); +} + +KJ_TEST("fiber promise chaining") { + if (isLibcContextHandlingKnownBroken()) return; + + EventLoop loop; + WaitScope waitScope(loop); + + auto paf = newPromiseAndFulfiller(); + bool ran = false; + + Promise fiber = startFiber(65536, + [promise = kj::mv(paf.promise), &ran](WaitScope& fiberScope) mutable { + ran = true; + return kj::mv(promise); + }); + + KJ_EXPECT(!ran); + KJ_EXPECT(!fiber.poll(waitScope)); + KJ_EXPECT(ran); + + paf.fulfiller->fulfill(123); + + KJ_ASSERT(fiber.poll(waitScope)); + KJ_EXPECT(fiber.wait(waitScope) == 123); +} + +KJ_TEST("throw from a fiber") { + if (isLibcContextHandlingKnownBroken()) return; + + EventLoop loop; + WaitScope waitScope(loop); + + auto paf = newPromiseAndFulfiller(); + + Promise fiber = startFiber(65536, + [promise = kj::mv(paf.promise)](WaitScope& fiberScope) mutable { + promise.wait(fiberScope); + KJ_FAIL_EXPECT("wait() should have thrown"); + }); + + KJ_EXPECT(!fiber.poll(waitScope)); + + paf.fulfiller->reject(KJ_EXCEPTION(FAILED, "test exception")); + + KJ_ASSERT(fiber.poll(waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test exception", fiber.wait(waitScope)); +} + +#if !__MINGW32__ || __MINGW64__ +// This test fails on MinGW 32-bit builds due to a compiler bug with exceptions + fibers: +// https://sourceforge.net/p/mingw-w64/bugs/835/ +KJ_TEST("cancel a fiber") { + if (isLibcContextHandlingKnownBroken()) return; + + EventLoop loop; + WaitScope waitScope(loop); + + // When exceptions are disabled we can't wait() on a non-void promise that throws. + auto paf = newPromiseAndFulfiller(); + + bool exited = false; + bool canceled = false; + + { + Promise fiber = startFiber(65536, + [promise = kj::mv(paf.promise), &exited, &canceled](WaitScope& fiberScope) mutable { + KJ_DEFER(exited = true); + try { + promise.wait(fiberScope); + } catch (kj::CanceledException) { + canceled = true; + throw; + } + return "foo"_kj; + }); + + KJ_EXPECT(!fiber.poll(waitScope)); + KJ_EXPECT(!exited); + KJ_EXPECT(!canceled); + } + + KJ_EXPECT(exited); + KJ_EXPECT(canceled); +} +#endif + +KJ_TEST("fiber pool") { + if (isLibcContextHandlingKnownBroken()) return; + + EventLoop loop; + WaitScope waitScope(loop); + FiberPool pool(65536); + + int* i1_local = nullptr; + int* i2_local = nullptr; + + auto run = [&]() mutable { + auto paf1 = newPromiseAndFulfiller(); + auto paf2 = newPromiseAndFulfiller(); + + { + Promise fiber1 = pool.startFiber([&, promise = kj::mv(paf1.promise)](WaitScope& scope) mutable { + int i = promise.wait(scope); + KJ_EXPECT(i == 123); + if (i1_local == nullptr) { + i1_local = &i; + } else { +#if !KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // Verify that the stack variable is in the exact same spot as before. + // May not work under ASAN as the instrumentation to detect stack-use-after-return can + // change the address. + KJ_ASSERT(i1_local == &i); +#endif + } + return i; + }); + { + Promise fiber2 = pool.startFiber([&, promise = kj::mv(paf2.promise)](WaitScope& scope) mutable { + int i = promise.wait(scope); + KJ_EXPECT(i == 456); + if (i2_local == nullptr) { + i2_local = &i; + } else { +#if !KJ_HAS_COMPILER_FEATURE(address_sanitizer) + KJ_ASSERT(i2_local == &i); +#endif + } + return i; + }); + + KJ_EXPECT(!fiber1.poll(waitScope)); + KJ_EXPECT(!fiber2.poll(waitScope)); + + KJ_EXPECT(pool.getFreelistSize() == 0); + + paf2.fulfiller->fulfill(456); + + KJ_EXPECT(!fiber1.poll(waitScope)); + KJ_ASSERT(fiber2.poll(waitScope)); + KJ_EXPECT(fiber2.wait(waitScope) == 456); + + KJ_EXPECT(pool.getFreelistSize() == 1); + } + + paf1.fulfiller->fulfill(123); + + KJ_ASSERT(fiber1.poll(waitScope)); + KJ_EXPECT(fiber1.wait(waitScope) == 123); + + KJ_EXPECT(pool.getFreelistSize() == 2); + } + }; + run(); + KJ_ASSERT(i1_local != nullptr); + KJ_ASSERT(i2_local != nullptr); + // run the same thing and reuse the fibers + run(); +} + +bool onOurStack(char* p) { + // If p points less than 64k away from a random stack variable, then it must be on the same + // stack, since we never allocate stacks smaller than 64k. +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // The stack-use-after-return detection mechanism breaks our ability to check this, so don't. + return true; +#else + char c; + ptrdiff_t diff = p - &c; + return diff < 65536 && diff > -65536; +#endif +} + +bool notOnOurStack(char* p) { + // Opposite of onOurStack(), except returns true if the check can't be performed. +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // The stack-use-after-return detection mechanism breaks our ability to check this, so don't. + return true; +#else + return !onOurStack(p); +#endif +} + +KJ_TEST("fiber pool runSynchronously()") { + if (isLibcContextHandlingKnownBroken()) return; + + FiberPool pool(65536); + + { + char c; + KJ_EXPECT(onOurStack(&c)); // sanity check... + } + + char* ptr1 = nullptr; + char* ptr2 = nullptr; + + pool.runSynchronously([&]() { + char c; + ptr1 = &c; + }); + KJ_ASSERT(ptr1 != nullptr); + + pool.runSynchronously([&]() { + char c; + ptr2 = &c; + }); + KJ_ASSERT(ptr2 != nullptr); + +#if !KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // Should have used the same stack both times, so local var would be in the same place. + // Under ASAN, the stack-use-after-return detection correctly fires on this, so we skip the check. + KJ_EXPECT(ptr1 == ptr2); +#endif + + // Should have been on a different stack from the main stack. + KJ_EXPECT(notOnOurStack(ptr1)); + + KJ_EXPECT_THROW_MESSAGE("test exception", + pool.runSynchronously([&]() { KJ_FAIL_ASSERT("test exception"); })); +} + +KJ_TEST("fiber pool limit") { + if (isLibcContextHandlingKnownBroken()) return; + + FiberPool pool(65536); + + pool.setMaxFreelist(1); + + kj::MutexGuarded state; + + char* ptr1; + char* ptr2; + + // Run some code that uses two stacks in separate threads at the same time. + { + kj::Thread thread([&]() noexcept { + auto lock = state.lockExclusive(); + lock.wait([](uint val) { return val == 1; }); + + pool.runSynchronously([&]() { + char c; + ptr2 = &c; + + *lock = 2; + lock.wait([](uint val) { return val == 3; }); + }); + }); + + ([&]() noexcept { + auto lock = state.lockExclusive(); + + pool.runSynchronously([&]() { + char c; + ptr1 = &c; + + *lock = 1; + lock.wait([](uint val) { return val == 2; }); + }); + + *lock = 3; + })(); + } + + KJ_EXPECT(pool.getFreelistSize() == 1); + + // We expect that if we reuse a stack from the pool, it will be the last one that exited, which + // is the one from the thread. + pool.runSynchronously([&]() { + KJ_EXPECT(onOurStack(ptr2)); + KJ_EXPECT(notOnOurStack(ptr1)); + + KJ_EXPECT(pool.getFreelistSize() == 0); + }); + + KJ_EXPECT(pool.getFreelistSize() == 1); + + // Note that it would NOT work to try to allocate two stacks at the same time again and verify + // that the second stack doesn't match the previously-deleted stack, because there's a high + // likelihood that the new stack would be allocated in the same location. +} + +#if __GNUC__ >= 12 && !__clang__ +// The test below intentionally takes a pointer to a stack variable and stores it past the end +// of the function. This seems to trigger a warning in newer GCCs. +#pragma GCC diagnostic ignored "-Wdangling-pointer" +#endif + +KJ_TEST("run event loop on freelisted stacks") { + if (isLibcContextHandlingKnownBroken()) return; + + FiberPool pool(65536); + + class MockEventPort: public EventPort { + public: + bool wait() override { + char c; + waitStack = &c; + KJ_IF_MAYBE(f, fulfiller) { + f->get()->fulfill(); + fulfiller = nullptr; + } + return false; + } + bool poll() override { + char c; + pollStack = &c; + KJ_IF_MAYBE(f, fulfiller) { + f->get()->fulfill(); + fulfiller = nullptr; + } + return false; + } + + char* waitStack = nullptr; + char* pollStack = nullptr; + + kj::Maybe>> fulfiller; + }; + + MockEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + waitScope.runEventCallbacksOnStackPool(pool); + + { + auto paf = newPromiseAndFulfiller(); + port.fulfiller = kj::mv(paf.fulfiller); + + char* ptr1 = nullptr; + char* ptr2 = nullptr; + kj::evalLater([&]() { + char c; + ptr1 = &c; + return kj::mv(paf.promise); + }).then([&]() { + char c; + ptr2 = &c; + }).wait(waitScope); + + KJ_EXPECT(ptr1 != nullptr); + KJ_EXPECT(ptr2 != nullptr); + KJ_EXPECT(port.waitStack != nullptr); + KJ_EXPECT(port.pollStack == nullptr); + + // The event callbacks should have run on a different stack, but the wait should have been on + // the main stack. + KJ_EXPECT(notOnOurStack(ptr1)); + KJ_EXPECT(notOnOurStack(ptr2)); + KJ_EXPECT(onOurStack(port.waitStack)); + + pool.runSynchronously([&]() { + // This should run on the same stack where the event callbacks ran. + KJ_EXPECT(onOurStack(ptr1)); + KJ_EXPECT(onOurStack(ptr2)); + KJ_EXPECT(notOnOurStack(port.waitStack)); + }); + } + + port.waitStack = nullptr; + port.pollStack = nullptr; + + // Now try poll() instead of wait(). Note that since poll() doesn't block, we let it run on the + // event stack. + { + auto paf = newPromiseAndFulfiller(); + port.fulfiller = kj::mv(paf.fulfiller); + + char* ptr1 = nullptr; + char* ptr2 = nullptr; + auto promise = kj::evalLater([&]() { + char c; + ptr1 = &c; + return kj::mv(paf.promise); + }).then([&]() { + char c; + ptr2 = &c; + }); + + KJ_EXPECT(promise.poll(waitScope)); + + KJ_EXPECT(ptr1 != nullptr); + KJ_EXPECT(ptr2 == nullptr); // didn't run because of lazy continuation evaluation + KJ_EXPECT(port.waitStack == nullptr); + KJ_EXPECT(port.pollStack != nullptr); + + // The event callback should have run on a different stack, and poll() should have run on + // a separate stack too. + KJ_EXPECT(notOnOurStack(ptr1)); + KJ_EXPECT(notOnOurStack(port.pollStack)); + + pool.runSynchronously([&]() { + // This should run on the same stack where the event callbacks ran. + KJ_EXPECT(onOurStack(ptr1)); + KJ_EXPECT(onOurStack(port.pollStack)); + }); + } +} +#endif + +KJ_TEST("retryOnDisconnect") { + EventLoop loop; + WaitScope waitScope(loop); + + { + uint i = 0; + auto promise = retryOnDisconnect([&]() -> Promise { + i++; + return 123; + }); + KJ_EXPECT(i == 0); + KJ_EXPECT(promise.wait(waitScope) == 123); + KJ_EXPECT(i == 1); + } + + { + uint i = 0; + auto promise = retryOnDisconnect([&]() -> Promise { + if (i++ == 0) { + return KJ_EXCEPTION(DISCONNECTED, "test disconnect"); + } else { + return 123; + } + }); + KJ_EXPECT(i == 0); + KJ_EXPECT(promise.wait(waitScope) == 123); + KJ_EXPECT(i == 2); + } + + + { + uint i = 0; + auto promise = retryOnDisconnect([&]() -> Promise { + if (i++ <= 1) { + return KJ_EXCEPTION(DISCONNECTED, "test disconnect", i); + } else { + return 123; + } + }); + KJ_EXPECT(i == 0); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test disconnect; i = 2", + promise.ignoreResult().wait(waitScope)); + KJ_EXPECT(i == 2); + } + + { + // Test passing a reference to a function. + struct Func { + uint i = 0; + Promise operator()() { + if (i++ == 0) { + return KJ_EXCEPTION(DISCONNECTED, "test disconnect"); + } else { + return 123; + } + } + }; + Func func; + + auto promise = retryOnDisconnect(func); + KJ_EXPECT(func.i == 0); + KJ_EXPECT(promise.wait(waitScope) == 123); + KJ_EXPECT(func.i == 2); + } +} + +#if (__GLIBC__ == 2 && __GLIBC_MINOR__ <= 17) || (__MINGW32__ && !__MINGW64__) +// manylinux2014-x86 doesn't seem to respect `alignas(16)`. I am guessing this is a glibc issue +// but I don't really know. It uses glibc 2.17, so testing for that and skipping the test makes +// CI work. +// +// MinGW 32-bit also mysteriously fails this test but I am not going to spend time figuring out +// why. +#else +KJ_TEST("capture weird alignment in continuation") { + struct alignas(16) WeirdAlign { + ~WeirdAlign() { + KJ_EXPECT(reinterpret_cast(this) % 16 == 0); + } + int i; + }; + + EventLoop loop; + WaitScope waitScope(loop); + + kj::Promise p = kj::READY_NOW; + + WeirdAlign value = { 123 }; + WeirdAlign value2 = { 456 }; + auto p2 = p.then([value, value2]() -> WeirdAlign { + return { value.i + value2.i }; + }); + + KJ_EXPECT(p2.wait(waitScope).i == 579); +} +#endif + +KJ_TEST("constPromise") { + EventLoop loop; + WaitScope waitScope(loop); + + Promise p = constPromise(); + int i = p.wait(waitScope); + KJ_EXPECT(i == 123); +} + } // namespace } // namespace kj diff --git a/c++/src/kj/async-unix-test.c++ b/c++/src/kj/async-unix-test.c++ index 3d78227548..64190c430f 100644 --- a/c++/src/kj/async-unix-test.c++ +++ b/c++/src/kj/async-unix-test.c++ @@ -34,6 +34,25 @@ #include #include #include +#include +#include +#include +#include +#include "mutex.h" + +#if KJ_USE_EPOLL +#include +#endif + +#if KJ_USE_KQUEUE +#include +#endif + +#if __BIONIC__ +// Android's Bionic defines SIGRTMIN but using it in sigaddset() throws EINVAL, which means we +// definitely can't actually use RT signals. +#undef SIGRTMIN +#endif namespace kj { namespace { @@ -50,17 +69,83 @@ inline void delay() { usleep(10000); } void captureSignals() { static bool captured = false; if (!captured) { - captured = true; - // We use SIGIO and SIGURG as our test signals because they're two signals that we can be // reasonably confident won't otherwise be delivered to any KJ or Cap'n Proto test. We can't // use SIGUSR1 because it is reserved by UnixEventPort and SIGUSR2 is used by Valgrind on OSX. UnixEventPort::captureSignal(SIGURG); UnixEventPort::captureSignal(SIGIO); + +#ifdef SIGRTMIN + UnixEventPort::captureSignal(SIGRTMIN); +#endif + + UnixEventPort::captureChildExit(); + + captured = true; } } +#if KJ_USE_EPOLL +bool qemuBugTestSignalHandlerRan = false; +void qemuBugTestSignalHandler(int, siginfo_t* siginfo, void*) { + qemuBugTestSignalHandlerRan = true; +} + +bool checkForQemuEpollPwaitBug() { + // Under qemu-user, when a signal is delivered during epoll_pwait(), the signal successfully + // interrupts the wait, but the correct signal handler is not run. This ruins all our tests so + // we check for it and skip tests in this case. This does imply UnixEventPort won't be able to + // handle signals correctly under qemu-user. + + sigset_t mask; + sigset_t origMask; + KJ_SYSCALL(sigemptyset(&mask)); + KJ_SYSCALL(sigaddset(&mask, SIGURG)); + KJ_SYSCALL(pthread_sigmask(SIG_BLOCK, &mask, &origMask)); + KJ_DEFER(KJ_SYSCALL(pthread_sigmask(SIG_SETMASK, &origMask, nullptr))); + + struct sigaction action; + memset(&action, 0, sizeof(action)); + action.sa_sigaction = &qemuBugTestSignalHandler; + action.sa_flags = SA_SIGINFO; + + KJ_SYSCALL(sigfillset(&action.sa_mask)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGBUS)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGFPE)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGILL)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGSEGV)); + + KJ_SYSCALL(sigaction(SIGURG, &action, nullptr)); + + int efd; + KJ_SYSCALL(efd = epoll_create1(EPOLL_CLOEXEC)); + KJ_DEFER(close(efd)); + + kill(getpid(), SIGURG); + KJ_ASSERT(!qemuBugTestSignalHandlerRan); + + struct epoll_event event; + int n = epoll_pwait(efd, &event, 1, -1, &origMask); + KJ_ASSERT(n < 0); + KJ_ASSERT(errno == EINTR); + +#if !__aarch64__ + // qemu-user should only be used to execute aarch64 binaries so we should'nt see this bug + // elsewhere! + KJ_ASSERT(qemuBugTestSignalHandlerRan); +#endif + + return !qemuBugTestSignalHandlerRan; +} + +const bool BROKEN_QEMU = checkForQemuEpollPwaitBug(); +#else +const bool BROKEN_QEMU = false; +#endif + TEST(AsyncUnixTest, Signals) { + if (BROKEN_QEMU) return; + captureSignals(); UnixEventPort port; EventLoop loop(port); @@ -84,7 +169,9 @@ TEST(AsyncUnixTest, SignalWithValue) { // // Also, this test fails on Linux on mipsel. si_value comes back as zero. No one with a mips // machine wants to debug the problem but they demand a patch fixing it, so we disable the test. - // Sad. https://github.com/sandstorm-io/capnproto/issues/204 + // Sad. https://github.com/capnproto/capnproto/issues/204 + + if (BROKEN_QEMU) return; captureSignals(); UnixEventPort port; @@ -94,7 +181,14 @@ TEST(AsyncUnixTest, SignalWithValue) { union sigval value; memset(&value, 0, sizeof(value)); value.sival_int = 123; - sigqueue(getpid(), SIGURG, value); + KJ_SYSCALL_HANDLE_ERRORS(sigqueue(getpid(), SIGURG, value)) { + case ENOSYS: + // sigqueue() not supported. Maybe running on WSL. + KJ_LOG(WARNING, "sigqueue() is not implemented by your system; skipping test"); + return; + default: + KJ_FAIL_SYSCALL("sigqueue(getpid(), SIGURG, value)", error); + } siginfo_t info = port.onSignal(SIGURG).wait(waitScope); EXPECT_EQ(SIGURG, info.si_signo); @@ -112,7 +206,9 @@ TEST(AsyncUnixTest, SignalWithPointerValue) { // // Also, this test fails on Linux on mipsel. si_value comes back as zero. No one with a mips // machine wants to debug the problem but they demand a patch fixing it, so we disable the test. - // Sad. https://github.com/sandstorm-io/capnproto/issues/204 + // Sad. https://github.com/capnproto/capnproto/issues/204 + + if (BROKEN_QEMU) return; captureSignals(); UnixEventPort port; @@ -122,7 +218,14 @@ TEST(AsyncUnixTest, SignalWithPointerValue) { union sigval value; memset(&value, 0, sizeof(value)); value.sival_ptr = &port; - sigqueue(getpid(), SIGURG, value); + KJ_SYSCALL_HANDLE_ERRORS(sigqueue(getpid(), SIGURG, value)) { + case ENOSYS: + // sigqueue() not supported. Maybe running on WSL. + KJ_LOG(WARNING, "sigqueue() is not implemented by your system; skipping test"); + return; + default: + KJ_FAIL_SYSCALL("sigqueue(getpid(), SIGURG, value)", error); + } siginfo_t info = port.onSignal(SIGURG).wait(waitScope); EXPECT_EQ(SIGURG, info.si_signo); @@ -132,6 +235,8 @@ TEST(AsyncUnixTest, SignalWithPointerValue) { #endif TEST(AsyncUnixTest, SignalsMultiListen) { + if (BROKEN_QEMU) return; + captureSignals(); UnixEventPort port; EventLoop loop(port); @@ -156,6 +261,8 @@ TEST(AsyncUnixTest, SignalsMultiListen) { // platform I'm assuming it's a Cygwin bug. TEST(AsyncUnixTest, SignalsMultiReceive) { + if (BROKEN_QEMU) return; + captureSignals(); UnixEventPort port; EventLoop loop(port); @@ -176,16 +283,24 @@ TEST(AsyncUnixTest, SignalsMultiReceive) { #endif // !__CYGWIN32__ TEST(AsyncUnixTest, SignalsAsync) { + if (BROKEN_QEMU) return; + captureSignals(); UnixEventPort port; EventLoop loop(port); WaitScope waitScope(loop); // Arrange for a signal to be sent from another thread. - pthread_t mainThread = pthread_self(); + pthread_t mainThread KJ_UNUSED = pthread_self(); Thread thread([&]() { delay(); +#if __APPLE__ && KJ_USE_KQUEUE + // MacOS kqueue only receives process-level signals and there's nothing much we can do about + // that. + kill(getpid(), SIGURG); +#else pthread_kill(mainThread, SIGURG); +#endif }); siginfo_t info = port.onSignal(SIGURG).wait(waitScope); @@ -336,6 +451,32 @@ TEST(AsyncUnixTest, ReadObserverMultiReceive) { promise2.wait(waitScope); } +TEST(AsyncUnixTest, ReadObserverAndSignals) { + // Get FD events while also waiting on a signal. This specifically exercises epoll_pwait() for + // FD events on Linux. + + captureSignals(); + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + auto signalPromise = port.onSignal(SIGIO); + + int pipefds[2]; + KJ_SYSCALL(pipe(pipefds)); + kj::AutoCloseFd infd(pipefds[0]), outfd(pipefds[1]); + + UnixEventPort::FdObserver observer(port, infd, UnixEventPort::FdObserver::OBSERVE_READ); + + KJ_SYSCALL(write(outfd, "foo", 3)); + + observer.whenBecomesReadable().wait(waitScope); + + KJ_EXPECT(!signalPromise.poll(waitScope)) + kill(getpid(), SIGIO); + KJ_EXPECT(signalPromise.poll(waitScope)) +} + TEST(AsyncUnixTest, ReadObserverAsync) { captureSignals(); UnixEventPort port; @@ -458,8 +599,9 @@ TEST(AsyncUnixTest, WriteObserver) { EXPECT_TRUE(writable); } -#if !__APPLE__ -// Disabled on macOS due to https://github.com/sandstorm-io/capnproto/issues/374. +#if !__APPLE__ && !(KJ_USE_KQUEUE && !defined(EVFILT_EXCEPT)) +// Disabled on macOS due to https://github.com/capnproto/capnproto/issues/374. +// Disabled on kqueue systems that lack EVFILT_EXCEPT because it doesn't work there. TEST(AsyncUnixTest, UrgentObserver) { // Verify that FdObserver correctly detects availability of out-of-band data. // Availability of out-of-band data is implementation-specific. @@ -484,6 +626,14 @@ TEST(AsyncUnixTest, UrgentObserver) { KJ_SYSCALL(getsockname(serverFd, reinterpret_cast(&saddr), &saddrLen)); KJ_SYSCALL(listen(serverFd, 1)); + // Create a pipe that we'll use to signal if MSG_OOB return EINVAL. + int failpipe[2]; + KJ_SYSCALL(pipe(failpipe)); + KJ_DEFER({ + close(failpipe[0]); + close(failpipe[1]); + }); + // Accept one connection, send in-band and OOB byte, wait for a quit message Thread thread([&]() { int tmpFd; @@ -499,7 +649,14 @@ TEST(AsyncUnixTest, UrgentObserver) { c = 'i'; KJ_SYSCALL(send(clientFd, &c, 1, 0)); c = 'o'; - KJ_SYSCALL(send(clientFd, &c, 1, MSG_OOB)); + KJ_SYSCALL_HANDLE_ERRORS(send(clientFd, &c, 1, MSG_OOB)) { + case EINVAL: + // Looks like MSG_OOB is not supported. (This is the case e.g. on WSL.) + KJ_SYSCALL(write(failpipe[1], &c, 1)); + break; + default: + KJ_FAIL_SYSCALL("send(..., MSG_OOB)", error); + } KJ_SYSCALL(recv(clientFd, &c, 1, 0)); EXPECT_EQ('q', c); @@ -512,24 +669,32 @@ TEST(AsyncUnixTest, UrgentObserver) { UnixEventPort::FdObserver observer(port, clientFd, UnixEventPort::FdObserver::OBSERVE_READ | UnixEventPort::FdObserver::OBSERVE_URGENT); + UnixEventPort::FdObserver failObserver(port, failpipe[0], + UnixEventPort::FdObserver::OBSERVE_READ | UnixEventPort::FdObserver::OBSERVE_URGENT); - observer.whenUrgentDataAvailable().wait(waitScope); + auto promise = observer.whenUrgentDataAvailable().then([]() { return true; }); + auto failPromise = failObserver.whenBecomesReadable().then([]() { return false; }); + bool oobSupported = promise.exclusiveJoin(kj::mv(failPromise)).wait(waitScope); + if (oobSupported) { #if __CYGWIN__ - // On Cygwin, reading the urgent byte first causes the subsequent regular read to block until - // such a time as the connection closes -- and then the byte is successfully returned. This - // seems to be a cygwin bug. - KJ_SYSCALL(recv(clientFd, &c, 1, 0)); - EXPECT_EQ('i', c); - KJ_SYSCALL(recv(clientFd, &c, 1, MSG_OOB)); - EXPECT_EQ('o', c); + // On Cygwin, reading the urgent byte first causes the subsequent regular read to block until + // such a time as the connection closes -- and then the byte is successfully returned. This + // seems to be a cygwin bug. + KJ_SYSCALL(recv(clientFd, &c, 1, 0)); + EXPECT_EQ('i', c); + KJ_SYSCALL(recv(clientFd, &c, 1, MSG_OOB)); + EXPECT_EQ('o', c); #else - // Attempt to read the urgent byte prior to reading the in-band byte. - KJ_SYSCALL(recv(clientFd, &c, 1, MSG_OOB)); - EXPECT_EQ('o', c); - KJ_SYSCALL(recv(clientFd, &c, 1, 0)); - EXPECT_EQ('i', c); + // Attempt to read the urgent byte prior to reading the in-band byte. + KJ_SYSCALL(recv(clientFd, &c, 1, MSG_OOB)); + EXPECT_EQ('o', c); + KJ_SYSCALL(recv(clientFd, &c, 1, 0)); + EXPECT_EQ('i', c); #endif + } else { + KJ_LOG(WARNING, "MSG_OOB doesn't seem to be supported on your platform."); + } // Allow server thread to let its clientFd go out of scope. c = 'q'; @@ -573,6 +738,50 @@ TEST(AsyncUnixTest, SteadyTimers) { } } +bool dummySignalHandlerCalled = false; +void dummySignalHandler(int) { + dummySignalHandlerCalled = true; +} + +TEST(AsyncUnixTest, InterruptedTimer) { + captureSignals(); + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + +#if __linux__ + // Linux timeslices are 1ms. + constexpr auto OS_SLOWNESS_FACTOR = 1; +#else + // OSX timeslices are 10ms, so we need longer timeouts to avoid flakiness. + // To be safe we'll assume other OS's are similar. + constexpr auto OS_SLOWNESS_FACTOR = 10; +#endif + + // Schedule a timer event in 100ms. + auto& timer = port.getTimer(); + auto start = timer.now(); + constexpr auto timeout = 100 * MILLISECONDS * OS_SLOWNESS_FACTOR; + + // Arrange SIGALRM to be delivered in 50ms, handled in an empty signal handler. This will cause + // our wait to be interrupted with EINTR. We should nevertheless continue waiting for the right + // amount of time. + dummySignalHandlerCalled = false; + if (signal(SIGALRM, &dummySignalHandler) == SIG_ERR) { + KJ_FAIL_SYSCALL("signal(SIGALRM)", errno); + } + struct itimerval itv; + memset(&itv, 0, sizeof(itv)); + itv.it_value.tv_usec = 50000 * OS_SLOWNESS_FACTOR; // signal after 50ms + setitimer(ITIMER_REAL, &itv, nullptr); + + timer.afterDelay(timeout).wait(waitScope); + + KJ_EXPECT(dummySignalHandlerCalled); + KJ_EXPECT(timer.now() - start >= timeout); + KJ_EXPECT(timer.now() - start <= timeout + (timeout / 5)); // allow 20ms error +} + TEST(AsyncUnixTest, Wake) { captureSignals(); UnixEventPort port; @@ -592,15 +801,323 @@ TEST(AsyncUnixTest, Wake) { EXPECT_FALSE(port.wait()); } - bool woken = false; - Thread thread([&]() { + // Test wake() when already wait()ing. + { + Thread thread([&]() { + delay(); + port.wake(); + }); + + EXPECT_TRUE(port.wait()); + } + + // Test wait() after wake() already happened. + { + Thread thread([&]() { + port.wake(); + }); + delay(); - woken = true; - port.wake(); - }); + EXPECT_TRUE(port.wait()); + } - EXPECT_TRUE(port.wait()); + // Test wake() during poll() busy loop. + { + Thread thread([&]() { + delay(); + port.wake(); + }); + + EXPECT_FALSE(port.poll()); + while (!port.poll()) {} + } + + // Test poll() when wake() already delivered. + { + EXPECT_FALSE(port.poll()); + + Thread thread([&]() { + port.wake(); + }); + + do { + delay(); + } while (!port.poll()); + } +} + +int exitCodeForSignal = 0; +[[noreturn]] void exitSignalHandler(int) { + _exit(exitCodeForSignal); +} + +struct TestChild { + kj::Maybe pid; + kj::Promise promise = nullptr; + + TestChild(UnixEventPort& port, int exitCode) { + pid_t p; + KJ_SYSCALL(p = fork()); + if (p == 0) { + // Arrange for SIGTERM to cause the process to exit normally. + exitCodeForSignal = exitCode; + signal(SIGTERM, &exitSignalHandler); + sigset_t sigs; + sigemptyset(&sigs); + sigaddset(&sigs, SIGTERM); + pthread_sigmask(SIG_UNBLOCK, &sigs, nullptr); + + for (;;) pause(); + } + pid = p; + promise = port.onChildExit(pid); + } + + ~TestChild() noexcept(false) { + KJ_IF_MAYBE(p, pid) { + KJ_SYSCALL(::kill(*p, SIGKILL)) { return; } + int status; + KJ_SYSCALL(waitpid(*p, &status, 0)) { return; } + } + } + + void kill(int signo) { + KJ_SYSCALL(::kill(KJ_REQUIRE_NONNULL(pid), signo)); + } + + KJ_DISALLOW_COPY_AND_MOVE(TestChild); +}; + +TEST(AsyncUnixTest, ChildProcess) { + if (BROKEN_QEMU) return; + + captureSignals(); + + // Block SIGTERM so that we can carefully un-block it in children. + sigset_t sigs, oldsigs; + KJ_SYSCALL(sigemptyset(&sigs)); + KJ_SYSCALL(sigaddset(&sigs, SIGTERM)); + KJ_SYSCALL(pthread_sigmask(SIG_BLOCK, &sigs, &oldsigs)); + KJ_DEFER(KJ_SYSCALL(pthread_sigmask(SIG_SETMASK, &oldsigs, nullptr)) { break; }); + + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + TestChild child1(port, 123); + KJ_EXPECT(!child1.promise.poll(waitScope)); + + child1.kill(SIGTERM); + + { + int status = child1.promise.wait(waitScope); + KJ_EXPECT(WIFEXITED(status)); + KJ_EXPECT(WEXITSTATUS(status) == 123); + } + + TestChild child2(port, 234); + TestChild child3(port, 345); + + KJ_EXPECT(!child2.promise.poll(waitScope)); + KJ_EXPECT(!child3.promise.poll(waitScope)); + + child2.kill(SIGKILL); + + { + int status = child2.promise.wait(waitScope); + KJ_EXPECT(!WIFEXITED(status)); + KJ_EXPECT(WIFSIGNALED(status)); + KJ_EXPECT(WTERMSIG(status) == SIGKILL); + } + + KJ_EXPECT(!child3.promise.poll(waitScope)); + + // child3 will be killed and synchronously waited on the way out. +} + +#if !__CYGWIN__ +// TODO(someday): Figure out why whenWriteDisconnected() never resolves on Cygwin. + +KJ_TEST("UnixEventPort whenWriteDisconnected()") { + captureSignals(); + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + int fds_[2]; + KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds_)); + kj::AutoCloseFd fds[2] = { kj::AutoCloseFd(fds_[0]), kj::AutoCloseFd(fds_[1]) }; + + UnixEventPort::FdObserver observer(port, fds[0], UnixEventPort::FdObserver::OBSERVE_READ); + + // At one point, the poll()-based version of UnixEventPort had a bug where if some other event + // had completed previously, whenWriteDisconnected() would stop being watched for. So we watch + // for readability as well and check that that goes away first. + auto readablePromise = observer.whenBecomesReadable(); + auto hupPromise = observer.whenWriteDisconnected(); + + KJ_EXPECT(!readablePromise.poll(waitScope)); + KJ_EXPECT(!hupPromise.poll(waitScope)); + + KJ_SYSCALL(write(fds[1], "foo", 3)); + + KJ_ASSERT(readablePromise.poll(waitScope)); + readablePromise.wait(waitScope); + + { + char junk[16]; + ssize_t n; + KJ_SYSCALL(n = read(fds[0], junk, 16)); + KJ_EXPECT(n == 3); + } + + KJ_EXPECT(!hupPromise.poll(waitScope)); + + fds[1] = nullptr; + KJ_ASSERT(hupPromise.poll(waitScope)); + hupPromise.wait(waitScope); +} + +KJ_TEST("UnixEventPort FdObserver(..., flags=0)::whenWriteDisconnected()") { + // Verifies that given `0' as a `flags' argument, + // FdObserver still observes whenWriteDisconnected(). + // + // This can be useful to watch disconnection on a blocking file descriptor. + // See discussion: https://github.com/capnproto/capnproto/issues/924 + + captureSignals(); + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + int pipefds[2]; + KJ_SYSCALL(pipe(pipefds)); + kj::AutoCloseFd infd(pipefds[0]), outfd(pipefds[1]); + + UnixEventPort::FdObserver observer(port, outfd, 0); + + auto hupPromise = observer.whenWriteDisconnected(); + + KJ_EXPECT(!hupPromise.poll(waitScope)); + + infd = nullptr; + KJ_ASSERT(hupPromise.poll(waitScope)); + hupPromise.wait(waitScope); +} + +#endif + +KJ_TEST("UnixEventPort poll for signals") { + captureSignals(); + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + auto promise1 = port.onSignal(SIGURG); + auto promise2 = port.onSignal(SIGIO); + + KJ_EXPECT(!promise1.poll(waitScope)); + KJ_EXPECT(!promise2.poll(waitScope)); + + KJ_SYSCALL(kill(getpid(), SIGURG)); + KJ_SYSCALL(kill(getpid(), SIGIO)); + port.wake(); + + KJ_EXPECT(port.poll()); + KJ_EXPECT(promise1.poll(waitScope)); + KJ_EXPECT(promise2.poll(waitScope)); + + promise1.wait(waitScope); + promise2.wait(waitScope); +} + +#if defined(SIGRTMIN) && !__CYGWIN__ && !__aarch64__ +// TODO(someday): Figure out why RT signals don't seem to work correctly on Cygwin. It looks like +// only the first signal is delivered, like how non-RT signals work. Is it possible Cygwin +// advertites RT signal support but doesn't actually implement them correctly? I can't find any +// information on the internet about this and TBH I don't care about Cygwin enough to dig in. +// TODO(someday): Figure out why RT signals don't work under qemu-user emulating aarch64 on +// Debian Buster. + +void testRtSignals(UnixEventPort& port, WaitScope& waitScope, bool doPoll) { + union sigval value; + memset(&value, 0, sizeof(value)); + + // Queue three copies of the signal upfront. + for (uint i = 0; i < 3; i++) { + value.sival_int = 123 + i; + KJ_SYSCALL(sigqueue(getpid(), SIGRTMIN, value)); + } + + // Now wait for them. + for (uint i = 0; i < 3; i++) { + auto promise = port.onSignal(SIGRTMIN); + if (doPoll) { + KJ_ASSERT(promise.poll(waitScope)); + } + auto info = promise.wait(waitScope); + KJ_EXPECT(info.si_value.sival_int == 123 + i); + } + + KJ_EXPECT(!port.onSignal(SIGRTMIN).poll(waitScope)); +} + +KJ_TEST("UnixEventPort can receive multiple queued instances of an RT signal") { + captureSignals(); + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + testRtSignals(port, waitScope, true); + + // Test again, but don't poll() the promises. This may test a different code path, if poll() and + // wait() are very different in how they read signals. (For the poll(2)-based implementation of + // UnixEventPort, they are indeed pretty different.) + testRtSignals(port, waitScope, false); } +#endif + +#if !(__APPLE__ && KJ_USE_KQUEUE) +KJ_TEST("UnixEventPort thread-specific signals") { + // Verify a signal directed to a thread is only received on the intended thread. + // + // MacOS kqueue only receives process-level signals and there's nothing much we can do about + // that, so this test won't work there. + + if (BROKEN_QEMU) return; + + captureSignals(); + + Vector> threads; + std::atomic readyCount(0); + std::atomic doneCount(0); + for (auto i KJ_UNUSED: kj::zeroTo(16)) { + threads.add(kj::heap([&]() noexcept { + UnixEventPort port; + EventLoop loop(port); + WaitScope waitScope(loop); + + readyCount.fetch_add(1, std::memory_order_relaxed); + port.onSignal(SIGIO).wait(waitScope); + doneCount.fetch_add(1, std::memory_order_relaxed); + })); + } + + do { + usleep(1000); + } while (readyCount.load(std::memory_order_relaxed) < 16); + + KJ_ASSERT(doneCount.load(std::memory_order_relaxed) == 0); + + uint count = 0; + for (uint i: {5, 14, 4, 6, 7, 11, 1, 3, 8, 0, 12, 9, 10, 15, 2, 13}) { + threads[i]->sendSignal(SIGIO); + threads[i] = nullptr; // wait for that one thread to exit + usleep(1000); + KJ_ASSERT(doneCount.load(std::memory_order_relaxed) == ++count); + } +} +#endif } // namespace } // namespace kj diff --git a/c++/src/kj/threadlocal-pthread-test.c++ b/c++/src/kj/async-unix-xthread-test.c++ similarity index 81% rename from c++/src/kj/threadlocal-pthread-test.c++ rename to c++/src/kj/async-unix-xthread-test.c++ index d4c270ea29..e57a8d84a0 100644 --- a/c++/src/kj/threadlocal-pthread-test.c++ +++ b/c++/src/kj/async-unix-xthread-test.c++ @@ -1,4 +1,4 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Copyright (c) 2019 Cloudflare, Inc. and contributors // Licensed under the MIT License: // // Permission is hereby granted, free of charge, to any person obtaining a copy @@ -20,6 +20,13 @@ // THE SOFTWARE. #if !_WIN32 -#define KJ_USE_PTHREAD_TLS 1 -#include "threadlocal-test.c++" -#endif + +#include "async-unix.h" + +#define KJ_XTHREAD_TEST_SETUP_LOOP \ + UnixEventPort port; \ + EventLoop loop(port); \ + WaitScope waitScope(loop) +#include "async-xthread-test.c++" + +#endif // !_WIN32 diff --git a/c++/src/kj/async-unix.c++ b/c++/src/kj/async-unix.c++ index 42fd11d50b..a8179ea5ae 100644 --- a/c++/src/kj/async-unix.c++ +++ b/c++/src/kj/async-unix.c++ @@ -28,28 +28,29 @@ #include #include #include -#include #include +#include +#include +#include #if KJ_USE_EPOLL -#include #include -#include #include +#elif KJ_USE_KQUEUE +#include +#include +#include +#if !__APPLE__ && !__OpenBSD__ +// MacOS and OpenBSD are missing this, which means we have to do ugly hacks instead on those. +#define KJ_HAS_SIGTIMEDWAIT 1 +#endif #else #include +#include #endif namespace kj { -// ======================================================================================= -// Timer code common to multiple implementations - -TimePoint UnixEventPort::readClock() { - return origin() + std::chrono::duration_cast( - std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS; -} - // ======================================================================================= // Signal code common to multiple implementations @@ -57,57 +58,255 @@ namespace { int reservedSignal = SIGUSR1; bool tooLateToSetReserved = false; +bool capturedChildExit = false; + +#if !KJ_USE_KQUEUE +bool threadClaimedChildExits = false; +#endif + +} // namespace + +#if KJ_USE_EPOLL + +namespace { + +KJ_THREADLOCAL_PTR(UnixEventPort) threadEventPort = nullptr; +// This is set to the current UnixEventPort just before epoll_pwait(), then back to null after it +// returns. + +} // namespace + +void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { + // Since this signal handler is *only* called during `epoll_pwait()`, we aren't subject to the + // usual signal-safety concerns. We can treat this more like a callback. So, we can just call + // gotSignal() directly, no biggy. + + // Note that, if somehow the signal hanlder is invoked when *not* running `epoll_pwait()`, then + // `threadEventPort` will be null. We silently ignore the signal in this case. This should never + // happen in normal execution, so you might argue we should assert-fail instead. However: + // - We obviously can't throw from here, so we'd have to crash instead. + // - The Cloudflare Workers runtime relies on this no-op behavior for a certain hack. The hack + // in question involves unblocking a signal from the signal mask and relying on it to interrupt + // certain blocking syscalls, causing them to fail with EINTR. The hack does not need the + // handler to do anything except return in this case. The hacky code makes sure to restore the + // signal mask before returning to the event loop. + + UnixEventPort* current = threadEventPort; + if (current != nullptr) { + current->gotSignal(*siginfo); + } +} + +#elif KJ_USE_KQUEUE + +#if !KJ_HAS_SIGTIMEDWAIT +KJ_THREADLOCAL_PTR(siginfo_t) threadCapture = nullptr; +#endif + +void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { +#if KJ_HAS_SIGTIMEDWAIT + // This is never called because we use sigtimedwait() to dequeue the signal while it is still + // blocked, without running the signal handler. However, if we don't register a handler at all, + // and the default behavior is SIG_IGN, then the signal will be discarded before sigtimedwait() + // can receive it. +#else + // When sigtimedwait() isn't available, we use sigsuspend() and wait for the siginfo_t to be + // delivered to the signal handler. + siginfo_t* capture = threadCapture; + if (capture != nullptr) { + *capture = *siginfo; + } +#endif +} + +#else + +namespace { struct SignalCapture { sigjmp_buf jumpTo; siginfo_t siginfo; + +#if __APPLE__ + sigset_t originalMask; + // The signal mask to be restored when jumping out of the signal handler. + // + // "But wait!" you say, "Isn't the whole point of siglongjmp() that it does this for you?" Well, + // yes, that is supposed to be the point. However, Apple implemented in wrong. On macOS, + // siglongjmp() uses sigprocmask() -- not pthread_sigmask() -- to restore the signal mask. + // Unfortunately, sigprocmask() on macOS affects threads other than the current thread. Arguably + // this is conformant: sigprocmask() is documented as having unspecified behavior in the presence + // of threads, and pthread_sigmask() must be used instead. However, this means siglongjmp() + // cannot be used in the presence of threads. + // + // We'll just have to restore the signal mask ourselves, rather than rely on siglongjmp()... + // + // ... but we ONLY do that on Apple systems, because it turns out, ironically, on Android, this + // hack breaks signal delivery. pthread_sigmask() vs. sigprocmask() is not the issue; we + // apparently MUST let siglongjmp() itself deal with the signal mask, otherwise various tests in + // async-unix-test.c++ end up hanging (I haven't gotten to the bottom of why). Note that on stock + // Linux, _either_ strategy works fine; this appears to be a problem with Android's Bionic libc. + // Since letting siglongjmp() do the work _seeems_ more "correct", we'll make it the default and + // only do something different on Apple platforms. +#define KJ_BROKEN_SIGLONGJMP 1 +#endif }; -#if !KJ_USE_EPOLL // on Linux we'll use signalfd KJ_THREADLOCAL_PTR(SignalCapture) threadCapture = nullptr; -void signalHandler(int, siginfo_t* siginfo, void*) { +} // namespace + +void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { SignalCapture* capture = threadCapture; if (capture != nullptr) { capture->siginfo = *siginfo; - siglongjmp(capture->jumpTo, 1); + +#if KJ_BROKEN_SIGLONGJMP + // See comments on SignalCapture::originalMask, above: We can't rely on siglongjmp() to restore + // the signal mask; we must do it ourselves using pthread_sigmask(). We pass false as the + // second parameter to siglongjmp() so that it skips changing the signal mask. This makes it + // equivalent to `longjmp()` on Linux or `_longjmp()` on BSD/macOS. See comments on + // SignalCapture::originalMask for explanation. + pthread_sigmask(SIG_SETMASK, &capture->originalMask, nullptr); + siglongjmp(capture->jumpTo, false); +#else + siglongjmp(capture->jumpTo, true); +#endif } } -#endif -void registerSignalHandler(int signum) { +#endif // !KJ_USE_EPOLL && !KJ_USE_KQUEUE + +void UnixEventPort::registerSignalHandler(int signum) { + KJ_REQUIRE(signum != SIGBUS && signum != SIGFPE && signum != SIGILL && signum != SIGSEGV, + "this signal is raised by erroneous code execution; you cannot capture it into the event " + "loop"); + tooLateToSetReserved = true; + // Block the signal from being delivered most of the time. We'll explicitly unblock it when we + // want to receive it. sigset_t mask; KJ_SYSCALL(sigemptyset(&mask)); KJ_SYSCALL(sigaddset(&mask, signum)); - KJ_SYSCALL(sigprocmask(SIG_BLOCK, &mask, nullptr)); + KJ_SYSCALL(pthread_sigmask(SIG_BLOCK, &mask, nullptr)); -#if !KJ_USE_EPOLL // on Linux we'll use signalfd + // Register the signal handler which should be invoked when we explicitly unblock the signal. struct sigaction action; memset(&action, 0, sizeof(action)); action.sa_sigaction = &signalHandler; - KJ_SYSCALL(sigfillset(&action.sa_mask)); action.sa_flags = SA_SIGINFO; + + // Set up the signal mask applied while the signal handler runs. We want to block all other + // signals from being raised during the handler, with the exception of the four "crash" signals, + // which realistically can't be blocked. + KJ_SYSCALL(sigfillset(&action.sa_mask)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGBUS)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGFPE)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGILL)); + KJ_SYSCALL(sigdelset(&action.sa_mask, SIGSEGV)); + KJ_SYSCALL(sigaction(signum, &action, nullptr)); -#endif } -void registerReservedSignal() { +#if !KJ_USE_EPOLL && !KJ_USE_KQUEUE && !KJ_USE_PIPE_FOR_WAKEUP +void UnixEventPort::registerReservedSignal() { registerSignalHandler(reservedSignal); +} +#endif - // We also disable SIGPIPE because users of UnixEventPort almost certainly don't want it. - while (signal(SIGPIPE, SIG_IGN) == SIG_ERR) { - int error = errno; - if (error != EINTR) { - KJ_FAIL_SYSCALL("signal(SIGPIPE, SIG_IGN)", error); +void UnixEventPort::ignoreSigpipe() { + // We disable SIGPIPE because users of UnixEventPort almost certainly don't want it. + // + // We've observed that when starting many threads at the same time, this can cause some + // contention on the kernel's signal handler table lock, so we try to run it only once. + static bool once KJ_UNUSED = []() { + while (signal(SIGPIPE, SIG_IGN) == SIG_ERR) { + int error = errno; + if (error != EINTR) { + KJ_FAIL_SYSCALL("signal(SIGPIPE, SIG_IGN)", error); + } + } + return true; + }(); +} + +#if !KJ_USE_KQUEUE // kqueue systems handle child processes differently + +struct UnixEventPort::ChildSet { + std::map waiters; + + void checkExits(); +}; + +class UnixEventPort::ChildExitPromiseAdapter { +public: + inline ChildExitPromiseAdapter(PromiseFulfiller& fulfiller, + ChildSet& childSet, Maybe& pidRef) + : childSet(childSet), + pid(KJ_REQUIRE_NONNULL(pidRef, + "`pid` must be non-null at the time `onChildExit()` is called")), + pidRef(pidRef), fulfiller(fulfiller) { + KJ_REQUIRE(childSet.waiters.insert(std::make_pair(pid, this)).second, + "already called onChildExit() for this pid"); + } + + ~ChildExitPromiseAdapter() noexcept(false) { + childSet.waiters.erase(pid); + } + + ChildSet& childSet; + pid_t pid; + Maybe& pidRef; + PromiseFulfiller& fulfiller; +}; + +void UnixEventPort::ChildSet::checkExits() { + for (;;) { + int status; + pid_t pid; + KJ_SYSCALL_HANDLE_ERRORS(pid = waitpid(-1, &status, WNOHANG)) { + case ECHILD: + return; + default: + KJ_FAIL_SYSCALL("waitpid()", error); + } + if (pid == 0) break; + + auto iter = waiters.find(pid); + if (iter != waiters.end()) { + iter->second->pidRef = nullptr; + iter->second->fulfiller.fulfill(kj::cp(status)); } } } -pthread_once_t registerReservedSignalOnce = PTHREAD_ONCE_INIT; +Promise UnixEventPort::onChildExit(Maybe& pid) { + KJ_REQUIRE(capturedChildExit, + "must call UnixEventPort::captureChildExit() to use onChildExit()."); -} // namespace + ChildSet* cs; + KJ_IF_MAYBE(c, childSet) { + cs = *c; + } else { + // In theory we should do an atomic compare-and-swap on threadClaimedChildExits, but this is + // for debug purposes only so it's not a big deal. + KJ_REQUIRE(!threadClaimedChildExits, + "only one UnixEvertPort per process may listen for child exits"); + threadClaimedChildExits = true; + + auto newChildSet = kj::heap(); + cs = newChildSet; + childSet = kj::mv(newChildSet); + } + + return kj::newAdaptedPromise(*cs, pid); +} + +void UnixEventPort::captureChildExit() { + captureSignal(SIGCHLD); + capturedChildExit = true; +} class UnixEventPort::SignalPromiseAdapter { public: @@ -151,9 +350,13 @@ public: }; Promise UnixEventPort::onSignal(int signum) { + KJ_REQUIRE(signum != SIGCHLD || !capturedChildExit, + "can't call onSigal(SIGCHLD) when kj::UnixEventPort::captureChildExit() has been called"); return newAdaptedPromise(*this, signum); } +#endif // !KJ_USE_KQUEUE + void UnixEventPort::captureSignal(int signum) { if (reservedSignal == SIGUSR1) { KJ_REQUIRE(signum != SIGUSR1, @@ -177,7 +380,17 @@ void UnixEventPort::setReservedSignal(int signum) { reservedSignal = signum; } +#if !KJ_USE_KQUEUE + void UnixEventPort::gotSignal(const siginfo_t& siginfo) { + // If onChildExit() has been called and this is SIGCHLD, check for child exits. + KJ_IF_MAYBE(cs, childSet) { + if (siginfo.si_signo == SIGCHLD) { + cs->get()->checkExits(); + return; + } + } + // Fire any events waiting on this signal. auto ptr = signalHead; while (ptr != nullptr) { @@ -190,39 +403,43 @@ void UnixEventPort::gotSignal(const siginfo_t& siginfo) { } } +#endif // !KJ_USE_KQUEUE + #if KJ_USE_EPOLL // ======================================================================================= // epoll FdObserver implementation UnixEventPort::UnixEventPort() - : timerImpl(readClock()), - epollFd(-1), - signalFd(-1), - eventFd(-1) { - pthread_once(®isterReservedSignalOnce, ®isterReservedSignal); + : clock(systemPreciseMonotonicClock()), + timerImpl(clock.now()) { + ignoreSigpipe(); int fd; KJ_SYSCALL(fd = epoll_create1(EPOLL_CLOEXEC)); epollFd = AutoCloseFd(fd); - KJ_SYSCALL(sigemptyset(&signalFdSigset)); - KJ_SYSCALL(fd = signalfd(-1, &signalFdSigset, SFD_NONBLOCK | SFD_CLOEXEC)); - signalFd = AutoCloseFd(fd); - KJ_SYSCALL(fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); eventFd = AutoCloseFd(fd); - struct epoll_event event; memset(&event, 0, sizeof(event)); event.events = EPOLLIN; event.data.u64 = 0; - KJ_SYSCALL(epoll_ctl(epollFd, EPOLL_CTL_ADD, signalFd, &event)); - event.data.u64 = 1; KJ_SYSCALL(epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event)); + + // Get the current signal mask, from which we'll compute the appropriate mask to pass to + // epoll_pwait() on each loop. (We explicitly memset to 0 first to make sure we can compare + // this against another mask with memcmp() for debug purposes.) + memset(&originalMask, 0, sizeof(originalMask)); + KJ_SYSCALL(sigprocmask(0, nullptr, &originalMask)); } -UnixEventPort::~UnixEventPort() noexcept(false) {} +UnixEventPort::~UnixEventPort() noexcept(false) { + if (childSet != nullptr) { + // We had claimed the exclusive right to call onChildExit(). Release that right. + threadClaimedChildExits = false; + } +} UnixEventPort::FdObserver::FdObserver(UnixEventPort& eventPort, int fd, uint flags) : eventPort(eventPort), fd(fd), flags(flags) { @@ -271,6 +488,13 @@ void UnixEventPort::FdObserver::fire(short events) { } } + if (events & (EPOLLHUP | EPOLLERR)) { + KJ_IF_MAYBE(f, hupFulfiller) { + f->get()->fulfill(); + hupFulfiller = nullptr; + } + } + if (events & EPOLLPRI) { KJ_IF_MAYBE(f, urgentFulfiller) { f->get()->fulfill(); @@ -304,15 +528,10 @@ Promise UnixEventPort::FdObserver::whenUrgentDataAvailable() { return kj::mv(paf.promise); } -bool UnixEventPort::wait() { - return doEpollWait( - timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue)) - .map([](uint64_t t) -> int { return t; }) - .orDefault(-1)); -} - -bool UnixEventPort::poll() { - return doEpollWait(0); +Promise UnixEventPort::FdObserver::whenWriteDisconnected() { + auto paf = newPromiseAndFulfiller(); + hupFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); } void UnixEventPort::wake() const { @@ -322,172 +541,611 @@ void UnixEventPort::wake() const { KJ_ASSERT(n < 0 || n == sizeof(one)); } -static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) { - // Unfortunately, siginfo_t is mostly a big union and the correct set of fields to fill in - // depends on the type of signal. OTOH, signalfd_siginfo is a flat struct that expands all - // siginfo_t's union fields out to be non-overlapping. We can't just copy all the fields over - // because of the unions; we have to carefully figure out which fields are appropriate to fill - // in for this signal. Ick. +bool UnixEventPort::wait() { +#ifdef KJ_DEBUG + // In debug mode, verify the current signal mask matches the original. + { + sigset_t currentMask; + memset(¤tMask, 0, sizeof(currentMask)); + KJ_SYSCALL(sigprocmask(0, nullptr, ¤tMask)); + if (memcmp(¤tMask, &originalMask, sizeof(currentMask)) != 0) { + kj::Vector changes; + for (int i = 0; i <= SIGRTMAX; i++) { + if (sigismember(¤tMask, i) && !sigismember(&originalMask, i)) { + changes.add(kj::str("signal #", i, " (", strsignal(i), ") was added")); + } else if (!sigismember(¤tMask, i) && sigismember(&originalMask, i)) { + changes.add(kj::str("signal #", i, " (", strsignal(i), ") was removed")); + } + } + + KJ_FAIL_REQUIRE( + "Signal mask has changed since UnixEventPort was constructed. You are required to " + "ensure that whenever control returns to the event loop, the signal mask is the same " + "as it was when UnixEventPort was created. In non-debug builds, this check is skipped, " + "and this situation may instead lead to unexpected results. In particular, while the " + "system is waiting for I/O events, the signal mask may be reverted to what it was at " + "construction time, ignoring your subsequent changes.", changes); + } + } +#endif - siginfo_t result; - memset(&result, 0, sizeof(result)); + int timeout = timerImpl.timeoutToNextEvent(clock.now(), MILLISECONDS, int(maxValue)) + .map([](uint64_t t) -> int { return t; }) + .orDefault(-1); - result.si_signo = siginfo.ssi_signo; - result.si_errno = siginfo.ssi_errno; - result.si_code = siginfo.ssi_code; + struct epoll_event events[16]; + int n; + if (signalHead != nullptr || childSet != nullptr) { + // We are interested in some signals. Use epoll_pwait(). + // + // Note: Once upon a time, we used signalfd for this. However, this turned out to be more + // trouble than it was worth. Some problems with signalfd: + // - It required opening an additional file descriptor per thread. + // - If the set of interesting signals changed, the signalfd would have to be updated before + // calling epoll_wait(), which was an extra syscall. + // - When a signal arrives, it requires extra syscalls to read the signal info from the + // signalfd, as well as code to translate from signalfd_siginfo to siginfo_t, which are + // different for some reason. + // - signalfd suffers from surprising lock contention during epoll_wait or when the signalfd's + // mask is updated in programs with many threads. Because the lock is a spinlock, this + // could consume exorbitant CPU. + // - When a signalfd is in an epoll, it will be flagged readable based on signals which are + // pending in the process/thread which called epoll_ctl_add() to register the signalfd. + // This is mostly fine for our usage, except that it breaks one useful case that otherwise + // works: many servers are designed to "daemonize" themselves by fork()ing and then having + // the parent process exit while the child thread lives on. In this case, if a UnixEventPort + // had been created before daemonizing, signal handling would be forever broken in the child. + + sigset_t waitMask = originalMask; + + // Unblock the signals we care about. + { + auto ptr = signalHead; + while (ptr != nullptr) { + KJ_SYSCALL(sigdelset(&waitMask, ptr->signum)); + ptr = ptr->next; + } + if (childSet != nullptr) { + KJ_SYSCALL(sigdelset(&waitMask, SIGCHLD)); + } + } - if (siginfo.ssi_code > 0) { - // Signal originated from the kernel. The structure of the siginfo depends primarily on the - // signal number. + threadEventPort = this; + n = epoll_pwait(epollFd, events, kj::size(events), timeout, &waitMask); + threadEventPort = nullptr; + } else { + // Not waiting on any signals. Regular epoll_wait() will be fine. + n = epoll_wait(epollFd, events, kj::size(events), timeout); + } - switch (siginfo.ssi_signo) { - case SIGCHLD: - result.si_pid = siginfo.ssi_pid; - result.si_uid = siginfo.ssi_uid; - result.si_status = siginfo.ssi_status; - result.si_utime = siginfo.ssi_utime; - result.si_stime = siginfo.ssi_stime; - break; + if (n < 0) { + int error = errno; + if (error == EINTR) { + // We received a singal. The signal handler may have queued an event to the event loop. Even + // if it didn't, we can't simply restart the epoll call because we need to recompute the + // timeout. Instead, we pretend epoll_wait() returned zero events. This will cause the event + // loop to spin once, decide it has nothing to do, recompute timeouts, then return to waiting. + n = 0; + } else { + KJ_FAIL_SYSCALL("epoll_pwait()", error); + } + } + + return processEpollEvents(events, n); +} + +bool UnixEventPort::processEpollEvents(struct epoll_event events[], int n) { + bool woken = false; + + for (int i = 0; i < n; i++) { + if (events[i].data.u64 == 0) { + // Someone called wake() from another thread. Consume the event. + uint64_t value; + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = read(eventFd, &value, sizeof(value))); + KJ_ASSERT(n < 0 || n == sizeof(value)); + + // We were woken. Need to return true. + woken = true; + } else { + FdObserver* observer = reinterpret_cast(events[i].data.ptr); + observer->fire(events[i].events); + } + } + + timerImpl.advanceTo(clock.now()); + + return woken; +} - case SIGILL: - case SIGFPE: - case SIGSEGV: - case SIGBUS: - case SIGTRAP: - result.si_addr = reinterpret_cast(static_cast(siginfo.ssi_addr)); -#ifdef si_trapno - result.si_trapno = siginfo.ssi_trapno; +bool UnixEventPort::poll() { + // Unfortunately, epoll_pwait() with a timeout of zero will never deliver actually deliver any + // pending signals. Therefore, we need a completely different approach to poll for signals. We + // might as well use regular epoll_wait() in this case, too, to save the kernel some effort. + + if (signalHead != nullptr || childSet != nullptr) { + // Use sigtimedwait() to poll for signals. + + // Construct a sigset of all signals we are interested in. + sigset_t sigset; + KJ_SYSCALL(sigemptyset(&sigset)); + uint count = 0; + + { + auto ptr = signalHead; + while (ptr != nullptr) { + KJ_SYSCALL(sigaddset(&sigset, ptr->signum)); + ++count; + ptr = ptr->next; + } + if (childSet != nullptr) { + KJ_SYSCALL(sigaddset(&sigset, SIGCHLD)); + ++count; + } + } + + // While that set is non-empty, poll for signals. + while (count > 0) { + struct timespec timeout; + timeout.tv_sec = 0; + timeout.tv_nsec = 0; + + siginfo_t siginfo; + int n; + KJ_NONBLOCKING_SYSCALL(n = sigtimedwait(&sigset, &siginfo, &timeout)); + if (n < 0) break; // EAGAIN: no signals in set are raised + + KJ_ASSERT(n == siginfo.si_signo); + gotSignal(siginfo); + + // Remove that signal from the set so we don't receive it again, but keep checking for others + // if there are any. + KJ_SYSCALL(sigdelset(&sigset, n)); + --count; + } + } + + struct epoll_event events[16]; + int n; + KJ_SYSCALL(n = epoll_wait(epollFd, events, kj::size(events), 0)); + + return processEpollEvents(events, n); +} + +#elif KJ_USE_KQUEUE +// ======================================================================================= +// kqueue FdObserver implementation + +UnixEventPort::UnixEventPort() + : clock(systemPreciseMonotonicClock()), + timerImpl(clock.now()) { + ignoreSigpipe(); + + int fd; + KJ_SYSCALL(fd = kqueue()); + kqueueFd = AutoCloseFd(fd); + + // NetBSD has kqueue1() which can set CLOEXEC atomically, but FreeBSD, MacOS, and others don't + // have this... oh well. + KJ_SYSCALL(fcntl(kqueueFd, F_SETFD, FD_CLOEXEC)); + + // Register the EVFILT_USER event used by wake(). + struct kevent event; + EV_SET(&event, 0, EVFILT_USER, EV_ADD | EV_CLEAR, 0, 0, nullptr); + KJ_SYSCALL(kevent(kqueueFd, &event, 1, nullptr, 0, nullptr)); +} + +UnixEventPort::~UnixEventPort() noexcept(false) {} + +UnixEventPort::FdObserver::FdObserver(UnixEventPort& eventPort, int fd, uint flags) + : eventPort(eventPort), fd(fd), flags(flags) { + struct kevent events[3]; + int nevents = 0; + + if (flags & OBSERVE_URGENT) { +#ifdef EVFILT_EXCEPT + EV_SET(&events[nevents++], fd, EVFILT_EXCEPT, EV_ADD | EV_CLEAR, NOTE_OOB, 0, this); +#else + // TODO(someday): Can we support this without reverting to poll()? + // Related: https://sandstorm.io/news/2015-04-08-osx-security-bug + KJ_FAIL_ASSERT("kqueue() on this system doesn't support EVFILT_EXCEPT (for OBSERVE_URGENT). " + "If you really need to observe OOB events, compile KJ (and your application) with " + "-DKJ_USE_KQUEUE=0 to disable use of kqueue()."); #endif -#ifdef si_addr_lsb - // ssi_addr_lsb is defined as coming immediately after ssi_addr in the kernel headers but - // apparently the userspace headers were never updated. So we do a pointer hack. :( - result.si_addr_lsb = *reinterpret_cast(&siginfo.ssi_addr + 1); + } + if (flags & OBSERVE_READ) { + EV_SET(&events[nevents++], fd, EVFILT_READ, EV_ADD | EV_CLEAR, 0, 0, this); + } + if (flags & OBSERVE_WRITE) { + EV_SET(&events[nevents++], fd, EVFILT_WRITE, EV_ADD | EV_CLEAR, 0, 0, this); + } + + KJ_SYSCALL(kevent(eventPort.kqueueFd, events, nevents, nullptr, 0, nullptr)); +} + +UnixEventPort::FdObserver::~FdObserver() noexcept(false) { + struct kevent events[3]; + int nevents = 0; + + if (flags & OBSERVE_URGENT) { +#ifdef EVFILT_EXCEPT + EV_SET(&events[nevents++], fd, EVFILT_EXCEPT, EV_DELETE, 0, 0, nullptr); #endif + } + if (flags & OBSERVE_READ) { + EV_SET(&events[nevents++], fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr); + } + if ((flags & OBSERVE_WRITE) || hupFulfiller != nullptr) { + EV_SET(&events[nevents++], fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + } + + // TODO(perf): Should we delay unregistration of events until the next time kqueue() is invoked? + // We can't delay registrations since it could lead to missed events, but we could delay + // unregistration safely. However, we'd have to be very careful about the possibility that + // the same FD is re-registered later. + KJ_SYSCALL_HANDLE_ERRORS(kevent(eventPort.kqueueFd, events, nevents, nullptr, 0, nullptr)) { + case ENOENT: + // In the specific case of unnamed pipes, when read end of the pipe is destroyed, FreeBSD + // seems to unregister the events on the write end automatically. Subsequently trying to + // remove them then produces ENOENT. Let's ignore this. break; + default: + KJ_FAIL_SYSCALL("kevent(remove events)", error); + } +} - case SIGIO: - static_assert(SIGIO == SIGPOLL, "SIGIO != SIGPOLL?"); +void UnixEventPort::FdObserver::fire(struct kevent event) { + switch (event.filter) { + case EVFILT_READ: + if (event.flags & EV_EOF) { + atEnd = true; + } else { + atEnd = false; + } - // Note: Technically, code can arrange for SIGIO signals to be delivered with a signal number - // other than SIGIO. AFAICT there is no way for us to detect this in the siginfo. Luckily - // SIGIO is totally obsoleted by epoll so it shouldn't come up. + KJ_IF_MAYBE(f, readFulfiller) { + f->get()->fulfill(); + readFulfiller = nullptr; + } + break; - result.si_band = siginfo.ssi_band; - result.si_fd = siginfo.ssi_fd; + case EVFILT_WRITE: + if (event.flags & EV_EOF) { + // EOF on write indicates disconnect. + KJ_IF_MAYBE(f, hupFulfiller) { + f->get()->fulfill(); + hupFulfiller = nullptr; + if (!(flags & OBSERVE_WRITE)) { + // We were only observing writes to get the disconnect event. Stop observing now. + struct kevent rmEvent; + EV_SET(&rmEvent, fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + KJ_SYSCALL_HANDLE_ERRORS(kevent(eventPort.kqueueFd, &rmEvent, 1, nullptr, 0, nullptr)) { + case ENOENT: + // In the specific case of unnamed pipes, when read end of the pipe is destroyed, + // FreeBSD seems to unregister the events on the write end automatically. + // Subsequently trying to remove them then produces ENOENT. Let's ignore this. + break; + default: + KJ_FAIL_SYSCALL("kevent(remove events)", error); + } + } + } + } + + KJ_IF_MAYBE(f, writeFulfiller) { + f->get()->fulfill(); + writeFulfiller = nullptr; + } break; - case SIGSYS: - // Apparently SIGSYS's fields are not available in signalfd_siginfo? +#ifdef EVFILT_EXCEPT + case EVFILT_EXCEPT: + KJ_IF_MAYBE(f, urgentFulfiller) { + f->get()->fulfill(); + urgentFulfiller = nullptr; + } break; - } +#endif + } +} - } else { - // Signal originated from userspace. The sender could specify whatever signal number they - // wanted. The structure of the signal is determined by the API they used, which is identified - // by SI_CODE. - - switch (siginfo.ssi_code) { - case SI_USER: - case SI_TKILL: - // kill(), tkill(), or tgkill(). - result.si_pid = siginfo.ssi_pid; - result.si_uid = siginfo.ssi_uid; - break; +Promise UnixEventPort::FdObserver::whenBecomesReadable() { + KJ_REQUIRE(flags & OBSERVE_READ, "FdObserver was not set to observe reads."); - case SI_QUEUE: - case SI_MESGQ: - case SI_ASYNCIO: - default: - result.si_pid = siginfo.ssi_pid; - result.si_uid = siginfo.ssi_uid; - - // This is awkward. In siginfo_t, si_ptr and si_int are in a union together. In - // signalfd_siginfo, they are not. We don't really know whether the app intended to send - // an int or a pointer. Presumably since the pointer is always larger than the int, if - // we write the pointer, we'll end up with the right value for the int? Presumably the - // two fields of signalfd_siginfo are actually extracted from one of these unions - // originally, so actually contain redundant data? Better write some tests... - // - // Making matters even stranger, siginfo.ssi_ptr is 64-bit even on 32-bit systems, and - // it appears that instead of doing the obvious thing by casting the pointer value to - // 64 bits, the kernel actually memcpy()s the 32-bit value into the 64-bit space. As - // a result, on big-endian 32-bit systems, the original pointer value ends up in the - // *upper* 32 bits of siginfo.ssi_ptr, which is totally weird. We play along and use - // a memcpy() on our end too, to get the right result on all platforms. - memcpy(&result.si_ptr, &siginfo.ssi_ptr, sizeof(result.si_ptr)); - break; + auto paf = newPromiseAndFulfiller(); + readFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); +} - case SI_TIMER: - result.si_timerid = siginfo.ssi_tid; - result.si_overrun = siginfo.ssi_overrun; +Promise UnixEventPort::FdObserver::whenBecomesWritable() { + KJ_REQUIRE(flags & OBSERVE_WRITE, "FdObserver was not set to observe writes."); - // Again with this weirdness... - result.si_ptr = reinterpret_cast(static_cast(siginfo.ssi_ptr)); - break; + auto paf = newPromiseAndFulfiller(); + writeFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); +} + +Promise UnixEventPort::FdObserver::whenUrgentDataAvailable() { + KJ_REQUIRE(flags & OBSERVE_URGENT, + "FdObserver was not set to observe availability of urgent data."); + + auto paf = newPromiseAndFulfiller(); + urgentFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); +} + +Promise UnixEventPort::FdObserver::whenWriteDisconnected() { + if (!(flags & OBSERVE_WRITE) && hupFulfiller == nullptr) { + // We aren't observing writes, but we need to if we want to detect disconnects. + struct kevent event; + EV_SET(&event, fd, EVFILT_WRITE, EV_ADD | EV_CLEAR, 0, 0, this); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + } + + auto paf = newPromiseAndFulfiller(); + hupFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); +} + +class UnixEventPort::SignalPromiseAdapter { +public: + inline SignalPromiseAdapter(PromiseFulfiller& fulfiller, + UnixEventPort& eventPort, int signum) + : eventPort(eventPort), signum(signum), fulfiller(fulfiller) { + struct kevent event; + EV_SET(&event, signum, EVFILT_SIGNAL, EV_ADD | EV_CLEAR, 0, 0, this); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + + // We must check for the signal now in case it was delivered previously and is currently in + // the blocked set. See comment in tryConsumeSignal(). (To avoid the race condition, we must + // check *after* having registered the kevent!) + tryConsumeSignal(); + } + + ~SignalPromiseAdapter() noexcept(false) { + // Unregister the event. This is important because it contains a pointer to this object which + // we don't want to see again. + struct kevent event; + EV_SET(&event, signum, EVFILT_SIGNAL, EV_DELETE, 0, 0, nullptr); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + } + + void tryConsumeSignal() { + // Unfortunately KJ's signal semantics are not a great fit for kqueue. In particular, KJ + // assumes that if no threads are waiting for a signal, it'll remain blocked until some + // thread actually calls `onSignal()` to receive it. kqueue, however, doesn't care if a signal + // is blocked -- the kqueue event will still be delivered. So, when `onSignal()` is called + // we will need to check if the signal is already queued; it's too late to ask kqueue() to + // tell us this. + // + // Alternatively we could maybe fix this by having every thread's kqueue wait on all captured + // signals all the time, but this would result in a thundering herd on any signal even if only + // one thread has actually registered interest. + // + // Another problem is per-thread signals, delivered with pthread_kill(). On FreeBSD, it appears + // a pthread_kill will wake up all kqueues in the process waiting on the particular signal, + // even if they are not associated with the target thread (kqueues don't really have any + // association with threads anyway). Worse, though, on MacOS, pthread_kill() doesn't wake + // kqueues at all. In fact, it appears they made it this way in 10.14, which broke stuff: + // https://github.com/libevent/libevent/issues/765 + // + // So, we have to: + // - Block signals normally. + // - Poll for a specific signal using sigtimedwait() or similar. + // - Use kqueue only as a hint to tell us when polling might be a good idea. + // - On MacOS, live with per-thread signals being broken I guess? + + // Anyway, this method here tries to have the signal delivered to this thread. + + if (fulfiller.isWaiting()) { +#if KJ_HAS_SIGTIMEDWAIT + sigset_t mask; + KJ_SYSCALL(sigemptyset(&mask)); + KJ_SYSCALL(sigaddset(&mask, signum)); + siginfo_t result; + struct timespec timeout; + memset(&timeout, 0, sizeof(timeout)); + + KJ_SYSCALL_HANDLE_ERRORS(sigtimedwait(&mask, &result, &timeout)) { + case EAGAIN: + // Signal was not queued. + return; + default: + KJ_FAIL_SYSCALL("sigtimedwait", error); + } + + fulfiller.fulfill(kj::mv(result)); +#else + // This platform doesn't appear to have sigtimedwait(). Ugh! We are forced to do two separate + // syscalls to see if the signal is pending, and then, if so, wait for it. There is an + // inherent race condition since the signal could be dequeued in another thread concurrently. + // We will try to work around that by locking a global mutex, so at least this code doesn't + // race against itself. + static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER; + pthread_mutex_lock(&mut); + KJ_DEFER(pthread_mutex_unlock(&mut)); + + sigset_t mask; + KJ_SYSCALL(sigpending(&mask)); + int isset; + KJ_SYSCALL(isset = sigismember(&mask, signum)); + if (isset) { + KJ_SYSCALL(sigfillset(&mask)); + KJ_SYSCALL(sigdelset(&mask, signum)); + siginfo_t info; + memset(&info, 0, sizeof(info)); + threadCapture = &info; + KJ_DEFER(threadCapture = nullptr); + int result = sigsuspend(&mask); + KJ_ASSERT(result < 0 && errno == EINTR, "sigsuspend() didn't EINTR?", result, errno); + KJ_ASSERT(info.si_signo == signum); + fulfiller.fulfill(kj::mv(info)); + } +#endif } } - return result; + UnixEventPort& eventPort; + int signum; + PromiseFulfiller& fulfiller; +}; + +Promise UnixEventPort::onSignal(int signum) { + KJ_REQUIRE(signum != SIGCHLD || !capturedChildExit, + "can't call onSigal(SIGCHLD) when kj::UnixEventPort::captureChildExit() has been called"); + + return newAdaptedPromise(*this, signum); } -bool UnixEventPort::doEpollWait(int timeout) { - sigset_t newMask; - sigemptyset(&newMask); +class UnixEventPort::ChildExitPromiseAdapter { +public: + inline ChildExitPromiseAdapter(PromiseFulfiller& fulfiller, + UnixEventPort& eventPort, Maybe& pid) + : eventPort(eventPort), pid(pid), fulfiller(fulfiller) { + pid_t p = KJ_ASSERT_NONNULL(pid); - { - auto ptr = signalHead; - while (ptr != nullptr) { - sigaddset(&newMask, ptr->signum); - ptr = ptr->next; + struct kevent event; + EV_SET(&event, p, EVFILT_PROC, EV_ADD | EV_CLEAR, NOTE_EXIT, 0, this); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + + // Check for race where child had already exited before the event was waiting. + tryConsumeChild(); + } + + ~ChildExitPromiseAdapter() noexcept(false) { + KJ_IF_MAYBE(p, pid) { + // The process has not been reaped. The promise must have been canceled. So, we're still + // registered with the kqueue. We'd better unregister because the kevent points back to this + // object. + struct kevent event; + EV_SET(&event, *p, EVFILT_PROC, EV_DELETE, 0, 0, nullptr); + KJ_SYSCALL(kevent(eventPort.kqueueFd, &event, 1, nullptr, 0, nullptr)); + + // We leak the zombie process here. The caller is responsible for doing its own waitpid(). } } - if (memcmp(&newMask, &signalFdSigset, sizeof(newMask)) != 0) { - // Apparently we're not waiting on the same signals as last time. Need to update the signal - // FD's mask. - signalFdSigset = newMask; - KJ_SYSCALL(signalfd(signalFd, &signalFdSigset, SFD_NONBLOCK | SFD_CLOEXEC)); + void tryConsumeChild() { + // Even though kqueue delivers the exit status to us, we still need to wait on the pid to + // clear the zombie. We can't set SIGCHLD to SIG_IGN to ignore this because it creates a race + // condition. + + KJ_IF_MAYBE(p, pid) { + int status; + pid_t result; + KJ_SYSCALL(result = waitpid(*p, &status, WNOHANG)); + if (result != 0) { + KJ_ASSERT(result == *p); + + // NOTE: The proc is automatically unregsitered from the kqueue on exit, so we should NOT + // attempt to unregister it here. + + pid = nullptr; + fulfiller.fulfill(kj::mv(status)); + } + } } - struct epoll_event events[16]; - int n; - KJ_SYSCALL(n = epoll_wait(epollFd, events, kj::size(events), timeout)); + UnixEventPort& eventPort; + Maybe& pid; + PromiseFulfiller& fulfiller; +}; + +Promise UnixEventPort::onChildExit(Maybe& pid) { + KJ_REQUIRE(capturedChildExit, + "must call UnixEventPort::captureChildExit() to use onChildExit()."); + + return kj::newAdaptedPromise(*this, pid); +} + +void UnixEventPort::captureChildExit() { + capturedChildExit = true; +} + +void UnixEventPort::wake() const { + // Trigger our user event. + struct kevent event; + EV_SET(&event, 0, EVFILT_USER, 0, NOTE_TRIGGER, 0, nullptr); + KJ_SYSCALL(kevent(kqueueFd, &event, 1, nullptr, 0, nullptr)); +} + +bool UnixEventPort::doKqueueWait(struct timespec* timeout) { + struct kevent events[16]; + int n = kevent(kqueueFd, nullptr, 0, events, kj::size(events), timeout); + + if (n < 0) { + int error = errno; + if (error == EINTR) { + // We received a singal. The signal handler may have queued an event to the event loop. Even + // if it didn't, we can't simply restart the kevent call because we need to recompute the + // timeout. Instead, we pretend kevent() returned zero events. This will cause the event + // loop to spin once, decide it has nothing to do, recompute timeouts, then return to waiting. + n = 0; + } else { + KJ_FAIL_SYSCALL("kevent()", error); + } + } bool woken = false; for (int i = 0; i < n; i++) { - if (events[i].data.u64 == 0) { - for (;;) { - struct signalfd_siginfo siginfo; - ssize_t n; - KJ_NONBLOCKING_SYSCALL(n = read(signalFd, &siginfo, sizeof(siginfo))); - if (n < 0) break; // no more signals + switch (events[i].filter) { +#ifdef EVFILT_EXCEPT + case EVFILT_EXCEPT: +#endif + case EVFILT_READ: + case EVFILT_WRITE: { + FdObserver* observer = reinterpret_cast(events[i].udata); + observer->fire(events[i]); + break; + } - KJ_ASSERT(n == sizeof(siginfo)); + case EVFILT_SIGNAL: { + SignalPromiseAdapter* observer = reinterpret_cast(events[i].udata); + observer->tryConsumeSignal(); + break; + } - gotSignal(toRegularSiginfo(siginfo)); + case EVFILT_PROC: { + ChildExitPromiseAdapter* observer = + reinterpret_cast(events[i].udata); + observer->tryConsumeChild(); + break; } - } else if (events[i].data.u64 == 1) { - // Someone called wake() from another thread. Consume the event. - uint64_t value; - ssize_t n; - KJ_NONBLOCKING_SYSCALL(n = read(eventFd, &value, sizeof(value))); - KJ_ASSERT(n < 0 || n == sizeof(value)); - // We were woken. Need to return true. - woken = true; - } else { - FdObserver* observer = reinterpret_cast(events[i].data.ptr); - observer->fire(events[i].events); + case EVFILT_USER: + // Someone called wake() from another thread. + woken = true; + break; + + default: + KJ_FAIL_ASSERT("unexpected EVFILT", events[i].filter); } } - timerImpl.advanceTo(readClock()); + timerImpl.advanceTo(clock.now()); return woken; } -#else // KJ_USE_EPOLL +bool UnixEventPort::wait() { + KJ_IF_MAYBE(t, timerImpl.timeoutToNextEvent(clock.now(), NANOSECONDS, int(maxValue))) { + struct timespec timeout; + timeout.tv_sec = *t / 1'000'000'000; + timeout.tv_nsec = *t % 1'000'000'000; + return doKqueueWait(&timeout); + } else { + return doKqueueWait(nullptr); + } +} + +bool UnixEventPort::poll() { + struct timespec timeout; + memset(&timeout, 0, sizeof(timeout)); + return doKqueueWait(&timeout); +} + +#else // KJ_USE_EPOLL, else KJ_USE_KQUEUE // ======================================================================================= // Traditional poll() FdObserver implementation. @@ -496,12 +1154,33 @@ bool UnixEventPort::doEpollWait(int timeout) { #endif UnixEventPort::UnixEventPort() - : timerImpl(readClock()) { + : clock(systemPreciseMonotonicClock()), + timerImpl(clock.now()) { +#if KJ_USE_PIPE_FOR_WAKEUP + // Allocate a pipe to which we'll write a byte in order to wake this thread. + int fds[2]; + KJ_SYSCALL(pipe(fds)); + wakePipeIn = kj::AutoCloseFd(fds[0]); + wakePipeOut = kj::AutoCloseFd(fds[1]); + KJ_SYSCALL(fcntl(wakePipeIn, F_SETFD, FD_CLOEXEC)); + KJ_SYSCALL(fcntl(wakePipeOut, F_SETFD, FD_CLOEXEC)); +#else static_assert(sizeof(threadId) >= sizeof(pthread_t), "pthread_t is larger than a long long on your platform. Please port."); *reinterpret_cast(&threadId) = pthread_self(); - pthread_once(®isterReservedSignalOnce, ®isterReservedSignal); + // Note: We used to use a pthread_once to call registerReservedSignal() only once per process. + // This didn't work correctly because registerReservedSignal() not only registers the + // (process-wide) signal handler, but also sets the (per-thread) signal mask to block the + // signal. Thus, if threads were spawned before the first UnixEventPort was created, and then + // multiple threads created UnixEventPorts, only one of them would have the signal properly + // blocked. We could have changed things so that only the handler registration was protected + // by the pthread_once and the mask update happened in every thread, but registering a signal + // handler is not an expensive operation, so whatever... we'll do it in every thread. + registerReservedSignal(); +#endif + + ignoreSigpipe(); } UnixEventPort::~UnixEventPort() noexcept(false) {} @@ -545,6 +1224,13 @@ void UnixEventPort::FdObserver::fire(short events) { } } + if (events & (POLLHUP | POLLERR | POLLNVAL)) { + KJ_IF_MAYBE(f, hupFulfiller) { + f->get()->fulfill(); + hupFulfiller = nullptr; + } + } + if (events & POLLPRI) { KJ_IF_MAYBE(f, urgentFulfiller) { f->get()->fulfill(); @@ -552,7 +1238,8 @@ void UnixEventPort::FdObserver::fire(short events) { } } - if (readFulfiller == nullptr && writeFulfiller == nullptr && urgentFulfiller == nullptr) { + if (readFulfiller == nullptr && writeFulfiller == nullptr && urgentFulfiller == nullptr && + hupFulfiller == nullptr) { // Remove from list. if (next == nullptr) { eventPort.observersTail = prev; @@ -568,7 +1255,16 @@ void UnixEventPort::FdObserver::fire(short events) { short UnixEventPort::FdObserver::getEventMask() { return (readFulfiller == nullptr ? 0 : (POLLIN | POLLRDHUP)) | (writeFulfiller == nullptr ? 0 : POLLOUT) | - (urgentFulfiller == nullptr ? 0 : POLLPRI); + (urgentFulfiller == nullptr ? 0 : POLLPRI) | + // The POSIX standard says POLLHUP and POLLERR will be reported even if not requested. + // But on MacOS, if `events` is 0, then POLLHUP apparently will not be reported: + // https://openradar.appspot.com/37537852 + // It seems that by settingc any non-zero value -- even one documented as ignored -- we + // cause POLLHUP to be reported. Both POLLHUP and POLLERR are documented as being ignored. + // So, we'll go ahead and set them. This has no effect on non-broken OSs, causes MacOS to + // do the right thing, and sort of looks as if we're explicitly requesting notification of + // these two conditions, which we do after all want to know about. + POLLHUP | POLLERR; } Promise UnixEventPort::FdObserver::whenBecomesReadable() { @@ -617,43 +1313,85 @@ Promise UnixEventPort::FdObserver::whenUrgentDataAvailable() { return kj::mv(paf.promise); } +Promise UnixEventPort::FdObserver::whenWriteDisconnected() { + if (prev == nullptr) { + KJ_DASSERT(next == nullptr); + prev = eventPort.observersTail; + *prev = this; + eventPort.observersTail = &next; + } + + auto paf = newPromiseAndFulfiller(); + hupFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); +} + class UnixEventPort::PollContext { public: - PollContext(FdObserver* ptr) { - while (ptr != nullptr) { + PollContext(UnixEventPort& port) { + for (FdObserver* ptr = port.observersHead; ptr != nullptr; ptr = ptr->next) { struct pollfd pollfd; memset(&pollfd, 0, sizeof(pollfd)); pollfd.fd = ptr->fd; pollfd.events = ptr->getEventMask(); pollfds.add(pollfd); pollEvents.add(ptr); - ptr = ptr->next; } + +#if KJ_USE_PIPE_FOR_WAKEUP + { + struct pollfd pollfd; + memset(&pollfd, 0, sizeof(pollfd)); + pollfd.fd = port.wakePipeIn; + pollfd.events = POLLIN; + pollfds.add(pollfd); + } +#endif } void run(int timeout) { - do { - pollResult = ::poll(pollfds.begin(), pollfds.size(), timeout); - pollError = pollResult < 0 ? errno : 0; - - // EINTR should only happen if we received a signal *other than* the ones registered via - // the UnixEventPort, so we don't care about that case. - } while (pollError == EINTR); + pollResult = ::poll(pollfds.begin(), pollfds.size(), timeout); + pollError = pollResult < 0 ? errno : 0; + + if (pollError == EINTR) { + // We can't simply restart the poll call because we need to recompute the timeout. Instead, + // we pretend poll() returned zero events. This will cause the event loop to spin once, + // decide it has nothing to do, recompute timeouts, then return to waiting. + pollResult = 0; + pollError = 0; + } } - void processResults() { + bool processResults() { if (pollResult < 0) { KJ_FAIL_SYSCALL("poll()", pollError); } + bool woken = false; for (auto i: indices(pollfds)) { if (pollfds[i].revents != 0) { - pollEvents[i]->fire(pollfds[i].revents); +#if KJ_USE_PIPE_FOR_WAKEUP + if (i == pollEvents.size()) { + // The last pollfd is our cross-thread wake pipe. + woken = true; + // Discard junk in the wake pipe. + char junk[256]; + ssize_t n; + do { + KJ_NONBLOCKING_SYSCALL(n = read(pollfds[i].fd, junk, sizeof(junk))); + } while (n >= 256); + } else { +#endif + pollEvents[i]->fire(pollfds[i].revents); +#if KJ_USE_PIPE_FOR_WAKEUP + } +#endif if (--pollResult <= 0) { break; } } } + return woken; } private: @@ -666,7 +1404,10 @@ private: bool UnixEventPort::wait() { sigset_t newMask; sigemptyset(&newMask); + +#if !KJ_USE_PIPE_FOR_WAKEUP sigaddset(&newMask, reservedSignal); +#endif { auto ptr = signalHead; @@ -674,43 +1415,58 @@ bool UnixEventPort::wait() { sigaddset(&newMask, ptr->signum); ptr = ptr->next; } + if (childSet != nullptr) { + sigaddset(&newMask, SIGCHLD); + } } - PollContext pollContext(observersHead); + PollContext pollContext(*this); // Capture signals. SignalCapture capture; +#if KJ_BROKEN_SIGLONGJMP + if (sigsetjmp(capture.jumpTo, false)) { +#else if (sigsetjmp(capture.jumpTo, true)) { +#endif // We received a signal and longjmp'd back out of the signal handler. threadCapture = nullptr; +#if !KJ_USE_PIPE_FOR_WAKEUP if (capture.siginfo.si_signo == reservedSignal) { return true; } else { +#endif gotSignal(capture.siginfo); return false; +#if !KJ_USE_PIPE_FOR_WAKEUP } +#endif } // Enable signals, run the poll, then mask them again. - sigset_t origMask; +#if KJ_BROKEN_SIGLONGJMP + auto& originalMask = capture.originalMask; +#else + sigset_t originalMask; +#endif threadCapture = &capture; - sigprocmask(SIG_UNBLOCK, &newMask, &origMask); + pthread_sigmask(SIG_UNBLOCK, &newMask, &originalMask); pollContext.run( - timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue)) + timerImpl.timeoutToNextEvent(clock.now(), MILLISECONDS, int(maxValue)) .map([](uint64_t t) -> int { return t; }) .orDefault(-1)); - sigprocmask(SIG_SETMASK, &origMask, nullptr); + pthread_sigmask(SIG_SETMASK, &originalMask, nullptr); threadCapture = nullptr; // Queue events. - pollContext.processResults(); - timerImpl.advanceTo(readClock()); + bool result = pollContext.processResults(); + timerImpl.advanceTo(clock.now()); - return false; + return result; } bool UnixEventPort::poll() { @@ -726,11 +1482,13 @@ bool UnixEventPort::poll() { KJ_SYSCALL(sigpending(&pending)); uint signalCount = 0; +#if !KJ_USE_PIPE_FOR_WAKEUP if (sigismember(&pending, reservedSignal)) { ++signalCount; sigdelset(&pending, reservedSignal); sigdelset(&waitMask, reservedSignal); } +#endif { auto ptr = signalHead; @@ -746,43 +1504,84 @@ bool UnixEventPort::poll() { // Wait for each pending signal. It would be nice to use sigtimedwait() here but it is not // available on OSX. :( Instead, we call sigsuspend() once per expected signal. - while (signalCount-- > 0) { + { SignalCapture capture; +#if KJ_BROKEN_SIGLONGJMP + pthread_sigmask(SIG_SETMASK, nullptr, &capture.originalMask); +#endif threadCapture = &capture; - if (sigsetjmp(capture.jumpTo, true)) { - // We received a signal and longjmp'd back out of the signal handler. - sigdelset(&waitMask, capture.siginfo.si_signo); - if (capture.siginfo.si_signo == reservedSignal) { - woken = true; + KJ_DEFER(threadCapture = nullptr); + while (signalCount-- > 0) { +#if KJ_BROKEN_SIGLONGJMP + if (sigsetjmp(capture.jumpTo, false)) { +#else + if (sigsetjmp(capture.jumpTo, true)) { +#endif + // We received a signal and longjmp'd back out of the signal handler. + sigdelset(&waitMask, capture.siginfo.si_signo); +#if !KJ_USE_PIPE_FOR_WAKEUP + if (capture.siginfo.si_signo == reservedSignal) { + woken = true; + } else { +#endif + gotSignal(capture.siginfo); +#if !KJ_USE_PIPE_FOR_WAKEUP + } +#endif } else { - gotSignal(capture.siginfo); +#if __CYGWIN__ + // Cygwin's sigpending() incorrectly reports signals pending for any thread, not just our + // own thread. As a work-around, instead of using sigsuspend() (which would block forever + // if the signal is not pending on *this* thread), we un-mask the signals and immediately + // mask them again. If any signals are pending, they *should* be delivered before the first + // sigprocmask() returns, and the handler will then longjmp() to the block above. If it + // turns out no signal is pending, we'll block the signals again and break out of the + // loop. + // + // Bug reported here: https://cygwin.com/ml/cygwin/2019-07/msg00051.html + sigset_t origMask; + sigprocmask(SIG_SETMASK, &waitMask, &origMask); + sigprocmask(SIG_SETMASK, &origMask, nullptr); + break; +#else + sigsuspend(&waitMask); + KJ_FAIL_ASSERT("sigsuspend() shouldn't return because the signal handler should " + "have siglongjmp()ed."); +#endif } - } else { - sigsuspend(&waitMask); - KJ_FAIL_ASSERT("sigsuspend() shouldn't return because the signal handler should " - "have siglongjmp()ed."); } - threadCapture = nullptr; } { - PollContext pollContext(observersHead); + PollContext pollContext(*this); pollContext.run(0); - pollContext.processResults(); + if (pollContext.processResults()) { + woken = true; + } } - timerImpl.advanceTo(readClock()); + timerImpl.advanceTo(clock.now()); return woken; } void UnixEventPort::wake() const { +#if KJ_USE_PIPE_FOR_WAKEUP + // We're going to write() a single byte to our wake pipe in order to cause poll() to complete in + // the target thread. + // + // If this write() fails with EWOULDBLOCK, we don't care, because the target thread is already + // scheduled to wake up. + char c = 0; + KJ_NONBLOCKING_SYSCALL(write(wakePipeOut, &c, 1)); +#else int error = pthread_kill(*reinterpret_cast(&threadId), reservedSignal); if (error != 0) { KJ_FAIL_SYSCALL("pthread_kill", error); } +#endif } -#endif // KJ_USE_EPOLL, else +#endif // KJ_USE_EPOLL, else KJ_USE_KQUEUE, else } // namespace kj diff --git a/c++/src/kj/async-unix.h b/c++/src/kj/async-unix.h index 34068d6724..665305ea70 100644 --- a/c++/src/kj/async-unix.h +++ b/c++/src/kj/async-unix.h @@ -19,26 +19,47 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_ASYNC_UNIX_H_ -#define KJ_ASYNC_UNIX_H_ +#pragma once #if _WIN32 #error "This file is Unix-specific. On Windows, include async-win32.h instead." #endif -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif - #include "async.h" -#include "time.h" -#include "vector.h" -#include "io.h" +#include "timer.h" +#include +#include #include -#if (__EMSCRIPTEN__ || (__linux__ && !__BIONIC__)) && !defined(KJ_USE_EPOLL) -// Default to epoll on Linux, except on Bionic (Android) which doesn't have signalfd.h. +KJ_BEGIN_HEADER + +#if !defined(KJ_USE_EPOLL) && !defined(KJ_USE_KQUEUE) +#if __linux__ +// Default to epoll on Linux. #define KJ_USE_EPOLL 1 +#elif __APPLE__ || __FreeBSD__ || __OpenBSD__ || __NetBSD__ || __DragonFly__ +// MacOS and BSDs prefer kqueue() for event notification. +#define KJ_USE_KQUEUE 1 +#endif +#endif + +#if KJ_USE_EPOLL && KJ_USE_KQUEUE +#error "Both KJ_USE_EPOLL and KJ_USE_KQUEUE are set. Please choose only one of these." +#endif + +#if __CYGWIN__ && !defined(KJ_USE_PIPE_FOR_WAKEUP) +// Cygwin has serious issues with the intersection of signals and threads, reported here: +// https://cygwin.com/ml/cygwin/2019-07/msg00052.html +// On Cygwin, therefore, we do not use signals to wake threads. Instead, each thread allocates a +// pipe, and we write a byte to the pipe to wake the thread... ick. +#define KJ_USE_PIPE_FOR_WAKEUP 1 +#endif + +#if KJ_USE_EPOLL +struct epoll_event; +#elif KJ_USE_KQUEUE +struct kevent; +struct timespec; #endif namespace kj { @@ -50,20 +71,26 @@ class UnixEventPort: public EventPort { // The implementation uses `poll()` or possibly a platform-specific API (e.g. epoll, kqueue). // To also wait on signals without race conditions, the implementation may block signals until // just before `poll()` while using a signal handler which `siglongjmp()`s back to just before - // the signal was unblocked, or it may use a nicer platform-specific API like signalfd. + // the signal was unblocked, or it may use a nicer platform-specific API. // // The implementation reserves a signal for internal use. By default, it uses SIGUSR1. If you // need to use SIGUSR1 for something else, you must offer a different signal by calling - // setReservedSignal() at startup. + // setReservedSignal() at startup. (On Linux, no signal is reserved; eventfd is used instead.) // // WARNING: A UnixEventPort can only be used in the thread and process that created it. In // particular, note that after a fork(), a UnixEventPort created in the parent process will // not work correctly in the child, even if the parent ceases to use its copy. In particular // note that this means that server processes which daemonize themselves at startup must wait // until after daemonization to create a UnixEventPort. + // + // TODO(cleanup): The above warning is no longer accurate -- daemonizing after creating a + // UnixEventPort should now work since we no longer use signalfd. But do we want to commit to + // keeping it that way? Note it's still unsafe to fork() and then use UnixEventPort from both + // processes! public: UnixEventPort(); + ~UnixEventPort() noexcept(false); class FdObserver; @@ -79,6 +106,15 @@ class UnixEventPort: public EventPort { // process-wide signal by only calling `onSignal()` on that thread's event loop. // // The result of waiting on the same signal twice at once is undefined. + // + // WARNING: On MacOS and iOS, `onSignal()` will only see process-level signals, NOT + // thread-specific signals (i.e. not those sent with pthread_kill()). This is a limitation of + // Apple's implemnetation of kqueue() introduced in MacOS 10.14 which Apple says is not a bug. + // See: https://github.com/libevent/libevent/issues/765 Consider using kj::Executor or + // kj::newPromiseAndCrossThreadFulfiller() for cross-thread communications instead of signals. + // If you must have signals, build KJ and your app with `-DKJ_USE_KQUEUE=0`, which will cause + // KJ to fall back to a generic poll()-based implementation that is less efficient but handles + // thread-specific signals. static void captureSignal(int signum); // Arranges for the given signal to be captured and handled via UnixEventPort, so that you may @@ -99,48 +135,108 @@ class UnixEventPort: public EventPort { Timer& getTimer() { return timerImpl; } + Promise onChildExit(Maybe& pid); + // When the given child process exits, resolves to its wait status, as returned by wait(2). You + // will need to use the WIFEXITED() etc. macros to interpret the status code. + // + // You must call onChildExit() immediately after the child is created, before returning to the + // event loop. Otherwise, you may miss the child exit event. + // + // `pid` is a reference to a Maybe which must be non-null at the time of the call. When + // wait() is invoked (and indicates this pid has finished), `pid` will be nulled out. This is + // necessary to avoid a race condition: as soon as the child has been wait()ed, the PID table + // entry is freed and can then be reused. So, if you ever want safely to call `kill()` on the + // PID, it's necessary to know whether it has been wait()ed already. Since the promise's + // .then() continuation may not run immediately, we need a more precise way, hence we null out + // the Maybe. + // + // The caller must NOT null out `pid` on its own unless it cancels the Promise first. If the + // caller decides to cancel the Promise, and `pid` is still non-null after this cancellation, + // then the caller is expected to `waitpid()` on it BEFORE returning to the event loop again. + // Probably, the caller should kill() the child before waiting to avoid a hang. If the caller + // fails to do its own waitpid() before returning to the event loop, the child may become a + // zombie, or may be reaped automatically, depending on the platform -- since the caller does not + // know, the caller cannot try to reap the zombie later. + // + // You must call `kj::UnixEventPort::captureChildExit()` early in your program if you want to use + // `onChildExit()`. + // + // WARNING: Only one UnixEventPort per process is allowed to use onChildExit(). This is because + // child exit is signaled to the process via SIGCHLD, and Unix does not allow the program to + // control which thread receives the signal. (We may fix this in the future by automatically + // coordinating between threads when multiple threads are expecting child exits.) + // WARNING 2: If any UnixEventPort in the process is currently waiting for onChildExit(), then + // *only* that port's thread can safely wait on child processes, even synchronously. This is + // because the thread which used onChildExit() uses wait() to reap children, without specifying + // which child, and therefore it may inadvertently reap children created by other threads. + + static void captureChildExit(); + // Arranges for child process exit to be captured and handled via UnixEventPort, so that you may + // call `onChildExit()`. Much like `captureSignal()`, this static method must be called early on + // in program startup. + // + // This method may capture the `SIGCHLD` signal. You must not use `captureSignal(SIGCHLD)` nor + // `onSignal(SIGCHLD)` in your own code if you use `captureChildExit()`. + // implements EventPort ------------------------------------------------------ bool wait() override; bool poll() override; void wake() const override; private: - struct TimerSet; // Defined in source file to avoid STL include. - class TimerPromiseAdapter; class SignalPromiseAdapter; + class ChildExitPromiseAdapter; + const MonotonicClock& clock; TimerImpl timerImpl; +#if !KJ_USE_KQUEUE SignalPromiseAdapter* signalHead = nullptr; SignalPromiseAdapter** signalTail = &signalHead; - TimePoint readClock(); void gotSignal(const siginfo_t& siginfo); +#endif friend class TimerPromiseAdapter; #if KJ_USE_EPOLL + sigset_t originalMask; AutoCloseFd epollFd; - AutoCloseFd signalFd; AutoCloseFd eventFd; // Used for cross-thread wakeups. - sigset_t signalFdSigset; - // Signal mask as currently set on the signalFd. Tracked so we can detect whether or not it - // needs updating. - - bool doEpollWait(int timeout); + bool processEpollEvents(struct epoll_event events[], int n); +#elif KJ_USE_KQUEUE + AutoCloseFd kqueueFd; + bool doKqueueWait(struct timespec* timeout); #else class PollContext; FdObserver* observersHead = nullptr; FdObserver** observersTail = &observersHead; +#if KJ_USE_PIPE_FOR_WAKEUP + AutoCloseFd wakePipeIn; + AutoCloseFd wakePipeOut; +#else unsigned long long threadId; // actually pthread_t #endif +#endif + +#if !KJ_USE_KQUEUE + struct ChildSet; + Maybe> childSet; +#endif + + static void signalHandler(int, siginfo_t* siginfo, void*) noexcept; + static void registerSignalHandler(int signum); +#if !KJ_USE_EPOLL && !KJ_USE_KQUEUE && !KJ_USE_PIPE_FOR_WAKEUP + static void registerReservedSignal(); +#endif + static void ignoreSigpipe(); }; -class UnixEventPort::FdObserver { +class UnixEventPort::FdObserver: private AsyncObject { // Object which watches a file descriptor to determine when it is readable or writable. // // For listen sockets, "readable" means that there is a connection to accept(). For everything @@ -169,7 +265,7 @@ class UnixEventPort::FdObserver { ~FdObserver() noexcept(false); - KJ_DISALLOW_COPY(FdObserver); + KJ_DISALLOW_COPY_AND_MOVE(FdObserver); Promise whenBecomesReadable(); // Resolves the next time the file descriptor transitions from having no data to read to having @@ -240,7 +336,10 @@ class UnixEventPort::FdObserver { // has not yet resolved. If you do this, the previous promise may throw an exception. // // WARNING: This has some known weird behavior on macOS. See - // https://github.com/sandstorm-io/capnproto/issues/374. + // https://github.com/capnproto/capnproto/issues/374. + + Promise whenWriteDisconnected(); + // Resolves when poll() on the file descriptor reports POLLHUP or POLLERR. private: UnixEventPort& eventPort; @@ -250,12 +349,17 @@ class UnixEventPort::FdObserver { kj::Maybe>> readFulfiller; kj::Maybe>> writeFulfiller; kj::Maybe>> urgentFulfiller; + kj::Maybe>> hupFulfiller; // Replaced each time `whenBecomesReadable()` or `whenBecomesWritable()` is called. Reverted to // null every time an event is fired. Maybe atEnd; +#if KJ_USE_KQUEUE + void fire(struct kevent event); +#else void fire(short events); +#endif #if !KJ_USE_EPOLL FdObserver* next; @@ -271,4 +375,4 @@ class UnixEventPort::FdObserver { } // namespace kj -#endif // KJ_ASYNC_UNIX_H_ +KJ_END_HEADER diff --git a/c++/src/kj/async-win32-test.c++ b/c++/src/kj/async-win32-test.c++ index 6866bda9b7..3dd6e5bc98 100644 --- a/c++/src/kj/async-win32-test.c++ +++ b/c++/src/kj/async-win32-test.c++ @@ -24,6 +24,7 @@ #include "async-win32.h" #include "thread.h" #include "test.h" +#include "mutex.h" namespace kj { namespace { diff --git a/c++/src/kj/async-win32-xthread-test.c++ b/c++/src/kj/async-win32-xthread-test.c++ new file mode 100644 index 0000000000..c93be7fe99 --- /dev/null +++ b/c++/src/kj/async-win32-xthread-test.c++ @@ -0,0 +1,32 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if _WIN32 + +#include "async-win32.h" + +#define KJ_XTHREAD_TEST_SETUP_LOOP \ + Win32IocpEventPort port; \ + EventLoop loop(port); \ + WaitScope waitScope(loop) +#include "async-xthread-test.c++" + +#endif // _WIN32 diff --git a/c++/src/kj/async-win32.c++ b/c++/src/kj/async-win32.c++ index 6a5d470fd1..ed82e55206 100644 --- a/c++/src/kj/async-win32.c++ +++ b/c++/src/kj/async-win32.c++ @@ -22,13 +22,12 @@ #if _WIN32 // Request Vista-level APIs. -#define WINVER 0x0600 -#define _WIN32_WINNT 0x0600 +#include #include "async-win32.h" #include "debug.h" #include -#include +#include "time.h" #include "refcount.h" #include // NTSTATUS #include // STATUS_SUCCESS @@ -38,7 +37,8 @@ namespace kj { Win32IocpEventPort::Win32IocpEventPort() - : iocp(newIocpHandle()), thread(openCurrentThread()), timerImpl(readClock()) {} + : clock(systemPreciseMonotonicClock()), + iocp(newIocpHandle()), thread(openCurrentThread()), timerImpl(clock.now()) {} Win32IocpEventPort::~Win32IocpEventPort() noexcept(false) {} @@ -157,17 +157,17 @@ Own Win32IocpEventPort::observeSignalState(HANDL return waitThreads.observeSignalState(handle); } -TimePoint Win32IocpEventPort::readClock() { - return origin() + std::chrono::duration_cast( - std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS; -} - bool Win32IocpEventPort::wait() { - waitIocp(timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, INFINITE - 1) + // It's possible that a wake event was received and discarded during ~IoPromiseAdapter. We + // need to check for that now. Otherwise, calling waitIocp may cause it to hang forever. + if (receivedWake()) { + return true; + } + waitIocp(timerImpl.timeoutToNextEvent(clock.now(), MILLISECONDS, INFINITE - 1) .map([](uint64_t t) -> DWORD { return t; }) .orDefault(INFINITE)); - timerImpl.advanceTo(readClock()); + timerImpl.advanceTo(clock.now()); return receivedWake(); } diff --git a/c++/src/kj/async-win32.h b/c++/src/kj/async-win32.h index b70c42e016..ddf4987a7a 100644 --- a/c++/src/kj/async-win32.h +++ b/c++/src/kj/async-win32.h @@ -19,27 +19,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_ASYNC_WIN32_H_ -#define KJ_ASYNC_WIN32_H_ +#pragma once #if !_WIN32 #error "This file is Windows-specific. On Unix, include async-unix.h instead." #endif +// Include windows.h as lean as possible. (If you need more of the Windows API for your app, +// #include windows.h yourself before including this header.) +#include + #include "async.h" -#include "time.h" +#include "timer.h" #include "io.h" #include #include -// Include windows.h as lean as possible. (If you need more of the Windows API for your app, -// #include windows.h yourself before including this header.) -#define WIN32_LEAN_AND_MEAN 1 -#define NOSERVICE 1 -#define NOMCX 1 -#define NOIME 1 #include -#include "windows-sanity.h" +#include + +KJ_BEGIN_HEADER namespace kj { @@ -121,14 +120,14 @@ class Win32EventPort: public EventPort { // Returns a promise that completes the next time the handle enters the signaled state. // // Depending on the type of handle, the handle may automatically be reset to a non-signaled - // state before the promise resolves. The underlying implementaiton uses WaitForSingleObject() + // state before the promise resolves. The underlying implementation uses WaitForSingleObject() // or an equivalent wait call, so check the documentation for that to understand the semantics. // // If the handle is a mutex and it is abandoned without being unlocked, the promise breaks with // an exception. virtual Promise onSignaledOrAbandoned() = 0; - // Like onSingaled(), but instead of throwing when a mutex is abandoned, resolves to `true`. + // Like onSignaled(), but instead of throwing when a mutex is abandoned, resolves to `true`. // Resolves to `false` for non-abandoned signals. }; @@ -180,7 +179,7 @@ class Win32WaitObjectThreadPool { bool finishedMainThreadWait(DWORD returnCode); // Call immediately after invoking WaitForMultipleObjects() or similar in the main thread, - // passing the value returend by that call. Returns true if the event indicated by `returnCode` + // passing the value returned by that call. Returns true if the event indicated by `returnCode` // has been handled (i.e. it was WAIT_OBJECT_n or WAIT_ABANDONED_n where n is in-range for the // last call to prepareMainThreadWait()). }; @@ -210,6 +209,8 @@ class Win32IocpEventPort final: public Win32EventPort { class IoOperationImpl; class IoObserverImpl; + const MonotonicClock& clock; + AutoCloseHandle iocp; AutoCloseHandle thread; Win32WaitObjectThreadPool waitThreads; @@ -217,8 +218,6 @@ class Win32IocpEventPort final: public Win32EventPort { mutable std::atomic sentWake {false}; bool isAllowApc = false; - static TimePoint readClock(); - void waitIocp(DWORD timeoutMs); // Wait on the I/O completion port for up to timeoutMs and pump events. Does not advance the // timer; caller must do that. @@ -231,4 +230,4 @@ class Win32IocpEventPort final: public Win32EventPort { } // namespace kj -#endif // KJ_ASYNC_WIN32_H_ +KJ_END_HEADER diff --git a/c++/src/kj/async-xthread-test.c++ b/c++/src/kj/async-xthread-test.c++ new file mode 100644 index 0000000000..b6bd237e19 --- /dev/null +++ b/c++/src/kj/async-xthread-test.c++ @@ -0,0 +1,1044 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if _WIN32 +#include "win32-api-version.h" +#endif + +#include "async.h" +#include "debug.h" +#include "thread.h" +#include "mutex.h" +#include + +#if _WIN32 +#include +#include "windows-sanity.h" +inline void delay() { Sleep(10); } +#else +#include +inline void delay() { usleep(10000); } +#endif + +// This file is #included from async-unix-xthread-test.c++ and async-win32-xthread-test.c++ after +// defining KJ_XTHREAD_TEST_SETUP_LOOP to set up a loop with the corresponding EventPort. +#ifndef KJ_XTHREAD_TEST_SETUP_LOOP +#define KJ_XTHREAD_TEST_SETUP_LOOP \ + EventLoop loop; \ + WaitScope waitScope(loop) +#endif + +namespace kj { +namespace { + +KJ_TEST("synchonous simple cross-thread events") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + thread_local bool isChild = false; // to assert which thread we're in + + // We use `noexcept` so that any uncaught exceptions immediately terminate the process without + // unwinding. Otherwise, the unwind would likely deadlock waiting for some synchronization with + // the other thread. + Thread thread([&]() noexcept { + isChild = true; + + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + KJ_ASSERT(paf.promise.wait(waitScope) == 123); + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + KJ_ASSERT(!isChild); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test exception", exec->executeSync([&]() { + KJ_ASSERT(isChild); + KJ_FAIL_ASSERT("test exception") { break; } + })); + + uint i = exec->executeSync([&]() { + KJ_ASSERT(isChild); + fulfiller->fulfill(123); + return 456; + }); + KJ_EXPECT(i == 456); + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("asynchronous simple cross-thread events") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + thread_local bool isChild = false; // to assert which thread we're in + + // We use `noexcept` so that any uncaught exceptions immediately terminate the process without + // unwinding. Otherwise, the unwind would likely deadlock waiting for some synchronization with + // the other thread. + Thread thread([&]() noexcept { + isChild = true; + + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + KJ_ASSERT(paf.promise.wait(waitScope) == 123); + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + KJ_ASSERT(!isChild); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test exception", exec->executeAsync([&]() { + KJ_ASSERT(isChild); + KJ_FAIL_ASSERT("test exception") { break; } + }).wait(waitScope)); + + Promise promise = exec->executeAsync([&]() { + KJ_ASSERT(isChild); + fulfiller->fulfill(123); + return 456u; + }); + KJ_EXPECT(promise.wait(waitScope) == 456); + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("synchonous promise cross-thread events") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + Promise promise = nullptr; // accessed only from the subthread + thread_local bool isChild = false; // to assert which thread we're in + + // We use `noexcept` so that any uncaught exceptions immediately terminate the process without + // unwinding. Otherwise, the unwind would likely deadlock waiting for some synchronization with + // the other thread. + Thread thread([&]() noexcept { + isChild = true; + + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + auto paf2 = newPromiseAndFulfiller(); + promise = kj::mv(paf2.promise); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + KJ_ASSERT(paf.promise.wait(waitScope) == 123); + + paf2.fulfiller->fulfill(321); + + // Make sure reply gets sent. + loop.run(); + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + KJ_ASSERT(!isChild); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test exception", exec->executeSync([&]() { + KJ_ASSERT(isChild); + return kj::Promise(KJ_EXCEPTION(FAILED, "test exception")); + })); + + uint i = exec->executeSync([&]() { + KJ_ASSERT(isChild); + fulfiller->fulfill(123); + return kj::mv(promise); + }); + KJ_EXPECT(i == 321); + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("asynchronous promise cross-thread events") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + Promise promise = nullptr; // accessed only from the subthread + thread_local bool isChild = false; // to assert which thread we're in + + // We use `noexcept` so that any uncaught exceptions immediately terminate the process without + // unwinding. Otherwise, the unwind would likely deadlock waiting for some synchronization with + // the other thread. + Thread thread([&]() noexcept { + isChild = true; + + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + auto paf2 = newPromiseAndFulfiller(); + promise = kj::mv(paf2.promise); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + KJ_ASSERT(paf.promise.wait(waitScope) == 123); + + paf2.fulfiller->fulfill(321); + + // Make sure reply gets sent. + loop.run(); + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + KJ_ASSERT(!isChild); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("test exception", exec->executeAsync([&]() { + KJ_ASSERT(isChild); + return kj::Promise(KJ_EXCEPTION(FAILED, "test exception")); + }).wait(waitScope)); + + Promise promise2 = exec->executeAsync([&]() { + KJ_ASSERT(isChild); + fulfiller->fulfill(123); + return kj::mv(promise); + }); + KJ_EXPECT(promise2.wait(waitScope) == 321); + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("cancel cross-thread event before it runs") { + MutexGuarded> executor; // to get the Executor from the other thread + + // We use `noexcept` so that any uncaught exceptions immediately terminate the process without + // unwinding. Otherwise, the unwind would likely deadlock waiting for some synchronization with + // the other thread. + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + // We never run the loop here, so that when the event is canceled, it's still queued. + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + volatile bool called = false; + { + Promise promise = exec->executeAsync([&]() { called = true; return 123u; }); + delay(); + KJ_EXPECT(!promise.poll(waitScope)); + } + KJ_EXPECT(!called); + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("cancel cross-thread event while it runs") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + + // We use `noexcept` so that any uncaught exceptions immediately terminate the process without + // unwinding. Otherwise, the unwind would likely deadlock waiting for some synchronization with + // the other thread. + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + paf.promise.wait(waitScope); + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + { + volatile bool called = false; + Promise promise = exec->executeAsync([&]() -> kj::Promise { + called = true; + return kj::NEVER_DONE; + }); + while (!called) { + delay(); + } + KJ_EXPECT(!promise.poll(waitScope)); + } + + exec->executeSync([&]() { fulfiller->fulfill(); }); + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("cross-thread cancellation in both directions at once") { + MutexGuarded> childExecutor; + MutexGuarded> parentExecutor; + + MutexGuarded readyCount(0); + + thread_local uint threadNumber = 0; + thread_local bool receivedFinalCall = false; + + // Code to execute simultaneously in two threads... + // We mark this noexcept so that any exceptions thrown will immediately invoke the termination + // handler, skipping any destructors that would deadlock. + auto simultaneous = [&](MutexGuarded>& selfExecutor, + MutexGuarded>& otherExecutor, + uint threadCount) noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + *selfExecutor.lockExclusive() = getCurrentThreadExecutor(); + + const Executor* exec; + { + auto lock = otherExecutor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + // Create a ton of cross-thread promises to cancel. + Vector> promises; + for (uint i = 0; i < 1000; i++) { + promises.add(exec->executeAsync([&]() -> kj::Promise { + return kj::Promise(kj::NEVER_DONE) + .attach(kj::defer([wasThreadNumber = threadNumber]() { + // Make sure destruction happens in the correct thread. + KJ_ASSERT(threadNumber == wasThreadNumber); + })); + })); + } + + // Signal other thread that we're done queueing, and wait for it to signal same. + { + auto lock = readyCount.lockExclusive(); + ++*lock; + lock.wait([&](uint i) { return i >= threadCount; }); + } + + // Run event loop to start all executions queued by the other thread. + waitScope.poll(); + loop.run(); + + // Signal other thread that we've run the loop, and wait for it to signal same. + { + auto lock = readyCount.lockExclusive(); + ++*lock; + lock.wait([&](uint i) { return i >= threadCount * 2; }); + } + + // Cancel all the promises. + promises.clear(); + + // All our cancellations completed, but the other thread may still be waiting for some + // cancellations from us. We need to pump our event loop to make sure we continue handling + // those cancellation requests. In particular we'll queue a function to the other thread and + // wait for it to complete. The other thread will queue its own function to this thread just + // before completing the function we queued to it. + receivedFinalCall = false; + exec->executeAsync([&]() { receivedFinalCall = true; }).wait(waitScope); + + // To be safe, make sure we've actually executed the function that the other thread queued to + // us by repeatedly polling until `receivedFinalCall` becomes true in this thread. + while (!receivedFinalCall) { + waitScope.poll(); + loop.run(); + } + + // OK, signal other that we're all done. + *otherExecutor.lockExclusive() = nullptr; + + // Wait until other thread sets executor to null, as a way to tell us to quit. + selfExecutor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }; + + { + Thread thread([&]() { + threadNumber = 1; + simultaneous(childExecutor, parentExecutor, 2); + }); + + threadNumber = 0; + simultaneous(parentExecutor, childExecutor, 2); + } + + // Let's even have a three-thread version, with cyclic cancellation requests. + MutexGuarded> child2Executor; + *readyCount.lockExclusive() = 0; + + { + Thread thread1([&]() { + threadNumber = 1; + simultaneous(childExecutor, child2Executor, 3); + }); + + Thread thread2([&]() { + threadNumber = 2; + simultaneous(child2Executor, parentExecutor, 3); + }); + + threadNumber = 0; + simultaneous(parentExecutor, childExecutor, 3); + } +} + +KJ_TEST("cross-thread cancellation cycle") { + // Another multi-way cancellation test where we set up an actual cycle between three threads + // waiting on each other to complete a single event. + + MutexGuarded> child1Executor, child2Executor; + + Own> fulfiller1, fulfiller2; + + auto threadMain = [](MutexGuarded>& executor, + Own>& fulfiller) noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + paf.promise.wait(waitScope); + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }; + + Thread thread1([&]() noexcept { threadMain(child1Executor, fulfiller1); }); + Thread thread2([&]() noexcept { threadMain(child2Executor, fulfiller2); }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + auto& parentExecutor = getCurrentThreadExecutor(); + + const Executor* exec1; + { + auto lock = child1Executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec1 = &KJ_ASSERT_NONNULL(*lock); + } + const Executor* exec2; + { + auto lock = child2Executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec2 = &KJ_ASSERT_NONNULL(*lock); + } + + // Create an event that cycles through both threads and back to this one, and then cancel it. + bool cycleAllDestroyed = false; + { + auto paf = kj::newPromiseAndFulfiller(); + Promise promise = exec1->executeAsync([&]() -> kj::Promise { + return exec2->executeAsync([&]() -> kj::Promise { + return parentExecutor.executeAsync([&]() -> kj::Promise { + paf.fulfiller->fulfill(); + return kj::Promise(kj::NEVER_DONE).attach(kj::defer([&]() { + cycleAllDestroyed = true; + })); + }); + }); + }); + + // Wait until the cycle has come all the way around. + paf.promise.wait(waitScope); + + KJ_EXPECT(!promise.poll(waitScope)); + } + + KJ_EXPECT(cycleAllDestroyed); + + exec1->executeSync([&]() { fulfiller1->fulfill(); }); + exec2->executeSync([&]() { fulfiller2->fulfill(); }); + + *child1Executor.lockExclusive() = nullptr; + *child2Executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("call own thread's executor") { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto& executor = getCurrentThreadExecutor(); + + { + uint i = executor.executeSync([]() { + return 123u; + }); + KJ_EXPECT(i == 123); + } + + KJ_EXPECT_THROW_MESSAGE( + "can't call executeSync() on own thread's executor with a promise-returning function", + executor.executeSync([]() { return kj::evalLater([]() {}); })); + + { + uint i = executor.executeAsync([]() { + return 123u; + }).wait(waitScope); + KJ_EXPECT(i == 123); + } +} + +KJ_TEST("synchronous cross-thread event disconnected") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + thread_local bool isChild = false; // to assert which thread we're in + + Thread thread([&]() noexcept { + isChild = true; + + { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + paf.promise.wait(waitScope); + + // Exit the event loop! + } + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + Own exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = KJ_ASSERT_NONNULL(*lock).addRef(); + } + + KJ_EXPECT(!isChild); + + KJ_EXPECT(exec->isLive()); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "Executor's event loop exited before cross-thread event could complete", + exec->executeSync([&]() -> Promise { + fulfiller->fulfill(); + return kj::NEVER_DONE; + })); + + KJ_EXPECT(!exec->isLive()); + + KJ_EXPECT_THROW_MESSAGE( + "Executor's event loop has exited", + exec->executeSync([&]() {})); + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("asynchronous cross-thread event disconnected") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + thread_local bool isChild = false; // to assert which thread we're in + + Thread thread([&]() noexcept { + isChild = true; + + { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + paf.promise.wait(waitScope); + + // Exit the event loop! + } + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + Own exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = KJ_ASSERT_NONNULL(*lock).addRef(); + } + + KJ_EXPECT(!isChild); + + KJ_EXPECT(exec->isLive()); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "Executor's event loop exited before cross-thread event could complete", + exec->executeAsync([&]() -> Promise { + fulfiller->fulfill(); + return kj::NEVER_DONE; + }).wait(waitScope)); + + KJ_EXPECT(!exec->isLive()); + + KJ_EXPECT_THROW_MESSAGE( + "Executor's event loop has exited", + exec->executeAsync([&]() {}).wait(waitScope)); + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("cross-thread event disconnected before it runs") { + MutexGuarded> executor; // to get the Executor from the other thread + thread_local bool isChild = false; // to assert which thread we're in + + Thread thread([&]() noexcept { + isChild = true; + + KJ_XTHREAD_TEST_SETUP_LOOP; + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + // Don't actually run the event loop. Destroy it when the other thread signals us to. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + Own exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = KJ_ASSERT_NONNULL(*lock).addRef(); + } + + KJ_EXPECT(!isChild); + + KJ_EXPECT(exec->isLive()); + + auto promise = exec->executeAsync([&]() { + KJ_LOG(ERROR, "shouldn't have executed"); + }); + KJ_EXPECT(!promise.poll(waitScope)); + + *executor.lockExclusive() = nullptr; + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "Executor's event loop exited before cross-thread event could complete", + promise.wait(waitScope)); + + KJ_EXPECT(!exec->isLive()); + })(); +} + +KJ_TEST("cross-thread event disconnected without holding Executor ref") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + thread_local bool isChild = false; // to assert which thread we're in + + Thread thread([&]() noexcept { + isChild = true; + + { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + paf.promise.wait(waitScope); + + // Exit the event loop! + } + + // Wait until parent thread sets executor to null, as a way to tell us to quit. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + KJ_EXPECT(!isChild); + + KJ_EXPECT(exec->isLive()); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "Executor's event loop exited before cross-thread event could complete", + exec->executeSync([&]() -> Promise { + fulfiller->fulfill(); + return kj::NEVER_DONE; + })); + + // Can't check `exec->isLive()` because it's been destroyed by now. + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("detached cross-thread event doesn't cause crash") { + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + paf.promise.wait(waitScope); + + // Without this poll(), we don't attempt to reply to the other thread? But this isn't required + // in other tests, for some reason? Oh well. + waitScope.poll(); + + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + }); + + ([&]() noexcept { + { + KJ_XTHREAD_TEST_SETUP_LOOP; + + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + exec->executeAsync([&]() -> kj::Promise { + // Make sure other thread gets time to exit its EventLoop. + delay(); + delay(); + delay(); + fulfiller->fulfill(); + return kj::READY_NOW; + }).detach([&](kj::Exception&& e) { + KJ_LOG(ERROR, e); + }); + + // Give the other thread a chance to wake up and start working on the event. + delay(); + + // Now we'll destroy our EventLoop. That *should* cause detached promises to be destroyed, + // thereby cancelling it, before disabling our own executor. However, at one point in the + // past, our executor was shut down first, followed by destroying detached promises, which + // led to an abort because the other thread had no way to reply back to this thread. + } + + *executor.lockExclusive() = nullptr; + })(); +} + +KJ_TEST("cross-thread event cancel requested while destination thread being destroyed") { + // This exercises the code in Executor::Impl::disconnect() which tears down the list of + // cross-thread events which have already been canceled. At one point this code had a bug which + // would cause it to throw if any events were present in the cancel list. + + MutexGuarded> executor; // to get the Executor from the other thread + Own> fulfiller; // accessed only from the subthread + + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = newPromiseAndFulfiller(); + fulfiller = kj::mv(paf.fulfiller); + + *executor.lockExclusive() = getCurrentThreadExecutor(); + + // Wait for other thread to start a cross-thread task. + paf.promise.wait(waitScope); + + // Let the other thread know, out-of-band, that the task is running, so that it can now request + // cancellation. We do this by setting `executor` to null (but we could also use some separate + // MutexGuarded conditional variable instead). + *executor.lockExclusive() = nullptr; + + // Give other thread a chance to request cancellation of the promise. + delay(); + + // now we exit the event loop + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + const Executor* exec; + { + auto lock = executor.lockExclusive(); + lock.wait([&](kj::Maybe value) { return value != nullptr; }); + exec = &KJ_ASSERT_NONNULL(*lock); + } + + KJ_EXPECT(exec->isLive()); + + auto promise = exec->executeAsync([&]() -> Promise { + fulfiller->fulfill(); + return kj::NEVER_DONE; + }); + + // Wait for the other thread to signal to us that it has indeed started executing our task. + executor.lockExclusive().wait([](auto& val) { return val == nullptr; }); + + // Cancel the promise. + promise = nullptr; + })(); +} + +KJ_TEST("cross-thread fulfiller") { + MutexGuarded>>> fulfillerMutex; + + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = kj::newPromiseAndCrossThreadFulfiller(); + *fulfillerMutex.lockExclusive() = kj::mv(paf.fulfiller); + + int result = paf.promise.wait(waitScope); + KJ_EXPECT(result == 123); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + Own> fulfiller; + { + auto lock = fulfillerMutex.lockExclusive(); + lock.wait([&](auto& value) { return value != nullptr; }); + fulfiller = kj::mv(KJ_ASSERT_NONNULL(*lock)); + } + + fulfiller->fulfill(123); + })(); +} + +KJ_TEST("cross-thread fulfiller rejects") { + MutexGuarded>>> fulfillerMutex; + + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = kj::newPromiseAndCrossThreadFulfiller(); + *fulfillerMutex.lockExclusive() = kj::mv(paf.fulfiller); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("foo exception", paf.promise.wait(waitScope)); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + Own> fulfiller; + { + auto lock = fulfillerMutex.lockExclusive(); + lock.wait([&](auto& value) { return value != nullptr; }); + fulfiller = kj::mv(KJ_ASSERT_NONNULL(*lock)); + } + + fulfiller->reject(KJ_EXCEPTION(FAILED, "foo exception")); + })(); +} + +KJ_TEST("cross-thread fulfiller destroyed") { + MutexGuarded>>> fulfillerMutex; + + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = kj::newPromiseAndCrossThreadFulfiller(); + *fulfillerMutex.lockExclusive() = kj::mv(paf.fulfiller); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "cross-thread PromiseFulfiller was destroyed without fulfilling the promise", + paf.promise.wait(waitScope)); + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + Own> fulfiller; + { + auto lock = fulfillerMutex.lockExclusive(); + lock.wait([&](auto& value) { return value != nullptr; }); + fulfiller = kj::mv(KJ_ASSERT_NONNULL(*lock)); + } + + fulfiller = nullptr; + })(); +} + +KJ_TEST("cross-thread fulfiller canceled") { + MutexGuarded>>> fulfillerMutex; + MutexGuarded done; + + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = kj::newPromiseAndCrossThreadFulfiller(); + { + auto lock = fulfillerMutex.lockExclusive(); + *lock = kj::mv(paf.fulfiller); + lock.wait([](auto& value) { return value == nullptr; }); + } + + // cancel + paf.promise = nullptr; + + { + auto lock = done.lockExclusive(); + lock.wait([](bool value) { return value; }); + } + }); + + ([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + Own> fulfiller; + { + auto lock = fulfillerMutex.lockExclusive(); + lock.wait([&](auto& value) { return value != nullptr; }); + fulfiller = kj::mv(KJ_ASSERT_NONNULL(*lock)); + KJ_ASSERT(fulfiller->isWaiting()); + *lock = nullptr; + } + + // Should eventually show not waiting. + while (fulfiller->isWaiting()) { + delay(); + } + + *done.lockExclusive() = true; + })(); +} + +KJ_TEST("cross-thread fulfiller multiple fulfills") { + MutexGuarded>>> fulfillerMutex; + + Thread thread([&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + auto paf = kj::newPromiseAndCrossThreadFulfiller(); + *fulfillerMutex.lockExclusive() = kj::mv(paf.fulfiller); + + int result = paf.promise.wait(waitScope); + KJ_EXPECT(result == 123); + }); + + auto func = [&]() noexcept { + KJ_XTHREAD_TEST_SETUP_LOOP; + + PromiseFulfiller* fulfiller; + { + auto lock = fulfillerMutex.lockExclusive(); + lock.wait([&](auto& value) { return value != nullptr; }); + fulfiller = KJ_ASSERT_NONNULL(*lock).get(); + } + + fulfiller->fulfill(123); + }; + + kj::Thread thread1(func); + kj::Thread thread2(func); + kj::Thread thread3(func); + kj::Thread thread4(func); +} + +} // namespace +} // namespace kj diff --git a/c++/src/kj/async.c++ b/c++/src/kj/async.c++ index 7920a38988..bc3498bf4a 100644 --- a/c++/src/kj/async.c++ +++ b/c++/src/kj/async.c++ @@ -19,183 +19,1694 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#undef _FORTIFY_SOURCE +// If _FORTIFY_SOURCE is defined, longjmp will complain when it detects the stack +// pointer moving in the "wrong direction", thinking you're jumping to a non-existent +// stack frame. But we use longjmp to jump between different stacks to implement fibers, +// so this check isn't appropriate for us. + +#if _WIN32 || __CYGWIN__ +#include +#elif __APPLE__ +// getcontext() and friends are marked deprecated on MacOS but seemingly no replacement is +// provided. It appears as if they deprecated it solely because the standards bodies deprecated it, +// which they seemingly did mainly because the proper semantics are too difficult for them to +// define. I doubt MacOS would actually remove these functions as they are widely used. But if they +// do, then I guess we'll need to fall back to using setjmp()/longjmp(), and some sort of hack +// involving sigaltstack() (and generating a fake signal I guess) in order to initialize the fiber +// in the first place. Or we could use assembly, I suppose. Either way, ick. +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#define _XOPEN_SOURCE // Must be defined to see getcontext() on MacOS. +#endif + #include "async.h" #include "debug.h" #include "vector.h" #include "threadlocal.h" -#include -#include +#include "mutex.h" +#include "one-of.h" +#include "function.h" +#include "list.h" +#include +#include + +#if _WIN32 || __CYGWIN__ +#include // for Sleep(0) and fibers +#include +#else + +#if KJ_USE_FIBERS +#include +#include // for fibers +#endif + +#include // mmap(), for allocating new stacks +#include // sysconf() +#include +#endif + +#if !_WIN32 +#include // just for sched_yield() +#endif + +#if !KJ_NO_RTTI +#include +#if __GNUC__ +#include +#endif +#endif + +#include + +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) +// Clang's address sanitizer requires special hints when switching fibers, especially in order for +// stack-use-after-return handling to work right. +// +// TODO(someday): Does GCC's sanitizer, flagged by __SANITIZE_ADDRESS__, have these hints too? I +// don't know and am not in a position to test, so I'm assuming not for now. +#include +#else +// Nop the hints so that we don't have to put #ifdefs around every use. +#define __sanitizer_start_switch_fiber(...) +#define __sanitizer_finish_switch_fiber(...) +#endif + +#if _MSC_VER && !__clang__ +// MSVC's atomic intrinsics are weird and different, whereas the C++ standard atomics match the GCC +// builtins -- except for requiring the obnoxious std::atomic wrapper. So, on MSVC let's just +// #define the builtins based on the C++ library, reinterpret-casting native types to +// std::atomic... this is cheating but ugh, whatever. +template +static std::atomic* reinterpretAtomic(T* ptr) { return reinterpret_cast*>(ptr); } +#define __atomic_store_n(ptr, val, order) \ + std::atomic_store_explicit(reinterpretAtomic(ptr), val, order) +#define __atomic_load_n(ptr, order) \ + std::atomic_load_explicit(reinterpretAtomic(ptr), order) +#define __atomic_compare_exchange_n(ptr, expected, desired, weak, succ, fail) \ + std::atomic_compare_exchange_strong_explicit( \ + reinterpretAtomic(ptr), expected, desired, succ, fail) +#define __atomic_exchange_n(ptr, val, order) \ + std::atomic_exchange_explicit(reinterpretAtomic(ptr), val, order) +#define __ATOMIC_RELAXED std::memory_order_relaxed +#define __ATOMIC_ACQUIRE std::memory_order_acquire +#define __ATOMIC_RELEASE std::memory_order_release +#endif + +namespace kj { + +namespace { + +KJ_THREADLOCAL_PTR(DisallowAsyncDestructorsScope) disallowAsyncDestructorsScope = nullptr; + +} // namespace + +AsyncObject::~AsyncObject() { + if (disallowAsyncDestructorsScope != nullptr) { + // If we try to do the KJ_FAIL_REQUIRE here (declaring `~AsyncObject()` itself to be noexcept), + // it seems to have a non-negligible performance impact in the HTTP benchmark. My guess is that + // it's because it breaks inlining of `~AsyncObject()` into various subclass destructors that + // are defined inside this file, which are some of the biggest ones. By forcing the actual + // failure code out into a separate function we get a little performance boost. + failed(); + } +} + +void AsyncObject::failed() noexcept { + // Since the method is noexcept, this will abort the process. + KJ_FAIL_REQUIRE( + kj::str("KJ async object being destroyed when not allowed: ", + disallowAsyncDestructorsScope->reason)); +} + +DisallowAsyncDestructorsScope::DisallowAsyncDestructorsScope(kj::StringPtr reason) + : reason(reason), previousValue(disallowAsyncDestructorsScope) { + requireOnStack(this, "DisallowAsyncDestructorsScope must be allocated on the stack."); + disallowAsyncDestructorsScope = this; +} + +DisallowAsyncDestructorsScope::~DisallowAsyncDestructorsScope() { + disallowAsyncDestructorsScope = previousValue; +} + +AllowAsyncDestructorsScope::AllowAsyncDestructorsScope() + : previousValue(disallowAsyncDestructorsScope) { + requireOnStack(this, "AllowAsyncDestructorsScope must be allocated on the stack."); + disallowAsyncDestructorsScope = nullptr; +} +AllowAsyncDestructorsScope::~AllowAsyncDestructorsScope() { + disallowAsyncDestructorsScope = previousValue; +} + +// ======================================================================================= + +namespace { + +KJ_THREADLOCAL_PTR(EventLoop) threadLocalEventLoop = nullptr; + +#define _kJ_ALREADY_READY reinterpret_cast< ::kj::_::Event*>(1) + +EventLoop& currentEventLoop() { + EventLoop* loop = threadLocalEventLoop; + KJ_REQUIRE(loop != nullptr, "No event loop is running on this thread."); + return *loop; +} + +class RootEvent: public _::Event { +public: + RootEvent(_::PromiseNode* node, void* traceAddr, SourceLocation location) + : Event(location), node(node), traceAddr(traceAddr) {} + + bool fired = false; + + Maybe> fire() override { + fired = true; + return nullptr; + } + + void traceEvent(_::TraceBuilder& builder) override { + node->tracePromise(builder, true); + builder.add(traceAddr); + } + +private: + _::PromiseNode* node; + void* traceAddr; +}; + +struct DummyFunctor { + void operator()() {}; +}; + +} // namespace + +// ======================================================================================= + +void END_CANCELER_STACK_START_CANCELEE_STACK() {} +// Dummy symbol used when reporting how a Canceler was canceled. We end up combining two stack +// traces into one and we use this as a separator. + +Canceler::~Canceler() noexcept(false) { + if (isEmpty()) return; + cancel(getDestructionReason( + reinterpret_cast(&END_CANCELER_STACK_START_CANCELEE_STACK), + Exception::Type::DISCONNECTED, __FILE__, __LINE__, "operation canceled"_kj)); +} + +void Canceler::cancel(StringPtr cancelReason) { + if (isEmpty()) return; + // We can't use getDestructionReason() here because if an exception is in-flight, it would use + // that exception, totally discarding the reason given by the caller. This would probably be + // unexpected. The caller can always use getDestructionReason() themselves if desired. + cancel(Exception(Exception::Type::DISCONNECTED, __FILE__, __LINE__, kj::str(cancelReason))); +} + +void Canceler::cancel(const Exception& exception) { + for (;;) { + KJ_IF_MAYBE(a, list) { + a->unlink(); + a->cancel(kj::cp(exception)); + } else { + break; + } + } +} + +void Canceler::release() { + for (;;) { + KJ_IF_MAYBE(a, list) { + a->unlink(); + } else { + break; + } + } +} + +Canceler::AdapterBase::AdapterBase(Canceler& canceler) + : prev(canceler.list), + next(canceler.list) { + canceler.list = *this; + KJ_IF_MAYBE(n, next) { + n->prev = next; + } +} + +Canceler::AdapterBase::~AdapterBase() noexcept(false) { + unlink(); +} + +void Canceler::AdapterBase::unlink() { + KJ_IF_MAYBE(p, prev) { + *p = next; + } + KJ_IF_MAYBE(n, next) { + n->prev = prev; + } + next = nullptr; + prev = nullptr; +} + +Canceler::AdapterImpl::AdapterImpl(kj::PromiseFulfiller& fulfiller, + Canceler& canceler, kj::Promise inner) + : AdapterBase(canceler), + fulfiller(fulfiller), + inner(inner.then( + [&fulfiller]() { fulfiller.fulfill(); }, + [&fulfiller](kj::Exception&& e) { fulfiller.reject(kj::mv(e)); }) + .eagerlyEvaluate(nullptr)) {} + +void Canceler::AdapterImpl::cancel(kj::Exception&& e) { + fulfiller.reject(kj::mv(e)); + inner = nullptr; +} + +// ======================================================================================= + +TaskSet::TaskSet(TaskSet::ErrorHandler& errorHandler, SourceLocation location) + : errorHandler(errorHandler), location(location) {} + +class TaskSet::Task final: public _::PromiseArenaMember, public _::Event { +public: + Task(_::OwnPromiseNode&& nodeParam, TaskSet& taskSet) + : Event(taskSet.location), taskSet(taskSet), node(kj::mv(nodeParam)) { + node->setSelfPointer(&node); + node->onReady(this); + } + + void destroy() override { freePromise(this); } + + OwnTask pop() { + KJ_IF_MAYBE(n, next) { n->get()->prev = prev; } + OwnTask self = kj::mv(KJ_ASSERT_NONNULL(*prev)); + KJ_ASSERT(self.get() == this); + *prev = kj::mv(next); + next = nullptr; + prev = nullptr; + return self; + } + + Maybe next; + Maybe* prev = nullptr; + + kj::String trace() { + void* space[32]; + _::TraceBuilder builder(space); + node->tracePromise(builder, false); + return kj::str("task: ", builder); + } + +protected: + Maybe> fire() override { + // Get the result. + _::ExceptionOr<_::Void> result; + node->get(result); + + // Delete the node, catching any exceptions. + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() { + node = nullptr; + })) { + result.addException(kj::mv(*exception)); + } + + // Remove from the task list. Do this before calling taskFailed(), so that taskFailed() can + // safely call clear(). + auto self = pop(); + + // We'll also process onEmpty() now, just in case `taskFailed()` actually destroys the whole + // `TaskSet`. + KJ_IF_MAYBE(f, taskSet.emptyFulfiller) { + if (taskSet.tasks == nullptr) { + f->get()->fulfill(); + taskSet.emptyFulfiller = nullptr; + } + } + + // Call the error handler if there was an exception. + KJ_IF_MAYBE(e, result.exception) { + taskSet.errorHandler.taskFailed(kj::mv(*e)); + } + + return Own(mv(self)); + } + + void traceEvent(_::TraceBuilder& builder) override { + // Pointing out the ErrorHandler's taskFailed() implementation will usually identify the + // particular TaskSet that contains this event. + builder.add(_::getMethodStartAddress(taskSet.errorHandler, &ErrorHandler::taskFailed)); + } + +private: + TaskSet& taskSet; + _::OwnPromiseNode node; +}; + +TaskSet::~TaskSet() noexcept(false) { + // You could argue it is dubious, but some applications would like for the destructor of a + // task to be able to schedule new tasks. So when we cancel our tasks... we might find new + // tasks added! We'll have to repeatedly cancel. Additionally, we need to make sure that we destroy + // the items in a loop to prevent any issues with stack overflow. + while (tasks != nullptr) { + auto removed = KJ_REQUIRE_NONNULL(tasks)->pop(); + } +} + +void TaskSet::add(Promise&& promise) { + auto task = _::appendPromise(_::PromiseNode::from(kj::mv(promise)), *this); + KJ_IF_MAYBE(head, tasks) { + head->get()->prev = &task->next; + task->next = kj::mv(tasks); + } + task->prev = &tasks; + tasks = kj::mv(task); +} + +kj::String TaskSet::trace() { + kj::Vector traces; + + Maybe* ptr = &tasks; + for (;;) { + KJ_IF_MAYBE(task, *ptr) { + traces.add(task->get()->trace()); + ptr = &task->get()->next; + } else { + break; + } + } + + return kj::strArray(traces, "\n"); +} + +Promise TaskSet::onEmpty() { + KJ_IF_MAYBE(fulfiller, emptyFulfiller) { + if (fulfiller->get()->isWaiting()) { + KJ_FAIL_REQUIRE("onEmpty() can only be called once at a time"); + } + } + + if (tasks == nullptr) { + return READY_NOW; + } else { + auto paf = newPromiseAndFulfiller(); + emptyFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } +} + +void TaskSet::clear() { + tasks = nullptr; + + KJ_IF_MAYBE(fulfiller, emptyFulfiller) { + fulfiller->get()->fulfill(); + } +} + +// ======================================================================================= + +namespace { + +#if _WIN32 || __CYGWIN__ +thread_local void* threadMainFiber = nullptr; + +void* getMainWin32Fiber() { + return threadMainFiber; +} +#endif + +inline void ensureThreadCanRunFibers() { +#if _WIN32 || __CYGWIN__ + // Make sure the current thread has been converted to a fiber. + void* fiber = threadMainFiber; + if (fiber == nullptr) { + // Thread not initialized. Convert it to a fiber now. + // Note: Unfortunately, if the application has already converted the thread to a fiber, I + // guess this will fail. But trying to call GetCurrentFiber() when the thread isn't a fiber + // doesn't work (it returns null on WINE but not on real windows, ugh). So I guess we're + // just incompatible with the application doing anything with fibers, which is sad. + threadMainFiber = fiber = ConvertThreadToFiber(nullptr); + } +#endif +} + +} // namespace + +namespace _ { + +class FiberStack final { + // A class containing a fiber stack impl. This is separate from fiber + // promises since it lets us move the stack itself around and reuse it. + +public: + FiberStack(size_t stackSize); + ~FiberStack() noexcept(false); + + struct SynchronousFunc { + kj::FunctionParam& func; + kj::Maybe exception; + }; + + void initialize(FiberBase& fiber); + void initialize(SynchronousFunc& syncFunc); + + void reset() { + main = {}; + } + + void switchToFiber(); + void switchToMain(); + + void trace(TraceBuilder& builder) { + // TODO(someday): Trace through fiber stack? Can it be done??? + builder.add(getMethodStartAddress(*this, &FiberStack::trace)); + } + +private: + size_t stackSize; + OneOf main; + + friend class FiberBase; + friend class FiberPool::Impl; + + struct StartRoutine; + +#if KJ_USE_FIBERS +#if _WIN32 || __CYGWIN__ + void* osFiber; +#else + struct Impl; + Impl* impl; +#endif +#endif + + [[noreturn]] void run(); + + bool isReset() { return main == nullptr; } +}; + +} // namespace _ + +#if __linux__ +// TODO(someday): Support core-local freelists on OSs other than Linux. The only tricky part is +// finding what to use instead of sched_getcpu() to get the current CPU ID. +#define USE_CORE_LOCAL_FREELISTS 1 +#endif + +#if USE_CORE_LOCAL_FREELISTS +static const size_t CACHE_LINE_SIZE = 64; +// Most modern architectures have 64-byte cache lines. +#endif + +class FiberPool::Impl final: private Disposer { +public: + Impl(size_t stackSize): stackSize(stackSize) {} + ~Impl() noexcept(false) { +#if USE_CORE_LOCAL_FREELISTS + if (coreLocalFreelists != nullptr) { + KJ_DEFER(free(coreLocalFreelists)); + + for (uint i: kj::zeroTo(nproc)) { + for (auto stack: coreLocalFreelists[i].stacks) { + if (stack != nullptr) { + delete stack; + } + } + } + } +#endif + + // Make sure we're not leaking anything from the global freelist either. + auto lock = freelist.lockExclusive(); + auto dangling = kj::mv(*lock); + for (auto& stack: dangling) { + delete stack; + } + } + + void setMaxFreelist(size_t count) { + maxFreelist = count; + } + + size_t getFreelistSize() const { + return freelist.lockShared()->size(); + } + + void useCoreLocalFreelists() { +#if USE_CORE_LOCAL_FREELISTS + if (coreLocalFreelists != nullptr) { + // Ignore repeat call. + return; + } + + int nproc_; + KJ_SYSCALL(nproc_ = sysconf(_SC_NPROCESSORS_CONF)); + nproc = nproc_; + + void* allocPtr; + size_t totalSize = nproc * sizeof(CoreLocalFreelist); + int error = posix_memalign(&allocPtr, CACHE_LINE_SIZE, totalSize); + if (error != 0) { + KJ_FAIL_SYSCALL("posix_memalign", error); + } + memset(allocPtr, 0, totalSize); + coreLocalFreelists = reinterpret_cast(allocPtr); +#endif + } + + Own<_::FiberStack> takeStack() const { + // Get a stack from the pool. The disposer on the returned Own pointer will return the stack + // to the pool, provided that reset() has been called to indicate that the stack is not in + // a weird state. + +#if USE_CORE_LOCAL_FREELISTS + KJ_IF_MAYBE(core, lookupCoreLocalFreelist()) { + for (auto& stackPtr: core->stacks) { + _::FiberStack* result = __atomic_exchange_n(&stackPtr, nullptr, __ATOMIC_ACQUIRE); + if (result != nullptr) { + // Found a stack in this slot! + return { result, *this }; + } + } + // No stacks found, fall back to global freelist. + } +#endif + + { + auto lock = freelist.lockExclusive(); + if (!lock->empty()) { + _::FiberStack* result = lock->back(); + lock->pop_back(); + return { result, *this }; + } + } + + _::FiberStack* result = new _::FiberStack(stackSize); + return { result, *this }; + } + +private: + size_t stackSize; + size_t maxFreelist = kj::maxValue; + MutexGuarded> freelist; + +#if USE_CORE_LOCAL_FREELISTS + struct CoreLocalFreelist { + union { + _::FiberStack* stacks[2]; + // For now, we don't try to freelist more than 2 stacks per core. If you have three or more + // threads interleaved on a core, chances are you have bigger problems... + + byte padToCacheLine[CACHE_LINE_SIZE]; + // We don't want two core-local freelists to live in the same cache line, otherwise the + // cores will fight over ownership of that line. + }; + }; + + uint nproc; + CoreLocalFreelist* coreLocalFreelists = nullptr; + + kj::Maybe lookupCoreLocalFreelist() const { + if (coreLocalFreelists == nullptr) { + return nullptr; + } else { + int cpu = sched_getcpu(); + if (cpu >= 0) { + // TODO(perf): Perhaps two hyperthreads on the same physical core should share a freelist? + // But I don't know how to find out if the system uses hyperthreading. + return coreLocalFreelists[cpu]; + } else { + static bool logged = false; + if (!logged) { + KJ_LOG(ERROR, "invalid cpu number from sched_getcpu()?", cpu, nproc); + logged = true; + } + return nullptr; + } + } + } +#endif + + void disposeImpl(void* pointer) const { + _::FiberStack* stack = reinterpret_cast<_::FiberStack*>(pointer); + KJ_DEFER(delete stack); + + // Verify that the stack was reset before returning, otherwise it might be in a weird state + // where we don't want to reuse it. + if (stack->isReset()) { +#if USE_CORE_LOCAL_FREELISTS + KJ_IF_MAYBE(core, lookupCoreLocalFreelist()) { + for (auto& stackPtr: core->stacks) { + stack = __atomic_exchange_n(&stackPtr, stack, __ATOMIC_RELEASE); + if (stack == nullptr) { + // Cool, we inserted the stack into an unused slot. We're done. + return; + } + } + // All slots were occupied, so we inserted the new stack in the front, pushed the rest back, + // and now `stack` refers to the stack that fell off the end of the core-local list. That + // needs to go into the global freelist. + } +#endif + + auto lock = freelist.lockExclusive(); + lock->push_back(stack); + if (lock->size() > maxFreelist) { + stack = lock->front(); + lock->pop_front(); + } else { + stack = nullptr; + } + } + } +}; + +FiberPool::FiberPool(size_t stackSize) : impl(kj::heap(stackSize)) {} +FiberPool::~FiberPool() noexcept(false) {} + +void FiberPool::setMaxFreelist(size_t count) { + impl->setMaxFreelist(count); +} + +size_t FiberPool::getFreelistSize() const { + return impl->getFreelistSize(); +} + +void FiberPool::useCoreLocalFreelists() { + impl->useCoreLocalFreelists(); +} + +void FiberPool::runSynchronously(kj::FunctionParam func) const { + ensureThreadCanRunFibers(); + + _::FiberStack::SynchronousFunc syncFunc { func, nullptr }; + + { + auto stack = impl->takeStack(); + stack->initialize(syncFunc); + stack->switchToFiber(); + stack->reset(); // safe to reuse + } + + KJ_IF_MAYBE(e, syncFunc.exception) { + kj::throwRecoverableException(kj::mv(*e)); + } +} + +namespace _ { // private + +class LoggingErrorHandler: public TaskSet::ErrorHandler { +public: + static LoggingErrorHandler instance; + + void taskFailed(kj::Exception&& exception) override { + KJ_LOG(ERROR, "Uncaught exception in daemonized task.", exception); + } +}; + +LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler(); + +} // namespace _ (private) + +// ======================================================================================= + +struct Executor::Impl { + Impl(EventLoop& loop): state(loop) {} + + struct State { + // Queues of notifications from other threads that need this thread's attention. + + State(EventLoop& loop): loop(loop) {} + + kj::Maybe loop; + // Becomes null when the loop is destroyed. + + List<_::XThreadEvent, &_::XThreadEvent::targetLink> start; + List<_::XThreadEvent, &_::XThreadEvent::targetLink> cancel; + List<_::XThreadEvent, &_::XThreadEvent::replyLink> replies; + // Lists of events that need actioning by this thread. + + List<_::XThreadEvent, &_::XThreadEvent::targetLink> executing; + // Events that have already been dispatched and are happily executing. This list is maintained + // so that they can be canceled if the event loop exits. + + List<_::XThreadPaf, &_::XThreadPaf::link> fulfilled; + // Set of XThreadPafs that have been fulfilled by another thread. + + bool waitingForCancel = false; + // True if this thread is currently blocked waiting for some other thread to pump its + // cancellation queue. If that other thread tries to block on *this* thread, then it could + // deadlock -- it must take precautions against this. + + bool isDispatchNeeded() const { + return !start.empty() || !cancel.empty() || !replies.empty() || !fulfilled.empty(); + } + + void dispatchAll(Vector<_::XThreadEvent*>& eventsToCancelOutsideLock) { + for (auto& event: start) { + start.remove(event); + executing.add(event); + event.state = _::XThreadEvent::EXECUTING; + event.armBreadthFirst(); + } + + dispatchCancels(eventsToCancelOutsideLock); + + for (auto& event: replies) { + replies.remove(event); + event.onReadyEvent.armBreadthFirst(); + } + + for (auto& event: fulfilled) { + fulfilled.remove(event); + event.state = _::XThreadPaf::DISPATCHED; + event.onReadyEvent.armBreadthFirst(); + } + } + + void dispatchCancels(Vector<_::XThreadEvent*>& eventsToCancelOutsideLock) { + for (auto& event: cancel) { + cancel.remove(event); + + if (event.promiseNode == nullptr) { + event.setDoneState(); + } else { + // We can't destroy the promiseNode while the mutex is locked, because we don't know + // what the destructor might do. But, we *must* destroy it before acknowledging + // cancellation. So we have to add it to a list to destroy later. + eventsToCancelOutsideLock.add(&event); + } + } + } + }; + + kj::MutexGuarded state; + // After modifying state from another thread, the loop's port.wake() must be called. + + void processAsyncCancellations(Vector<_::XThreadEvent*>& eventsToCancelOutsideLock) { + // After calling dispatchAll() or dispatchCancels() with the lock held, it may be that some + // cancellations require dropping the lock before destroying the promiseNode. In that case + // those cancellations will be added to the eventsToCancelOutsideLock Vector passed to the + // method. That vector must then be passed to processAsyncCancellations() as soon as the lock + // is released. + + for (auto& event: eventsToCancelOutsideLock) { + event->promiseNode = nullptr; + event->disarm(); + } + + // Now we need to mark all the events "done" under lock. + auto lock = state.lockExclusive(); + for (auto& event: eventsToCancelOutsideLock) { + event->setDoneState(); + } + } + + void disconnect() { + state.lockExclusive()->loop = nullptr; + + // Now that `loop` is set null in `state`, other threads will no longer try to manipulate our + // lists, so we can access them without a lock. That's convenient because a bunch of the things + // we want to do with them would require dropping the lock to avoid deadlocks. We'd end up + // copying all the lists over into separate vectors first, dropping the lock, operating on + // them, and then locking again. + auto& s = state.getWithoutLock(); + + // We do, however, take and release the lock on the way out, to make sure anyone performing + // a conditional wait for state changes gets a chance to have their wait condition re-checked. + KJ_DEFER(state.lockExclusive()); + + for (auto& event: s.start) { + KJ_ASSERT(event.state == _::XThreadEvent::QUEUED, event.state) { break; } + s.start.remove(event); + event.setDisconnected(); + event.sendReply(); + event.setDoneState(); + } + + for (auto& event: s.executing) { + KJ_ASSERT(event.state == _::XThreadEvent::EXECUTING, event.state) { break; } + s.executing.remove(event); + event.promiseNode = nullptr; + event.setDisconnected(); + event.sendReply(); + event.setDoneState(); + } + + for (auto& event: s.cancel) { + KJ_ASSERT(event.state == _::XThreadEvent::CANCELING, event.state) { break; } + s.cancel.remove(event); + event.promiseNode = nullptr; + event.setDoneState(); + } + + // The replies list "should" be empty, because any locally-initiated tasks should have been + // canceled before destroying the EventLoop. + if (!s.replies.empty()) { + KJ_LOG(ERROR, "EventLoop destroyed with cross-thread event replies outstanding"); + for (auto& event: s.replies) { + s.replies.remove(event); + } + } + + // Similarly for cross-thread fulfillers. The waiting tasks should have been canceled. + if (!s.fulfilled.empty()) { + KJ_LOG(ERROR, "EventLoop destroyed with cross-thread fulfiller replies outstanding"); + for (auto& event: s.fulfilled) { + s.fulfilled.remove(event); + event.state = _::XThreadPaf::DISPATCHED; + } + } + }}; + +namespace _ { // (private) + +XThreadEvent::XThreadEvent( + ExceptionOrValue& result, const Executor& targetExecutor, EventLoop& loop, + void* funcTracePtr, SourceLocation location) + : Event(loop, location), result(result), funcTracePtr(funcTracePtr), + targetExecutor(targetExecutor.addRef()) {} + +void XThreadEvent::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + // We can't safely trace into another thread, so we'll stop here. + builder.add(funcTracePtr); +} + +void XThreadEvent::ensureDoneOrCanceled() { + if (__atomic_load_n(&state, __ATOMIC_ACQUIRE) != DONE) { + auto lock = targetExecutor->impl->state.lockExclusive(); + + const EventLoop* loop; + KJ_IF_MAYBE(l, lock->loop) { + loop = l; + } else { + // Target event loop is already dead, so we know it's already working on transitioning all + // events to the DONE state. We can just wait. + lock.wait([&](auto&) { return state == DONE; }); + return; + } + + switch (state) { + case UNUSED: + // Nothing to do. + break; + case QUEUED: + lock->start.remove(*this); + // No wake needed since we removed work rather than adding it. + state = DONE; + break; + case EXECUTING: { + lock->executing.remove(*this); + lock->cancel.add(*this); + state = CANCELING; + KJ_IF_MAYBE(p, loop->port) { + p->wake(); + } + + Maybe maybeSelfExecutor = nullptr; + if (threadLocalEventLoop != nullptr) { + KJ_IF_MAYBE(e, threadLocalEventLoop->executor) { + maybeSelfExecutor = **e; + } + } + + KJ_IF_MAYBE(selfExecutor, maybeSelfExecutor) { + // If, while waiting for other threads to process our cancellation request, we have + // cancellation requests queued back to this thread, we must process them. Otherwise, + // we could deadlock with two threads waiting on each other to process cancellations. + // + // We don't have a terribly good way to detect this, except to check if the remote + // thread is itself waiting for cancellations and, if so, wake ourselves up to check for + // cancellations to process. This will busy-loop but at least it should eventually + // resolve assuming fair scheduling. + // + // To make things extra-annoying, in order to update our waitingForCancel flag, we have + // to lock our own executor state, but we can't take both locks at once, so we have to + // release the other lock in the meantime. + + // Make sure we unset waitingForCancel on the way out. + KJ_DEFER({ + lock = {}; + + Vector<_::XThreadEvent*> eventsToCancelOutsideLock; + KJ_DEFER(selfExecutor->impl->processAsyncCancellations(eventsToCancelOutsideLock)); + + auto selfLock = selfExecutor->impl->state.lockExclusive(); + selfLock->waitingForCancel = false; + selfLock->dispatchCancels(eventsToCancelOutsideLock); + + // We don't need to re-take the lock on the other executor here; it's not used again + // after this scope. + }); + + while (state != DONE) { + bool otherThreadIsWaiting = lock->waitingForCancel; + + // Make sure our waitingForCancel is on and dispatch any pending cancellations on this + // thread. + lock = {}; + { + Vector<_::XThreadEvent*> eventsToCancelOutsideLock; + KJ_DEFER(selfExecutor->impl->processAsyncCancellations(eventsToCancelOutsideLock)); + + auto selfLock = selfExecutor->impl->state.lockExclusive(); + selfLock->waitingForCancel = true; + + // Note that we don't have to proactively delete the PromiseNodes extracted from + // the canceled events because those nodes belong to this thread and can't possibly + // continue executing while we're blocked here. + selfLock->dispatchCancels(eventsToCancelOutsideLock); + } + + if (otherThreadIsWaiting) { + // We know the other thread was waiting for cancellations to complete a moment ago. + // We may have just processed the necessary cancellations in this thread, in which + // case the other thread needs a chance to receive control and notice this. Or, it + // may be that the other thread is waiting for some third thread to take action. + // Either way, we should yield control here to give things a chance to settle. + // Otherwise we could end up in a tight busy loop. +#if _WIN32 + Sleep(0); +#else + sched_yield(); +#endif + } + + // OK now we can take the original lock again. + lock = targetExecutor->impl->state.lockExclusive(); + + // OK, now we can wait for the other thread to either process our cancellation or + // indicate that it is waiting for remote cancellation. + lock.wait([&](const Executor::Impl::State& executorState) { + return state == DONE || executorState.waitingForCancel; + }); + } + } else { + // We have no executor of our own so we don't have to worry about cancellation cycles + // causing deadlock. + // + // NOTE: I don't think we can actually get here, because it implies that this is a + // synchronous execution, which means there's no way to cancel it. + lock.wait([&](auto&) { return state == DONE; }); + } + KJ_DASSERT(!targetLink.isLinked()); + break; + } + case CANCELING: + KJ_FAIL_ASSERT("impossible state: CANCELING should only be set within the above case"); + case DONE: + // Became done while we waited for lock. Nothing to do. + break; + } + } + + KJ_IF_MAYBE(e, replyExecutor) { + // Since we know we reached the DONE state (or never left UNUSED), we know that the remote + // thread is all done playing with our `replyPrev` pointer. Only the current thread could + // possibly modify it after this point. So we can skip the lock if it's already null. + if (replyLink.isLinked()) { + auto lock = e->impl->state.lockExclusive(); + lock->replies.remove(*this); + } + } +} + +void XThreadEvent::sendReply() { + KJ_IF_MAYBE(e, replyExecutor) { + // Queue the reply. + const EventLoop* replyLoop; + { + auto lock = e->impl->state.lockExclusive(); + KJ_IF_MAYBE(l, lock->loop) { + lock->replies.add(*this); + replyLoop = l; + } else { + // Calling thread exited without cancelling the promise. This is UB. In fact, + // `replyExecutor` is probably already destroyed and we are in use-after-free territory + // already. Better abort. + KJ_LOG(FATAL, + "the thread which called kj::Executor::executeAsync() apparently exited its own " + "event loop without canceling the cross-thread promise first; this is undefined " + "behavior so I will crash now"); + abort(); + } + } + + // Note that it's safe to assume `replyLoop` still exists even though we dropped the lock + // because that thread would have had to cancel any promises before destroying its own + // EventLoop, and when it tries to destroy this promise, it will wait for `state` to become + // `DONE`, which we don't set until later on. That's nice because wake() probably makes a + // syscall and we'd rather not hold the lock through syscalls. + KJ_IF_MAYBE(p, replyLoop->port) { + p->wake(); + } + } +} + +void XThreadEvent::done() { + KJ_ASSERT(targetExecutor.get() == ¤tEventLoop().getExecutor(), + "calling done() from wrong thread?"); + + sendReply(); + + { + auto lock = targetExecutor->impl->state.lockExclusive(); + + switch (state) { + case EXECUTING: + lock->executing.remove(*this); + break; + case CANCELING: + // Sending thread requested cancellation, but we're done anyway, so it doesn't matter at this + // point. + lock->cancel.remove(*this); + break; + default: + KJ_FAIL_ASSERT("can't call done() from this state", (uint)state); + } + + setDoneState(); + } +} + +inline void XThreadEvent::setDoneState() { + __atomic_store_n(&state, DONE, __ATOMIC_RELEASE); +} + +void XThreadEvent::setDisconnected() { + result.addException(KJ_EXCEPTION(DISCONNECTED, + "Executor's event loop exited before cross-thread event could complete")); +} + +class XThreadEvent::DelayedDoneHack: public Disposer { + // Crazy hack: In fire(), we want to call done() if the event is finished. But done() signals + // the requesting thread to wake up and possibly delete the XThreadEvent. But the caller (the + // EventLoop) still has to set `event->firing = false` after `fire()` returns, so this would be + // a race condition use-after-free. + // + // It just so happens, though, that fire() is allowed to return an optional `Own` to drop, + // and the caller drops that pointer immediately after setting event->firing = false. So we + // return a pointer whose disposer calls done(). + // + // It's not quite as much of a hack as it seems: The whole reason fire() returns an Own is + // so that the event can delete itself, but do so after the caller sets event->firing = false. + // It just happens to be that in this case, the event isn't deleting itself, but rather releasing + // itself back to the other thread. + +protected: + void disposeImpl(void* pointer) const override { + reinterpret_cast(pointer)->done(); + } +}; + +Maybe> XThreadEvent::fire() { + static constexpr DelayedDoneHack DISPOSER {}; + + KJ_IF_MAYBE(n, promiseNode) { + n->get()->get(result); + promiseNode = nullptr; // make sure to destroy in the thread that created it + return Own(this, DISPOSER); + } else { + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + promiseNode = execute(); + })) { + result.addException(kj::mv(*exception)); + }; + KJ_IF_MAYBE(n, promiseNode) { + n->get()->onReady(this); + } else { + return Own(this, DISPOSER); + } + } + + return nullptr; +} + +void XThreadEvent::traceEvent(TraceBuilder& builder) { + KJ_IF_MAYBE(n, promiseNode) { + n->get()->tracePromise(builder, true); + } + + // We can't safely trace into another thread, so we'll stop here. + builder.add(funcTracePtr); +} -#if KJ_USE_FUTEX -#include -#include -#include -#endif +void XThreadEvent::onReady(Event* event) noexcept { + onReadyEvent.init(event); +} -#if !KJ_NO_RTTI -#include -#if __GNUC__ -#include -#include -#endif -#endif +XThreadPaf::XThreadPaf() + : state(WAITING), executor(getCurrentThreadExecutor()) {} +XThreadPaf::~XThreadPaf() noexcept(false) {} -namespace kj { +void XThreadPaf::destroy() { + auto oldState = WAITING; -namespace { + if (__atomic_load_n(&state, __ATOMIC_ACQUIRE) == DISPATCHED) { + // Common case: Promise was fully fulfilled and dispatched, no need for locking. + delete this; + } else if (__atomic_compare_exchange_n(&state, &oldState, CANCELED, false, + __ATOMIC_ACQUIRE, __ATOMIC_ACQUIRE)) { + // State transitioned from WAITING to CANCELED, so now it's the fulfiller's job to destroy the + // object. + } else { + // Whoops, another thread is already in the process of fulfilling this promise. We'll have to + // wait for it to finish and transition the state to FULFILLED. + executor.impl->state.when([&](auto&) { + return state == FULFILLED || state == DISPATCHED; + }, [&](Executor::Impl::State& exState) { + if (state == FULFILLED) { + // The object is on the queue but was not yet dispatched. Remove it. + exState.fulfilled.remove(*this); + } + }); -KJ_THREADLOCAL_PTR(EventLoop) threadLocalEventLoop = nullptr; + // It's ours now, delete it. + delete this; + } +} -#define _kJ_ALREADY_READY reinterpret_cast< ::kj::_::Event*>(1) +void XThreadPaf::onReady(Event* event) noexcept { + onReadyEvent.init(event); +} -EventLoop& currentEventLoop() { - EventLoop* loop = threadLocalEventLoop; - KJ_REQUIRE(loop != nullptr, "No event loop is running on this thread."); - return *loop; +void XThreadPaf::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + // We can't safely trace into another thread, so we'll stop here. + // Maybe returning the address of get() will give us a function name with meaningful type + // information. + builder.add(getMethodStartAddress(implicitCast(*this), &PromiseNode::get)); +} + +XThreadPaf::FulfillScope::FulfillScope(XThreadPaf** pointer) { + obj = __atomic_exchange_n(pointer, static_cast(nullptr), __ATOMIC_ACQUIRE); + auto oldState = WAITING; + if (obj == nullptr) { + // Already fulfilled (possibly by another thread). + } else if (__atomic_compare_exchange_n(&obj->state, &oldState, FULFILLING, false, + __ATOMIC_ACQUIRE, __ATOMIC_ACQUIRE)) { + // Transitioned to FULFILLING, good. + } else { + // The waiting thread must have canceled. + KJ_ASSERT(oldState == CANCELED); + + // It's our responsibility to clean up, then. + delete obj; + + // Set `obj` null so that we don't try to fill it in or delete it later. + obj = nullptr; + } +} +XThreadPaf::FulfillScope::~FulfillScope() noexcept(false) { + if (obj != nullptr) { + auto lock = obj->executor.impl->state.lockExclusive(); + KJ_IF_MAYBE(l, lock->loop) { + lock->fulfilled.add(*obj); + __atomic_store_n(&obj->state, FULFILLED, __ATOMIC_RELEASE); + KJ_IF_MAYBE(p, l->port) { + // TODO(perf): It's annoying we have to call wake() with the lock held, but we have to + // prevent the destination EventLoop from being destroyed first. + p->wake(); + } + } else { + KJ_LOG(FATAL, + "the thread which called kj::newPromiseAndCrossThreadFulfiller() apparently exited " + "its own event loop without canceling the cross-thread promise first; this is " + "undefined behavior so I will crash now"); + abort(); + } + } +} + +kj::Exception XThreadPaf::unfulfilledException() { + // TODO(cleanup): Share code with regular PromiseAndFulfiller for stack tracing here. + return kj::Exception(kj::Exception::Type::FAILED, __FILE__, __LINE__, kj::heapString( + "cross-thread PromiseFulfiller was destroyed without fulfilling the promise.")); } -class BoolEvent: public _::Event { +class ExecutorImpl: public Executor, public AtomicRefcounted { public: - bool fired = false; + using Executor::Executor; - Maybe> fire() override { - fired = true; - return nullptr; + kj::Own addRef() const override { + return kj::atomicAddRef(*this); } }; -class YieldPromiseNode final: public _::PromiseNode { -public: - void onReady(_::Event& event) noexcept override { - event.armBreadthFirst(); +} // namespace _ + +Executor::Executor(EventLoop& loop, Badge): impl(kj::heap(loop)) {} +Executor::~Executor() noexcept(false) {} + +bool Executor::isLive() const { + return impl->state.lockShared()->loop != nullptr; +} + +void Executor::send(_::XThreadEvent& event, bool sync) const { + KJ_ASSERT(event.state == _::XThreadEvent::UNUSED); + + if (sync) { + EventLoop* thisThread = threadLocalEventLoop; + if (thisThread != nullptr && + thisThread->executor.map([this](auto& e) { return e == this; }).orDefault(false)) { + // Invoking a sync request on our own thread. Just execute it directly; if we try to queue + // it to the loop, we'll deadlock. + auto promiseNode = event.execute(); + + // If the function returns a promise, we have no way to pump the event loop to wait for it, + // because the event loop may already be pumping somewhere up the stack. + KJ_ASSERT(promiseNode == nullptr, + "can't call executeSync() on own thread's executor with a promise-returning function"); + + return; + } + } else { + event.replyExecutor = getCurrentThreadExecutor(); + + // Note that async requests will "just work" even if the target executor is our own thread's + // executor. In theory we could detect this case to avoid some locking and signals but that + // would be extra code complexity for probably little benefit. } - void get(_::ExceptionOrValue& output) noexcept override { - output.as<_::Void>() = _::Void(); + + auto lock = impl->state.lockExclusive(); + const EventLoop* loop; + KJ_IF_MAYBE(l, lock->loop) { + loop = l; + } else { + event.setDisconnected(); + return; } -}; -class NeverDonePromiseNode final: public _::PromiseNode { -public: - void onReady(_::Event& event) noexcept override { - // ignore + event.state = _::XThreadEvent::QUEUED; + lock->start.add(event); + + KJ_IF_MAYBE(p, loop->port) { + p->wake(); + } else { + // Event loop will be waiting on executor.wait(), which will be woken when we unlock the mutex. } - void get(_::ExceptionOrValue& output) noexcept override { - KJ_FAIL_REQUIRE("Not ready."); + + if (sync) { + lock.wait([&](auto&) { return event.state == _::XThreadEvent::DONE; }); } -}; +} -} // namespace +void Executor::wait() { + Vector<_::XThreadEvent*> eventsToCancelOutsideLock; + KJ_DEFER(impl->processAsyncCancellations(eventsToCancelOutsideLock)); + + auto lock = impl->state.lockExclusive(); + + lock.wait([](const Impl::State& state) { + return state.isDispatchNeeded(); + }); + + lock->dispatchAll(eventsToCancelOutsideLock); +} + +bool Executor::poll() { + Vector<_::XThreadEvent*> eventsToCancelOutsideLock; + KJ_DEFER(impl->processAsyncCancellations(eventsToCancelOutsideLock)); + + auto lock = impl->state.lockExclusive(); + if (lock->isDispatchNeeded()) { + lock->dispatchAll(eventsToCancelOutsideLock); + return true; + } else { + return false; + } +} + +EventLoop& Executor::getLoop() const { + KJ_IF_MAYBE(l, impl->state.lockShared()->loop) { + return *l; + } else { + kj::throwFatalException(KJ_EXCEPTION(DISCONNECTED, "Executor's event loop has exited")); + } +} + +const Executor& getCurrentThreadExecutor() { + return currentEventLoop().getExecutor(); +} + +// ======================================================================================= +// Fiber implementation. namespace _ { // private -class TaskSetImpl { -public: - inline TaskSetImpl(TaskSet::ErrorHandler& errorHandler) - : errorHandler(errorHandler) {} - - ~TaskSetImpl() noexcept(false) { - // std::map doesn't like it when elements' destructors throw, so carefully disassemble it. - if (!tasks.empty()) { - Vector> deleteMe(tasks.size()); - for (auto& entry: tasks) { - deleteMe.add(kj::mv(entry.second)); - } +#if KJ_USE_FIBERS +#if !(_WIN32 || __CYGWIN__) +struct FiberStack::Impl { + // This struct serves two purposes: + // - It contains OS-specific state that we don't want to declare in the header. + // - It is allocated at the top of the fiber's stack area, so the Impl pointer also serves to + // track where the stack was allocated. + + jmp_buf fiberJmpBuf; + jmp_buf originalJmpBuf; + +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) + // Stuff that we need to pass to __sanitizer_start_switch_fiber() / + // __sanitizer_finish_switch_fiber() when using ASAN. + + void* originalFakeStack = nullptr; + void* fiberFakeStack = nullptr; + // Pointer to ASAN "fake stack" associated with the fiber and its calling stack. Filled in by + // __sanitizer_start_switch_fiber() before switching away, consumed by + // __sanitizer_finish_switch_fiber() upon switching back. + + void const* originalBottom; + size_t originalSize; + // Size and location of the original stack before switching fibers. These are filled in by + // __sanitizer_finish_switch_fiber() after the switch, and must be passed to + // __sanitizer_start_switch_fiber() when switching back later. +#endif + + static Impl* alloc(size_t stackSize, ucontext_t* context) { +#ifndef MAP_ANONYMOUS +#define MAP_ANONYMOUS MAP_ANON +#endif + size_t pageSize = getPageSize(); + size_t allocSize = stackSize + pageSize; // size plus guard page and impl + + // Allocate virtual address space for the stack but make it inaccessible initially. + // TODO(someday): Does it make sense to use MAP_GROWSDOWN on Linux? It's a kind of bizarre flag + // that causes the mapping to automatically allocate extra pages (beyond the range specified) + // until it hits something... Note that on FreeBSD, MAP_STACK has the effect that + // MAP_GROWSDOWN has on Linux. (MAP_STACK, meanwhile, has no effect on Linux.) + void* stackMapping = mmap(nullptr, allocSize, PROT_NONE, + MAP_ANONYMOUS | MAP_PRIVATE, -1, 0); + if (stackMapping == MAP_FAILED) { + KJ_FAIL_SYSCALL("mmap(new stack)", errno); } + KJ_ON_SCOPE_FAILURE({ + KJ_SYSCALL(munmap(stackMapping, allocSize)) { break; } + }); + + void* stack = reinterpret_cast(stackMapping) + pageSize; + // Now mark everything except the guard page as read-write. We assume the stack grows down, so + // the guard page is at the beginning. No modern architecture uses stacks that grow up. + KJ_SYSCALL(mprotect(stack, stackSize, PROT_READ | PROT_WRITE)); + + // Stick `Impl` at the top of the stack. + Impl* impl = (reinterpret_cast(reinterpret_cast(stack) + stackSize) - 1); + + // Note: mmap() allocates zero'd pages so we don't have to memset() anything here. + + KJ_SYSCALL(getcontext(context)); +#if __APPLE__ && __aarch64__ + // Per issue #1386, apple on arm64 zeros the entire configured stack. + // But this is redundant, since we just allocated the stack with mmap() which + // returns zero'd pages. Re-zeroing is both slow and results in prematurely + // allocating pages we may not need -- it's normal for stacks to rely heavily + // on lazy page allocation to avoid wasting memory. Instead, we lie: + // we allocate the full size, but tell the ucontext the stack is the last + // page only. This appears to work as no particular bounds checks or + // anything are set up based on what we say here. + context->uc_stack.ss_size = min(pageSize, stackSize) - sizeof(Impl); + context->uc_stack.ss_sp = reinterpret_cast(stack) + stackSize - min(pageSize, stackSize); +#else + context->uc_stack.ss_size = stackSize - sizeof(Impl); + context->uc_stack.ss_sp = reinterpret_cast(stack); +#endif + context->uc_stack.ss_flags = 0; + // We don't use uc_link since our fiber start routine runs forever in a loop to allow for + // reuse. When we're done with the fiber, we just destroy it, without switching to it's + // stack. This is safe since the start routine doesn't allocate any memory or RAII objects + // before looping. + context->uc_link = 0; + + return impl; } - class Task final: public Event { - public: - Task(TaskSetImpl& taskSet, Own<_::PromiseNode>&& nodeParam) - : taskSet(taskSet), node(kj::mv(nodeParam)) { - node->setSelfPointer(&node); - node->onReady(*this); - } - - protected: - Maybe> fire() override { - // Get the result. - _::ExceptionOr<_::Void> result; - node->get(result); - - // Delete the node, catching any exceptions. - KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() { - node = nullptr; - })) { - result.addException(kj::mv(*exception)); - } + static void free(Impl* impl, size_t stackSize) { + size_t allocSize = stackSize + getPageSize(); + void* stack = reinterpret_cast(impl + 1) - allocSize; + KJ_SYSCALL(munmap(stack, allocSize)) { break; } + } - // Call the error handler if there was an exception. - KJ_IF_MAYBE(e, result.exception) { - taskSet.errorHandler.taskFailed(kj::mv(*e)); - } + static size_t getPageSize() { +#ifndef _SC_PAGESIZE +#define _SC_PAGESIZE _SC_PAGE_SIZE +#endif + static size_t result = sysconf(_SC_PAGESIZE); + return result; + } +}; +#endif +#endif - // Remove from the task map. - auto iter = taskSet.tasks.find(this); - KJ_ASSERT(iter != taskSet.tasks.end()); - Own self = kj::mv(iter->second); - taskSet.tasks.erase(iter); - return mv(self); - } +struct FiberStack::StartRoutine { +#if _WIN32 || __CYGWIN__ + static void WINAPI run(LPVOID ptr) { + // This is the static C-style function we pass to CreateFiber(). + reinterpret_cast(ptr)->run(); + } +#else + [[noreturn]] static void run(int arg1, int arg2) { + // This is the static C-style function we pass to makeContext(). - _::PromiseNode* getInnerForTrace() override { - return node; - } + // POSIX says the arguments are ints, not pointers. So we split our pointer in half in order to + // work correctly on 64-bit machines. Gross. + uintptr_t ptr = static_cast(arg1); + ptr |= static_cast(static_cast(arg2)) << (sizeof(ptr) * 4); - private: - TaskSetImpl& taskSet; - kj::Own<_::PromiseNode> node; - }; + auto& stack = *reinterpret_cast(ptr); + + __sanitizer_finish_switch_fiber(nullptr, + &stack.impl->originalBottom, &stack.impl->originalSize); - void add(Promise&& promise) { - auto task = heap(*this, kj::mv(promise.node)); - Task* ptr = task; - tasks.insert(std::make_pair(ptr, kj::mv(task))); + // We first switch to the fiber inside of the FiberStack constructor. This is just for + // initialization purposes, and we're expected to switch back immediately. + stack.switchToMain(); + + // OK now have a real job. + stack.run(); } +#endif +}; - kj::String trace() { - kj::Vector traces; - for (auto& entry: tasks) { - traces.add(entry.second->trace()); +void FiberStack::run() { + // Loop forever so that the fiber can be reused. + for (;;) { + KJ_SWITCH_ONEOF(main) { + KJ_CASE_ONEOF(event, FiberBase*) { + event->run(); + } + KJ_CASE_ONEOF(func, SynchronousFunc*) { + KJ_IF_MAYBE(exception, kj::runCatchingExceptions(func->func)) { + func->exception.emplace(kj::mv(*exception)); + } + } } - return kj::strArray(traces, "\n============================================\n"); + + // Wait for the fiber to be used again. Note the fiber might simply be destroyed without this + // ever returning. That's OK because we don't have any nontrivial destructors on the stack + // at this point. + switchToMain(); } +} -private: - TaskSet::ErrorHandler& errorHandler; +FiberStack::FiberStack(size_t stackSizeParam) + // Force stackSize to a reasonable minimum. + : stackSize(kj::max(stackSizeParam, 65536)) +{ - // TODO(perf): Use a linked list instead. - std::map> tasks; -}; +#if KJ_USE_FIBERS +#if _WIN32 || __CYGWIN__ + // We can create fibers before we convert the main thread into a fiber in FiberBase + KJ_WIN32(osFiber = CreateFiber(stackSize, &StartRoutine::run, this)); -class LoggingErrorHandler: public TaskSet::ErrorHandler { -public: - static LoggingErrorHandler instance; +#else + // Note: Nothing below here can throw. If that changes then we need to call Impl::free(impl) + // on exceptions... + ucontext_t context; + impl = Impl::alloc(stackSize, &context); + + // POSIX says the arguments are ints, not pointers. So we split our pointer in half in order to + // work correctly on 64-bit machines. Gross. + uintptr_t ptr = reinterpret_cast(this); + int arg1 = ptr & ((uintptr_t(1) << (sizeof(ptr) * 4)) - 1); + int arg2 = ptr >> (sizeof(ptr) * 4); + + makecontext(&context, reinterpret_cast(&StartRoutine::run), 2, arg1, arg2); + + __sanitizer_start_switch_fiber(&impl->originalFakeStack, impl, stackSize - sizeof(Impl)); + if (_setjmp(impl->originalJmpBuf) == 0) { + setcontext(&context); + } + __sanitizer_finish_switch_fiber(impl->originalFakeStack, nullptr, nullptr); +#endif +#else +#if KJ_NO_EXCEPTIONS + KJ_UNIMPLEMENTED("Fibers are not implemented because exceptions are disabled"); +#else + KJ_UNIMPLEMENTED( + "Fibers are not implemented on this platform because its C library lacks setcontext() " + "and friends. If you'd like to see fiber support added, file a bug to let us know. " + "We can likely make it happen using assembly, but didn't want to try unless it was " + "actually needed."); +#endif +#endif +} - void taskFailed(kj::Exception&& exception) override { - KJ_LOG(ERROR, "Uncaught exception in daemonized task.", exception); +FiberStack::~FiberStack() noexcept(false) { +#if KJ_USE_FIBERS +#if _WIN32 || __CYGWIN__ + DeleteFiber(osFiber); +#else + Impl::free(impl, stackSize); +#endif +#endif +} + +void FiberStack::initialize(FiberBase& fiber) { + KJ_REQUIRE(this->main == nullptr); + this->main = &fiber; +} + +void FiberStack::initialize(SynchronousFunc& func) { + KJ_REQUIRE(this->main == nullptr); + this->main = &func; +} + +FiberBase::FiberBase(size_t stackSize, _::ExceptionOrValue& result, SourceLocation location) + : Event(location), state(WAITING), stack(kj::heap(stackSize)), result(result) { + stack->initialize(*this); + ensureThreadCanRunFibers(); +} + +FiberBase::FiberBase(const FiberPool& pool, _::ExceptionOrValue& result, SourceLocation location) + : Event(location), state(WAITING), result(result) { + stack = pool.impl->takeStack(); + stack->initialize(*this); + ensureThreadCanRunFibers(); +} + +FiberBase::~FiberBase() noexcept(false) {} + +void FiberBase::cancel() { + // Called by `~Fiber()` to begin teardown. We can't do this work in `~FiberBase()` because the + // `Fiber` subclass contains members that may still be in-use until the fiber stops. + + switch (state) { + case WAITING: + // We can't just free the stack while the fiber is running. We need to force it to execute + // until finished, so we cause it to throw an exception. + state = CANCELED; + stack->switchToFiber(); + + // The fiber should only switch back to the main stack on completion, because any further + // calls to wait() would throw before trying to switch. + KJ_ASSERT(state == FINISHED); + + // The fiber shut down properly so the stack is safe to reuse. + stack->reset(); + break; + + case RUNNING: + case CANCELED: + // Bad news. + KJ_LOG(FATAL, "fiber tried to cancel itself"); + ::abort(); + break; + + case FINISHED: + // Normal completion, yay. + stack->reset(); + break; } -}; +} -LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler(); +Maybe> FiberBase::fire() { + KJ_ASSERT(state == WAITING); + state = RUNNING; + stack->switchToFiber(); + return nullptr; +} -class NullEventPort: public EventPort { -public: - bool wait() override { - KJ_FAIL_REQUIRE("Nothing to wait for; this thread would hang forever."); +void FiberStack::switchToFiber() { + // Switch from the main stack to the fiber. Returns once the fiber either calls switchToMain() + // or returns from its main function. +#if KJ_USE_FIBERS +#if _WIN32 || __CYGWIN__ + SwitchToFiber(osFiber); +#else + __sanitizer_start_switch_fiber(&impl->originalFakeStack, impl, stackSize - sizeof(Impl)); + if (_setjmp(impl->originalJmpBuf) == 0) { + _longjmp(impl->fiberJmpBuf, 1); + } + __sanitizer_finish_switch_fiber(impl->originalFakeStack, nullptr, nullptr); +#endif +#endif +} +void FiberStack::switchToMain() { + // Switch from the fiber to the main stack. Returns the next time the main stack calls + // switchToFiber(). +#if KJ_USE_FIBERS +#if _WIN32 || __CYGWIN__ + SwitchToFiber(getMainWin32Fiber()); +#else + // TODO(someady): In theory, the last time we switch away from the fiber, we should pass `nullptr` + // for the first argument here, so that ASAN destroys the fake stack. However, as currently + // designed, we don't actually know if we're switching away for the last time. It's understood + // that when we call switchToMain() in FiberStack::run(), then the main stack is allowed to + // destroy the fiber, or reuse it. I don't want to develop a mechanism to switch back to the + // fiber on final destruction just to get the hints right, so instead we leak the fake stack. + // This doesn't seem to cause any problems -- it's not even detected by ASAN as a memory leak. + // But if we wanted to run ASAN builds in production or something, it might be an issue. + __sanitizer_start_switch_fiber(&impl->fiberFakeStack, + impl->originalBottom, impl->originalSize); + if (_setjmp(impl->fiberJmpBuf) == 0) { + _longjmp(impl->originalJmpBuf, 1); } + __sanitizer_finish_switch_fiber(impl->fiberFakeStack, + &impl->originalBottom, &impl->originalSize); +#endif +#endif +} - bool poll() override { return false; } +void FiberBase::run() { +#if !KJ_NO_EXCEPTIONS + bool caughtCanceled = false; + state = RUNNING; + KJ_DEFER(state = FINISHED); - void wake() const override { - // TODO(someday): Implement using condvar. - kj::throwRecoverableException(KJ_EXCEPTION(UNIMPLEMENTED, - "Cross-thread events are not yet implemented for EventLoops with no EventPort.")); + WaitScope waitScope(currentEventLoop(), *this); + + try { + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + runImpl(waitScope); + })) { + result.addException(kj::mv(*exception)); + } + } catch (CanceledException) { + if (state != CANCELED) { + // no idea who would throw this but it's not really our problem + result.addException(KJ_EXCEPTION(FAILED, "Caught CanceledException, but fiber wasn't canceled")); + } + caughtCanceled = true; } - static NullEventPort instance; -}; + if (state == CANCELED && !caughtCanceled) { + KJ_LOG(ERROR, "Canceled fiber apparently caught CanceledException and didn't rethrow it. " + "Generally, applications should not catch CanceledException, but if they do, they must always rethrow."); + } + + onReadyEvent.arm(); +#endif +} + +void FiberBase::onReady(_::Event* event) noexcept { + onReadyEvent.init(event); +} -NullEventPort NullEventPort::instance = NullEventPort(); +void FiberBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + if (stopAtNextEvent) return; + currentInner->tracePromise(builder, false); + stack->trace(builder); +} + +void FiberBase::traceEvent(TraceBuilder& builder) { + currentInner->tracePromise(builder, true); + stack->trace(builder); + onReadyEvent.traceEvent(builder); +} } // namespace _ (private) @@ -209,22 +1720,29 @@ void EventPort::wake() const { } EventLoop::EventLoop() - : port(_::NullEventPort::instance), - daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {} + : daemons(kj::heap(_::LoggingErrorHandler::instance)) {} EventLoop::EventLoop(EventPort& port) : port(port), - daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {} + daemons(kj::heap(_::LoggingErrorHandler::instance)) {} EventLoop::~EventLoop() noexcept(false) { - // Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop - // some more. + // Destroy all "daemon" tasks, noting that their destructors might register more daemon tasks. + while (!daemons->isEmpty()) { + auto oldDaemons = kj::mv(daemons); + daemons = kj::heap(_::LoggingErrorHandler::instance); + } daemons = nullptr; + KJ_IF_MAYBE(e, executor) { + // Cancel all outstanding cross-thread events. + e->get()->impl->disconnect(); + } + // The application _should_ destroy everything using the EventLoop before destroying the // EventLoop itself, so if there are events on the loop, this indicates a memory leak. KJ_REQUIRE(head == nullptr, "EventLoop destroyed with events still in the queue. Memory leak?", - head->trace()) { + head->traceEvent()) { // Unlink all the events and hope that no one ever fires them... _::Event* event = head; while (event != nullptr) { @@ -269,6 +1787,9 @@ bool EventLoop::turn() { } depthFirstInsertPoint = &head; + if (breadthFirstInsertPoint == &event->next) { + breadthFirstInsertPoint = &head; + } if (tail == &event->next) { tail = &head; } @@ -280,6 +1801,8 @@ bool EventLoop::turn() { { event->firing = true; KJ_DEFER(event->firing = false); + currentlyFiring = event; + KJ_DEFER(currentlyFiring = nullptr); eventToDestroy = event->fire(); } @@ -292,9 +1815,19 @@ bool EventLoop::isRunnable() { return head != nullptr; } +const Executor& EventLoop::getExecutor() { + KJ_IF_MAYBE(e, executor) { + return **e; + } else { + return *executor.emplace(kj::atomicRefcounted<_::ExecutorImpl>(*this, Badge())); + } +} + void EventLoop::setRunnable(bool runnable) { if (runnable != lastRunnableState) { - port.setRunnable(runnable); + KJ_IF_MAYBE(p, port) { + p->setRunnable(runnable); + } lastRunnableState = runnable; } } @@ -312,48 +1845,276 @@ void EventLoop::leaveScope() { threadLocalEventLoop = nullptr; } -namespace _ { // private +void EventLoop::wait() { + KJ_IF_MAYBE(p, port) { + if (p->wait()) { + // Another thread called wake(). Check for cross-thread events. + KJ_IF_MAYBE(e, executor) { + e->get()->poll(); + } + } + } else KJ_IF_MAYBE(e, executor) { + e->get()->wait(); + } else { + KJ_FAIL_REQUIRE("Nothing to wait for; this thread would hang forever."); + } +} + +void EventLoop::poll() { + KJ_IF_MAYBE(p, port) { + if (p->poll()) { + // Another thread called wake(). Check for cross-thread events. + KJ_IF_MAYBE(e, executor) { + e->get()->poll(); + } + } + } else KJ_IF_MAYBE(e, executor) { + e->get()->poll(); + } +} + +uint WaitScope::poll(uint maxTurnCount) { + KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread."); + KJ_REQUIRE(!loop.running, "poll() is not allowed from within event callbacks."); + + loop.running = true; + KJ_DEFER(loop.running = false); + + uint turnCount = 0; + runOnStackPool([&]() { + while (turnCount < maxTurnCount) { + if (loop.turn()) { + ++turnCount; + } else { + // No events in the queue. Poll for I/O. + loop.poll(); + + if (!loop.isRunnable()) { + // Still no events in the queue. We're done. + return; + } + } + } + }); + return turnCount; +} + +void WaitScope::cancelAllDetached() { + KJ_REQUIRE(fiber == nullptr, + "can't call cancelAllDetached() on a fiber WaitScope, only top-level"); + + while (!loop.daemons->isEmpty()) { + auto oldDaemons = kj::mv(loop.daemons); + loop.daemons = kj::heap(_::LoggingErrorHandler::instance); + // Destroying `oldDaemons` could theoretically add new ones. + } +} + +namespace _ { // private + +#if !KJ_NO_EXCEPTIONS +static kj::CanceledException fiberCanceledException() { + // Construct the exception to throw from wait() when the fiber has been canceled (because the + // promise returned by startFiber() was dropped before completion). + return kj::CanceledException { }; +}; +#endif + +void waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, WaitScope& waitScope, + SourceLocation location) { + EventLoop& loop = waitScope.loop; + KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread."); + +#if !KJ_NO_EXCEPTIONS + // we don't support fibers when running without exceptions, so just remove the whole block + KJ_IF_MAYBE(fiber, waitScope.fiber) { + if (fiber->state == FiberBase::CANCELED) { + throw fiberCanceledException(); + } + KJ_REQUIRE(fiber->state == FiberBase::RUNNING, + "This WaitScope can only be used within the fiber that created it."); + + node->setSelfPointer(&node); + node->onReady(fiber); + + fiber->currentInner = node; + KJ_DEFER(fiber->currentInner = nullptr); + + // Switch to the main stack to run the event loop. + fiber->state = FiberBase::WAITING; + fiber->stack->switchToMain(); + + // The main stack switched back to us, meaning either the event we registered with + // node->onReady() fired, or we are being canceled by FiberBase's destructor. + + if (fiber->state == FiberBase::CANCELED) { + throw fiberCanceledException(); + } + + KJ_ASSERT(fiber->state == FiberBase::RUNNING); + } else { +#endif + KJ_REQUIRE(!loop.running, "wait() is not allowed from within event callbacks."); + + RootEvent doneEvent(node, reinterpret_cast(&waitImpl), location); + node->setSelfPointer(&node); + node->onReady(&doneEvent); + + loop.running = true; + KJ_DEFER(loop.running = false); + + for (;;) { + waitScope.runOnStackPool([&]() { + uint counter = 0; + while (!doneEvent.fired) { + if (!loop.turn()) { + // No events in the queue. Wait for callback. + return; + } else if (++counter > waitScope.busyPollInterval) { + // Note: It's intentional that if busyPollInterval is kj::maxValue, we never poll. + counter = 0; + loop.poll(); + } + } + }); + + if (doneEvent.fired) { + break; + } else { + loop.wait(); + } + } + + loop.setRunnable(loop.isRunnable()); +#if !KJ_NO_EXCEPTIONS + } +#endif + + waitScope.runOnStackPool([&]() { + node->get(result); + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + node = nullptr; + })) { + result.addException(kj::mv(*exception)); + } + }); +} + +bool pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location) { + EventLoop& loop = waitScope.loop; + KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread."); + KJ_REQUIRE(waitScope.fiber == nullptr, "poll() is not supported in fibers."); + KJ_REQUIRE(!loop.running, "poll() is not allowed from within event callbacks."); + + RootEvent doneEvent(&node, reinterpret_cast(&pollImpl), location); + node.onReady(&doneEvent); + + loop.running = true; + KJ_DEFER(loop.running = false); + + waitScope.runOnStackPool([&]() { + while (!doneEvent.fired) { + if (!loop.turn()) { + // No events in the queue. Poll for I/O. + loop.poll(); + + if (!doneEvent.fired && !loop.isRunnable()) { + // No progress. Give up. + node.onReady(nullptr); + loop.setRunnable(false); + break; + } + } + } + }); + + if (!doneEvent.fired) { + return false; + } + + loop.setRunnable(loop.isRunnable()); + return true; +} + +Promise yield() { + class YieldPromiseNode final: public _::PromiseNode { + public: + void destroy() override {} -void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope& waitScope) { - EventLoop& loop = waitScope.loop; - KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread."); - KJ_REQUIRE(!loop.running, "wait() is not allowed from within event callbacks."); + void onReady(_::Event* event) noexcept override { + if (event) event->armBreadthFirst(); + } + void get(_::ExceptionOrValue& output) noexcept override { + output.as<_::Void>() = _::Void(); + } + void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { + builder.add(reinterpret_cast(&kj::evalLater)); + } + }; - BoolEvent doneEvent; - node->setSelfPointer(&node); - node->onReady(doneEvent); + static YieldPromiseNode NODE; + return _::PromiseNode::to>(OwnPromiseNode(&NODE)); +} - loop.running = true; - KJ_DEFER(loop.running = false); +Promise yieldHarder() { + class YieldHarderPromiseNode final: public _::PromiseNode { + public: + void destroy() override {} - while (!doneEvent.fired) { - if (!loop.turn()) { - // No events in the queue. Wait for callback. - loop.port.wait(); + void onReady(_::Event* event) noexcept override { + if (event) event->armLast(); } - } - - loop.setRunnable(loop.isRunnable()); + void get(_::ExceptionOrValue& output) noexcept override { + output.as<_::Void>() = _::Void(); + } + void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { + builder.add(reinterpret_cast(&kj::evalLast)); + } + }; - node->get(result); - KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { - node = nullptr; - })) { - result.addException(kj::mv(*exception)); - } + static YieldHarderPromiseNode NODE; + return _::PromiseNode::to>(OwnPromiseNode(&NODE)); } -Promise yield() { - return Promise(false, kj::heap()); +OwnPromiseNode readyNow() { + class ReadyNowPromiseNode: public ImmediatePromiseNodeBase { + // This is like `ConstPromiseNode`, but the compiler won't let me pass a literal + // value of type `Void` as a template parameter. (Might require C++20?) + + public: + void destroy() override {} + void get(ExceptionOrValue& output) noexcept override { + output.as() = Void(); + } + }; + + static ReadyNowPromiseNode NODE; + return OwnPromiseNode(&NODE); } -Own neverDone() { - return kj::heap(); +OwnPromiseNode neverDone() { + class NeverDonePromiseNode final: public _::PromiseNode { + public: + void destroy() override {} + + void onReady(_::Event* event) noexcept override { + // ignore + } + void get(_::ExceptionOrValue& output) noexcept override { + KJ_FAIL_REQUIRE("Not ready."); + } + void tracePromise(_::TraceBuilder& builder, bool stopAtNextEvent) override { + builder.add(_::getMethodStartAddress(kj::NEVER_DONE, &_::NeverDone::wait)); + } + }; + + static NeverDonePromiseNode NODE; + return OwnPromiseNode(&NODE); } -void NeverDone::wait(WaitScope& waitScope) const { +void NeverDone::wait(WaitScope& waitScope, SourceLocation location) const { ExceptionOr dummy; - waitImpl(neverDone(), dummy, waitScope); + waitImpl(neverDone(), dummy, waitScope, location); KJ_UNREACHABLE; } @@ -363,33 +2124,37 @@ void detach(kj::Promise&& promise) { loop.daemons->add(kj::mv(promise)); } -Event::Event() - : loop(currentEventLoop()), next(nullptr), prev(nullptr) {} +Event::Event(SourceLocation location) + : loop(currentEventLoop()), next(nullptr), prev(nullptr), location(location) {} + +Event::Event(kj::EventLoop& loop, SourceLocation location) + : loop(loop), next(nullptr), prev(nullptr), location(location) {} Event::~Event() noexcept(false) { - if (prev != nullptr) { - if (loop.tail == &next) { - loop.tail = prev; - } - if (loop.depthFirstInsertPoint == &next) { - loop.depthFirstInsertPoint = prev; - } + live = 0; - *prev = next; - if (next != nullptr) { - next->prev = prev; - } - } + // Prevent compiler from eliding this store above. This line probably isn't needed because there + // are complex calls later in this destructor, and the compiler probably can't prove that they + // won't come back and examine `live`, so it won't elide the write anyway. However, an + // atomic_signal_fence is also sufficient to tell the compiler that a signal handler might access + // `live`, so it won't optimize away the write. Note that a signal fence does not produce + // any instructions, it just blocks compiler optimizations. + std::atomic_signal_fence(std::memory_order_acq_rel); + + disarm(); KJ_REQUIRE(!firing, "Promise callback destroyed itself."); - KJ_REQUIRE(threadLocalEventLoop == &loop || threadLocalEventLoop == nullptr, - "Promise destroyed from a different thread than it was created in."); } void Event::armDepthFirst() { KJ_REQUIRE(threadLocalEventLoop == &loop || threadLocalEventLoop == nullptr, "Event armed from different thread than it was created in. You must use " - "the thread-safe work queue to queue events cross-thread."); + "Executor to queue events cross-thread."); + if (live != MAGIC_LIVE_VALUE) { + ([this]() noexcept { + KJ_FAIL_ASSERT("tried to arm Event after it was destroyed", location); + })(); + } if (prev == nullptr) { next = *loop.depthFirstInsertPoint; @@ -401,6 +2166,9 @@ void Event::armDepthFirst() { loop.depthFirstInsertPoint = &next; + if (loop.breadthFirstInsertPoint == prev) { + loop.breadthFirstInsertPoint = &next; + } if (loop.tail == prev) { loop.tail = &next; } @@ -412,112 +2180,169 @@ void Event::armDepthFirst() { void Event::armBreadthFirst() { KJ_REQUIRE(threadLocalEventLoop == &loop || threadLocalEventLoop == nullptr, "Event armed from different thread than it was created in. You must use " - "the thread-safe work queue to queue events cross-thread."); + "Executor to queue events cross-thread."); + if (live != MAGIC_LIVE_VALUE) { + ([this]() noexcept { + KJ_FAIL_ASSERT("tried to arm Event after it was destroyed", location); + })(); + } if (prev == nullptr) { - next = *loop.tail; - prev = loop.tail; + next = *loop.breadthFirstInsertPoint; + prev = loop.breadthFirstInsertPoint; *prev = this; if (next != nullptr) { next->prev = &next; } - loop.tail = &next; + loop.breadthFirstInsertPoint = &next; + + if (loop.tail == prev) { + loop.tail = &next; + } loop.setRunnable(true); } } -_::PromiseNode* Event::getInnerForTrace() { - return nullptr; -} +void Event::armLast() { + KJ_REQUIRE(threadLocalEventLoop == &loop || threadLocalEventLoop == nullptr, + "Event armed from different thread than it was created in. You must use " + "Executor to queue events cross-thread."); + if (live != MAGIC_LIVE_VALUE) { + ([this]() noexcept { + KJ_FAIL_ASSERT("tried to arm Event after it was destroyed", location); + })(); + } -#if !KJ_NO_RTTI -#if __GNUC__ -static kj::String demangleTypeName(const char* name) { - int status; - char* buf = abi::__cxa_demangle(name, nullptr, nullptr, &status); - kj::String result = kj::heapString(buf == nullptr ? name : buf); - free(buf); - return kj::mv(result); + if (prev == nullptr) { + next = *loop.breadthFirstInsertPoint; + prev = loop.breadthFirstInsertPoint; + *prev = this; + if (next != nullptr) { + next->prev = &next; + } + + // We don't update loop.breadthFirstInsertPoint because we want further inserts to go *before* + // this event. + + if (loop.tail == prev) { + loop.tail = &next; + } + + loop.setRunnable(true); + } } -#else -static kj::String demangleTypeName(const char* name) { - return kj::heapString(name); + +bool Event::isNext() { + return loop.running && loop.head == this; } -#endif -#endif -static kj::String traceImpl(Event* event, _::PromiseNode* node) { -#if KJ_NO_RTTI - return heapString("Trace not available because RTTI is disabled."); -#else - kj::Vector trace; +void Event::disarm() { + if (prev != nullptr) { + if (threadLocalEventLoop != &loop && threadLocalEventLoop != nullptr) { + KJ_LOG(FATAL, "Promise destroyed from a different thread than it was created in."); + // There's no way out of this place without UB, so abort now. + abort(); + } - if (event != nullptr) { - trace.add(demangleTypeName(typeid(*event).name())); - } + if (loop.tail == &next) { + loop.tail = prev; + } + if (loop.depthFirstInsertPoint == &next) { + loop.depthFirstInsertPoint = prev; + } + if (loop.breadthFirstInsertPoint == &next) { + loop.breadthFirstInsertPoint = prev; + } - while (node != nullptr) { - trace.add(demangleTypeName(typeid(*node).name())); - node = node->getInnerForTrace(); + *prev = next; + if (next != nullptr) { + next->prev = prev; + } + + prev = nullptr; + next = nullptr; } +} - return strArray(trace, "\n"); -#endif +String Event::traceEvent() { + void* space[32]; + TraceBuilder builder(space); + traceEvent(builder); + return kj::str(builder); } -kj::String Event::trace() { - return traceImpl(this, getInnerForTrace()); +String TraceBuilder::toString() { + auto result = finish(); + return kj::str(stringifyStackTraceAddresses(result), + stringifyStackTrace(result)); } } // namespace _ (private) -// ======================================================================================= - -TaskSet::TaskSet(ErrorHandler& errorHandler) - : impl(heap<_::TaskSetImpl>(errorHandler)) {} - -TaskSet::~TaskSet() noexcept(false) {} +ArrayPtr getAsyncTrace(ArrayPtr space) { + EventLoop* loop = threadLocalEventLoop; + if (loop == nullptr) return nullptr; + if (loop->currentlyFiring == nullptr) return nullptr; -void TaskSet::add(Promise&& promise) { - impl->add(kj::mv(promise)); + _::TraceBuilder builder(space); + loop->currentlyFiring->traceEvent(builder); + return builder.finish(); } -kj::String TaskSet::trace() { - return impl->trace(); +kj::String getAsyncTrace() { + void* space[32]; + auto trace = getAsyncTrace(space); + return kj::str(stringifyStackTraceAddresses(trace), stringifyStackTrace(trace)); } +// ======================================================================================= + namespace _ { // private kj::String PromiseBase::trace() { - return traceImpl(nullptr, node); + void* space[32]; + TraceBuilder builder(space); + node->tracePromise(builder, false); + return kj::str(builder); } -void PromiseNode::setSelfPointer(Own* selfPtr) noexcept {} +void PromiseNode::setSelfPointer(OwnPromiseNode* selfPtr) noexcept {} -PromiseNode* PromiseNode::getInnerForTrace() { return nullptr; } - -void PromiseNode::OnReadyEvent::init(Event& newEvent) { +void PromiseNode::OnReadyEvent::init(Event* newEvent) { if (event == _kJ_ALREADY_READY) { // A new continuation was added to a promise that was already ready. In this case, we schedule // breadth-first, to make it difficult for applications to accidentally starve the event loop // by repeatedly waiting on immediate promises. - newEvent.armBreadthFirst(); + if (newEvent) newEvent->armBreadthFirst(); } else { - event = &newEvent; + event = newEvent; } } void PromiseNode::OnReadyEvent::arm() { - if (event == nullptr) { - event = _kJ_ALREADY_READY; - } else { + KJ_ASSERT(event != _kJ_ALREADY_READY, "arm() should only be called once"); + + if (event != nullptr) { // A promise resolved and an event is already waiting on it. In this case, arm in depth-first // order so that the event runs immediately after the current one. This way, chained promises // execute together for better cache locality and lower latency. event->armDepthFirst(); } + + event = _kJ_ALREADY_READY; +} + +void PromiseNode::OnReadyEvent::armBreadthFirst() { + KJ_ASSERT(event != _kJ_ALREADY_READY, "armBreadthFirst() should only be called once"); + + if (event != nullptr) { + // A promise resolved and an event is already waiting on it. + event->armBreadthFirst(); + } + + event = _kJ_ALREADY_READY; } // ------------------------------------------------------------------- @@ -525,25 +2350,33 @@ void PromiseNode::OnReadyEvent::arm() { ImmediatePromiseNodeBase::ImmediatePromiseNodeBase() {} ImmediatePromiseNodeBase::~ImmediatePromiseNodeBase() noexcept(false) {} -void ImmediatePromiseNodeBase::onReady(Event& event) noexcept { - event.armBreadthFirst(); +void ImmediatePromiseNodeBase::onReady(Event* event) noexcept { + if (event) event->armBreadthFirst(); +} + +void ImmediatePromiseNodeBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + // Maybe returning the address of get() will give us a function name with meaningful type + // information. + builder.add(getMethodStartAddress(implicitCast(*this), &PromiseNode::get)); } ImmediateBrokenPromiseNode::ImmediateBrokenPromiseNode(Exception&& exception) : exception(kj::mv(exception)) {} +void ImmediateBrokenPromiseNode::destroy() { freePromise(this); } + void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept { output.exception = kj::mv(exception); } // ------------------------------------------------------------------- -AttachmentPromiseNodeBase::AttachmentPromiseNodeBase(Own&& dependencyParam) +AttachmentPromiseNodeBase::AttachmentPromiseNodeBase(OwnPromiseNode&& dependencyParam) : dependency(kj::mv(dependencyParam)) { dependency->setSelfPointer(&dependency); } -void AttachmentPromiseNodeBase::onReady(Event& event) noexcept { +void AttachmentPromiseNodeBase::onReady(Event* event) noexcept { dependency->onReady(event); } @@ -551,8 +2384,11 @@ void AttachmentPromiseNodeBase::get(ExceptionOrValue& output) noexcept { dependency->get(output); } -PromiseNode* AttachmentPromiseNodeBase::getInnerForTrace() { - return dependency; +void AttachmentPromiseNodeBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + dependency->tracePromise(builder, stopAtNextEvent); + + // TODO(debug): Maybe use __builtin_return_address to get the locations that called fork() and + // addBranch()? } void AttachmentPromiseNodeBase::dropDependency() { @@ -562,12 +2398,12 @@ void AttachmentPromiseNodeBase::dropDependency() { // ------------------------------------------------------------------- TransformPromiseNodeBase::TransformPromiseNodeBase( - Own&& dependencyParam, void* continuationTracePtr) + OwnPromiseNode&& dependencyParam, void* continuationTracePtr) : dependency(kj::mv(dependencyParam)), continuationTracePtr(continuationTracePtr) { dependency->setSelfPointer(&dependency); } -void TransformPromiseNodeBase::onReady(Event& event) noexcept { +void TransformPromiseNodeBase::onReady(Event* event) noexcept { dependency->onReady(event); } @@ -580,8 +2416,15 @@ void TransformPromiseNodeBase::get(ExceptionOrValue& output) noexcept { } } -PromiseNode* TransformPromiseNodeBase::getInnerForTrace() { - return dependency; +void TransformPromiseNodeBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + // Note that we null out the dependency just before calling our own continuation, which + // conveniently means that if we're currently executing the continuation when the trace is + // requested, it won't trace into the obsolete dependency. Nice. + if (dependency.get() != nullptr) { + dependency->tracePromise(builder, stopAtNextEvent); + } + + builder.add(continuationTracePtr); } void TransformPromiseNodeBase::dropDependency() { @@ -603,7 +2446,7 @@ void TransformPromiseNodeBase::getDepResult(ExceptionOrValue& output) { // ------------------------------------------------------------------- -ForkBranchBase::ForkBranchBase(Own&& hubParam): hub(kj::mv(hubParam)) { +ForkBranchBase::ForkBranchBase(OwnForkHubBase&& hubParam): hub(kj::mv(hubParam)) { if (hub->tailBranch == nullptr) { onReadyEvent.arm(); } else { @@ -635,20 +2478,28 @@ void ForkBranchBase::releaseHub(ExceptionOrValue& output) { } } -void ForkBranchBase::onReady(Event& event) noexcept { +void ForkBranchBase::onReady(Event* event) noexcept { onReadyEvent.init(event); } -PromiseNode* ForkBranchBase::getInnerForTrace() { - return hub->getInnerForTrace(); +void ForkBranchBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + if (stopAtNextEvent) return; + + if (hub.get() != nullptr) { + hub->inner->tracePromise(builder, false); + } + + // TODO(debug): Maybe use __builtin_return_address to get the locations that called fork() and + // addBranch()? } // ------------------------------------------------------------------- -ForkHubBase::ForkHubBase(Own&& innerParam, ExceptionOrValue& resultRef) - : inner(kj::mv(innerParam)), resultRef(resultRef) { +ForkHubBase::ForkHubBase(OwnPromiseNode&& innerParam, ExceptionOrValue& resultRef, + SourceLocation location) + : Event(location), inner(kj::mv(innerParam)), resultRef(resultRef) { inner->setSelfPointer(&inner); - inner->onReady(*this); + inner->onReady(this); } Maybe> ForkHubBase::fire() { @@ -673,25 +2524,33 @@ Maybe> ForkHubBase::fire() { return nullptr; } -_::PromiseNode* ForkHubBase::getInnerForTrace() { - return inner; +void ForkHubBase::traceEvent(TraceBuilder& builder) { + if (inner.get() != nullptr) { + inner->tracePromise(builder, true); + } + + if (headBranch != nullptr) { + // We'll trace down the first branch, I guess. + headBranch->onReadyEvent.traceEvent(builder); + } } // ------------------------------------------------------------------- -ChainPromiseNode::ChainPromiseNode(Own innerParam) - : state(STEP1), inner(kj::mv(innerParam)) { +ChainPromiseNode::ChainPromiseNode(OwnPromiseNode innerParam, SourceLocation location) + : Event(location), state(STEP1), inner(kj::mv(innerParam)) { inner->setSelfPointer(&inner); - inner->onReady(*this); + inner->onReady(this); } ChainPromiseNode::~ChainPromiseNode() noexcept(false) {} -void ChainPromiseNode::onReady(Event& event) noexcept { +void ChainPromiseNode::destroy() { freePromise(this); } + +void ChainPromiseNode::onReady(Event* event) noexcept { switch (state) { case STEP1: - KJ_REQUIRE(onReadyEvent == nullptr, "onReady() can only be called once."); - onReadyEvent = &event; + onReadyEvent = event; return; case STEP2: inner->onReady(event); @@ -700,7 +2559,7 @@ void ChainPromiseNode::onReady(Event& event) noexcept { KJ_UNREACHABLE; } -void ChainPromiseNode::setSelfPointer(Own* selfPtr) noexcept { +void ChainPromiseNode::setSelfPointer(OwnPromiseNode* selfPtr) noexcept { if (state == STEP2) { *selfPtr = kj::mv(inner); // deletes this! selfPtr->get()->setSelfPointer(selfPtr); @@ -714,8 +2573,15 @@ void ChainPromiseNode::get(ExceptionOrValue& output) noexcept { return inner->get(output); } -PromiseNode* ChainPromiseNode::getInnerForTrace() { - return inner; +void ChainPromiseNode::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + if (stopAtNextEvent && state == STEP1) { + // In STEP1, we are an Event -- when the inner node resolves, it will arm *this* object. + // In STEP2, we are not an Event -- when the inner node resolves, it directly arms our parent + // event. + return; + } + + inner->tracePromise(builder, stopAtNextEvent); } Maybe> ChainPromiseNode::fire() { @@ -737,11 +2603,11 @@ Maybe> ChainPromiseNode::fire() { // There is an exception. If there is also a value, delete it. kj::runCatchingExceptions([&]() { intermediate.value = nullptr; }); // Now set step2 to a rejected promise. - inner = heap(kj::mv(*exception)); + inner = allocPromise(kj::mv(*exception)); } else KJ_IF_MAYBE(value, intermediate.value) { // There is a value and no exception. The value is itself a promise. Adopt it as our // step2. - inner = kj::mv(value->node); + inner = _::PromiseNode::from(kj::mv(*value)); } else { // We can only get here if inner->get() returned neither an exception nor a // value, which never actually happens. @@ -755,29 +2621,51 @@ Maybe> ChainPromiseNode::fire() { *selfPtr = kj::mv(inner); selfPtr->get()->setSelfPointer(selfPtr); if (onReadyEvent != nullptr) { - selfPtr->get()->onReady(*onReadyEvent); + selfPtr->get()->onReady(onReadyEvent); } // Return our self-pointer so that the caller takes care of deleting it. - return Own(kj::mv(chain)); + return Own(kj::Own(kj::mv(chain))); } else { inner->setSelfPointer(&inner); if (onReadyEvent != nullptr) { - inner->onReady(*onReadyEvent); + inner->onReady(onReadyEvent); } return nullptr; } } +void ChainPromiseNode::traceEvent(TraceBuilder& builder) { + switch (state) { + case STEP1: + if (inner.get() != nullptr) { + inner->tracePromise(builder, true); + } + if (!builder.full() && onReadyEvent != nullptr) { + onReadyEvent->traceEvent(builder); + } + break; + case STEP2: + // This probably never happens -- a trace being generated after the meat of fire() already + // executed. If it does, though, we probably can't do anything here. We don't know if + // `onReadyEvent` is still valid because we passed it on to the phase-2 promise, and tracing + // just `inner` would probably be confusing. Let's just do nothing. + break; + } +} + // ------------------------------------------------------------------- -ExclusiveJoinPromiseNode::ExclusiveJoinPromiseNode(Own left, Own right) - : left(*this, kj::mv(left)), right(*this, kj::mv(right)) {} +ExclusiveJoinPromiseNode::ExclusiveJoinPromiseNode( + OwnPromiseNode left, OwnPromiseNode right, SourceLocation location) + : left(*this, kj::mv(left), location), right(*this, kj::mv(right), location) {} ExclusiveJoinPromiseNode::~ExclusiveJoinPromiseNode() noexcept(false) {} -void ExclusiveJoinPromiseNode::onReady(Event& event) noexcept { +void ExclusiveJoinPromiseNode::destroy() { freePromise(this); } + +void ExclusiveJoinPromiseNode::onReady(Event* event) noexcept { onReadyEvent.init(event); } @@ -785,19 +2673,25 @@ void ExclusiveJoinPromiseNode::get(ExceptionOrValue& output) noexcept { KJ_REQUIRE(left.get(output) || right.get(output), "get() called before ready."); } -PromiseNode* ExclusiveJoinPromiseNode::getInnerForTrace() { - auto result = left.getInnerForTrace(); - if (result == nullptr) { - result = right.getInnerForTrace(); +void ExclusiveJoinPromiseNode::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + // TODO(debug): Maybe use __builtin_return_address to get the locations that called + // exclusiveJoin()? + + if (stopAtNextEvent) return; + + // Trace the left branch I guess. + if (left.dependency.get() != nullptr) { + left.dependency->tracePromise(builder, false); + } else if (right.dependency.get() != nullptr) { + right.dependency->tracePromise(builder, false); } - return result; } ExclusiveJoinPromiseNode::Branch::Branch( - ExclusiveJoinPromiseNode& joinNode, Own dependencyParam) - : joinNode(joinNode), dependency(kj::mv(dependencyParam)) { + ExclusiveJoinPromiseNode& joinNode, OwnPromiseNode dependencyParam, SourceLocation location) + : Event(location), joinNode(joinNode), dependency(kj::mv(dependencyParam)) { dependency->setSelfPointer(&dependency); - dependency->onReady(*this); + dependency->onReady(this); } ExclusiveJoinPromiseNode::Branch::~Branch() noexcept(false) {} @@ -812,32 +2706,41 @@ bool ExclusiveJoinPromiseNode::Branch::get(ExceptionOrValue& output) { } Maybe> ExclusiveJoinPromiseNode::Branch::fire() { - // Cancel the branch that didn't return first. Ignore exceptions caused by cancellation. - if (this == &joinNode.left) { - kj::runCatchingExceptions([&]() { joinNode.right.dependency = nullptr; }); + if (dependency) { + // Cancel the branch that didn't return first. Ignore exceptions caused by cancellation. + if (this == &joinNode.left) { + kj::runCatchingExceptions([&]() { joinNode.right.dependency = nullptr; }); + } else { + kj::runCatchingExceptions([&]() { joinNode.left.dependency = nullptr; }); + } + + joinNode.onReadyEvent.arm(); } else { - kj::runCatchingExceptions([&]() { joinNode.left.dependency = nullptr; }); + // The other branch already fired, and this branch was canceled. It's possible for both + // branches to fire if both were armed simultaneously. } - - joinNode.onReadyEvent.arm(); return nullptr; } -PromiseNode* ExclusiveJoinPromiseNode::Branch::getInnerForTrace() { - return dependency; +void ExclusiveJoinPromiseNode::Branch::traceEvent(TraceBuilder& builder) { + if (dependency.get() != nullptr) { + dependency->tracePromise(builder, true); + } + joinNode.onReadyEvent.traceEvent(builder); } // ------------------------------------------------------------------- ArrayJoinPromiseNodeBase::ArrayJoinPromiseNodeBase( - Array> promises, ExceptionOrValue* resultParts, size_t partSize) - : countLeft(promises.size()) { + Array promises, ExceptionOrValue* resultParts, size_t partSize, + SourceLocation location, ArrayJoinBehavior joinBehavior) + : joinBehavior(joinBehavior), countLeft(promises.size()) { // Make the branches. auto builder = heapArrayBuilder(promises.size()); for (uint i: indices(promises)) { ExceptionOrValue& output = *reinterpret_cast( reinterpret_cast(resultParts) + i * partSize); - builder.add(*this, kj::mv(promises[i]), output); + builder.add(*this, kj::mv(promises[i]), output, location); } branches = builder.finish(); @@ -847,70 +2750,106 @@ ArrayJoinPromiseNodeBase::ArrayJoinPromiseNodeBase( } ArrayJoinPromiseNodeBase::~ArrayJoinPromiseNodeBase() noexcept(false) {} -void ArrayJoinPromiseNodeBase::onReady(Event& event) noexcept { +void ArrayJoinPromiseNodeBase::onReady(Event* event) noexcept { onReadyEvent.init(event); } void ArrayJoinPromiseNodeBase::get(ExceptionOrValue& output) noexcept { - // If any of the elements threw exceptions, propagate them. for (auto& branch: branches) { - KJ_IF_MAYBE(exception, branch.getPart()) { + if (joinBehavior == ArrayJoinBehavior::LAZY) { + // This implements `joinPromises()`'s lazy evaluation semantics. + branch.dependency->get(branch.output); + } + + // If any of the elements threw exceptions, propagate them. + KJ_IF_MAYBE(exception, branch.output.exception) { output.addException(kj::mv(*exception)); } } + // We either failed fast, or waited for all promises. + KJ_DASSERT(countLeft == 0 || output.exception != nullptr); + if (output.exception == nullptr) { // No errors. The template subclass will need to fill in the result. getNoError(output); } } -PromiseNode* ArrayJoinPromiseNodeBase::getInnerForTrace() { - return branches.size() == 0 ? nullptr : branches[0].getInnerForTrace(); +void ArrayJoinPromiseNodeBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + // TODO(debug): Maybe use __builtin_return_address to get the locations that called + // joinPromises()? + + if (stopAtNextEvent) return; + + // Trace the first branch I guess. + if (branches != nullptr) { + branches[0].dependency->tracePromise(builder, false); + } } ArrayJoinPromiseNodeBase::Branch::Branch( - ArrayJoinPromiseNodeBase& joinNode, Own dependencyParam, ExceptionOrValue& output) - : joinNode(joinNode), dependency(kj::mv(dependencyParam)), output(output) { + ArrayJoinPromiseNodeBase& joinNode, OwnPromiseNode dependencyParam, ExceptionOrValue& output, + SourceLocation location) + : Event(location), joinNode(joinNode), dependency(kj::mv(dependencyParam)), output(output) { dependency->setSelfPointer(&dependency); - dependency->onReady(*this); + dependency->onReady(this); } ArrayJoinPromiseNodeBase::Branch::~Branch() noexcept(false) {} Maybe> ArrayJoinPromiseNodeBase::Branch::fire() { - if (--joinNode.countLeft == 0) { + if (--joinNode.countLeft == 0 && !joinNode.armed) { joinNode.onReadyEvent.arm(); + joinNode.armed = true; + } + + if (joinNode.joinBehavior == ArrayJoinBehavior::EAGER) { + // This implements `joinPromisesFailFast()`'s eager-evaluation semantics. + dependency->get(output); + if (output.exception != nullptr && !joinNode.armed) { + joinNode.onReadyEvent.arm(); + joinNode.armed = true; + } } - return nullptr; -} -_::PromiseNode* ArrayJoinPromiseNodeBase::Branch::getInnerForTrace() { - return dependency->getInnerForTrace(); + return nullptr; } -Maybe ArrayJoinPromiseNodeBase::Branch::getPart() { - dependency->get(output); - return kj::mv(output.exception); +void ArrayJoinPromiseNodeBase::Branch::traceEvent(TraceBuilder& builder) { + dependency->tracePromise(builder, true); + joinNode.onReadyEvent.traceEvent(builder); } ArrayJoinPromiseNode::ArrayJoinPromiseNode( - Array> promises, Array> resultParts) - : ArrayJoinPromiseNodeBase(kj::mv(promises), resultParts.begin(), sizeof(ExceptionOr<_::Void>)), + Array promises, Array> resultParts, + SourceLocation location, ArrayJoinBehavior joinBehavior) + : ArrayJoinPromiseNodeBase(kj::mv(promises), resultParts.begin(), sizeof(ExceptionOr<_::Void>), + location, joinBehavior), resultParts(kj::mv(resultParts)) {} ArrayJoinPromiseNode::~ArrayJoinPromiseNode() {} +void ArrayJoinPromiseNode::destroy() { freePromise(this); } + void ArrayJoinPromiseNode::getNoError(ExceptionOrValue& output) noexcept { output.as<_::Void>() = _::Void(); } } // namespace _ (private) -Promise joinPromises(Array>&& promises) { - return Promise(false, kj::heap<_::ArrayJoinPromiseNode>( - KJ_MAP(p, promises) { return kj::mv(p.node); }, - heapArray<_::ExceptionOr<_::Void>>(promises.size()))); +Promise joinPromises(Array>&& promises, SourceLocation location) { + return _::PromiseNode::to>(_::allocPromise<_::ArrayJoinPromiseNode>( + KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); }, + heapArray<_::ExceptionOr<_::Void>>(promises.size()), location, + _::ArrayJoinBehavior::LAZY)); +} + +Promise joinPromisesFailFast(Array>&& promises, SourceLocation location) { + return _::PromiseNode::to>(_::allocPromise<_::ArrayJoinPromiseNode>( + KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); }, + heapArray<_::ExceptionOr<_::Void>>(promises.size()), location, + _::ArrayJoinBehavior::EAGER)); } namespace _ { // (private) @@ -918,18 +2857,32 @@ namespace _ { // (private) // ------------------------------------------------------------------- EagerPromiseNodeBase::EagerPromiseNodeBase( - Own&& dependencyParam, ExceptionOrValue& resultRef) - : dependency(kj::mv(dependencyParam)), resultRef(resultRef) { + OwnPromiseNode&& dependencyParam, ExceptionOrValue& resultRef, SourceLocation location) + : Event(location), dependency(kj::mv(dependencyParam)), resultRef(resultRef) { dependency->setSelfPointer(&dependency); - dependency->onReady(*this); + dependency->onReady(this); } -void EagerPromiseNodeBase::onReady(Event& event) noexcept { +void EagerPromiseNodeBase::onReady(Event* event) noexcept { onReadyEvent.init(event); } -PromiseNode* EagerPromiseNodeBase::getInnerForTrace() { - return dependency; +void EagerPromiseNodeBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + // TODO(debug): Maybe use __builtin_return_address to get the locations that called + // eagerlyEvaluate()? But note that if a non-null exception handler was passed to it, that + // creates a TransformPromiseNode which will report the location anyhow. + + if (stopAtNextEvent) return; + if (dependency.get() != nullptr) { + dependency->tracePromise(builder, stopAtNextEvent); + } +} + +void EagerPromiseNodeBase::traceEvent(TraceBuilder& builder) { + if (dependency.get() != nullptr) { + dependency->tracePromise(builder, true); + } + onReadyEvent.traceEvent(builder); } Maybe> EagerPromiseNodeBase::fire() { @@ -946,13 +2899,253 @@ Maybe> EagerPromiseNodeBase::fire() { // ------------------------------------------------------------------- -void AdapterPromiseNodeBase::onReady(Event& event) noexcept { +void AdapterPromiseNodeBase::onReady(Event* event) noexcept { onReadyEvent.init(event); } +void AdapterPromiseNodeBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + // Maybe returning the address of get() will give us a function name with meaningful type + // information. + builder.add(getMethodStartAddress(implicitCast(*this), &PromiseNode::get)); +} + +void END_FULFILLER_STACK_START_LISTENER_STACK() {} +// Dummy symbol used when reporting how a PromiseFulfiller was destroyed without fulfilling the +// promise. We end up combining two stack traces into one and we use this as a separator. + +void WeakFulfillerBase::disposeImpl(void* pointer) const { + if (inner == nullptr) { + // Already detached. + delete this; + } else { + if (inner->isWaiting()) { + // Let's find out if there's an exception being thrown. If so, we'll use it to reject the + // promise. + inner->reject(getDestructionReason( + reinterpret_cast(&END_FULFILLER_STACK_START_LISTENER_STACK), + kj::Exception::Type::FAILED, __FILE__, __LINE__, + "PromiseFulfiller was destroyed without fulfilling the promise."_kj)); + } + inner = nullptr; + } +} + +} // namespace _ (private) + // ------------------------------------------------------------------- +namespace _ { // (private) + Promise IdentityFunc>::operator()() const { return READY_NOW; } } // namespace _ (private) + +// ------------------------------------------------------------------- + +#if KJ_HAS_COROUTINE + +namespace _ { // (private) + +CoroutineBase::CoroutineBase(stdcoro::coroutine_handle<> coroutine, ExceptionOrValue& resultRef, + SourceLocation location) + : Event(location), + coroutine(coroutine), + resultRef(resultRef) {} +CoroutineBase::~CoroutineBase() noexcept(false) { + readMaybe(maybeDisposalResults)->destructorRan = true; +} + +void CoroutineBase::unhandled_exception() { + // Pretty self-explanatory, we propagate the exception to the promise which owns us, unless + // we're being destroyed, in which case we propagate it back to our disposer. Note that all + // unhandled exceptions end up here, not just ones after the first co_await. + + auto exception = getCaughtExceptionAsKj(); + + KJ_IF_MAYBE(disposalResults, maybeDisposalResults) { + // Exception during coroutine destruction. Only record the first one. + if (disposalResults->exception == nullptr) { + disposalResults->exception = kj::mv(exception); + } + } else if (isWaiting()) { + // Exception during coroutine execution. + resultRef.addException(kj::mv(exception)); + scheduleResumption(); + } else { + // Okay, what could this mean? We've already been fulfilled or rejected, but we aren't being + // destroyed yet. The only possibility is that we are unwinding the coroutine frame due to a + // successful completion, and something in the frame threw. We can't already be rejected, + // because rejecting a coroutine involves throwing, which would have unwound the frame prior + // to setting `waiting = false`. + // + // Since we know we're unwinding due to a successful completion, we also know that whatever + // Event we may have armed has not yet fired, because we haven't had a chance to return to + // the event loop. + + // final_suspend() has not been called. + KJ_IASSERT(!coroutine.done()); + + // Since final_suspend() hasn't been called, whatever Event is waiting on us has not fired, + // and will see this exception. + resultRef.addException(kj::mv(exception)); + } +} + +void CoroutineBase::onReady(Event* event) noexcept { + onReadyEvent.init(event); +} + +void CoroutineBase::tracePromise(TraceBuilder& builder, bool stopAtNextEvent) { + if (stopAtNextEvent) return; + + KJ_IF_MAYBE(promise, promiseNodeForTrace) { + promise->tracePromise(builder, stopAtNextEvent); + } + + // Maybe returning the address of coroutine() will give us a function name with meaningful type + // information. (Narrator: It doesn't.) + builder.add(GetFunctorStartAddress<>::apply(coroutine)); +}; + +Maybe> CoroutineBase::fire() { + // Call Awaiter::await_resume() and proceed with the coroutine. Note that this will not destroy + // the coroutine if control flows off the end of it, because we return suspend_always() from + // final_suspend(). + // + // It's tempting to arrange to check for exceptions right now and reject the promise that owns + // us without resuming the coroutine, which would save us from throwing an exception when we + // already know where it's going. But, we don't really know: unlike in the KJ_NO_EXCEPTIONS + // case, the `co_await` might be in a try-catch block, so we have no choice but to resume and + // throw later. + // + // TODO(someday): If we ever support coroutines with -fno-exceptions, we'll need to reject the + // enclosing coroutine promise here, if the Awaiter's result is exceptional. + + promiseNodeForTrace = nullptr; + + coroutine.resume(); + + return nullptr; +} + +void CoroutineBase::traceEvent(TraceBuilder& builder) { + KJ_IF_MAYBE(promise, promiseNodeForTrace) { + promise->tracePromise(builder, true); + } + + // Maybe returning the address of coroutine() will give us a function name with meaningful type + // information. (Narrator: It doesn't.) + builder.add(GetFunctorStartAddress<>::apply(coroutine)); + + onReadyEvent.traceEvent(builder); +} + +void CoroutineBase::destroy() { + // Called by PromiseDisposer to delete the object. Basically a wrapper around coroutine.destroy() + // with some stuff to propagate exceptions appropriately. + + // Objects in the coroutine frame might throw from their destructors, so unhandled_exception() + // will need some way to communicate those exceptions back to us. Separately, we also want + // confirmation that our own ~Coroutine() destructor ran. To solve this, we put a + // DisposalResults object on the stack and set a pointer to it in the Coroutine object. This + // indicates to unhandled_exception() and ~Coroutine() where to store the results of the + // destruction operation. + DisposalResults disposalResults; + maybeDisposalResults = &disposalResults; + + // Need to save this while `unwindDetector` is still valid. + bool shouldRethrow = !unwindDetector.isUnwinding(); + + do { + // Clang's implementation of the Coroutines TS does not destroy the Coroutine object or + // deallocate the coroutine frame if a destructor of an object on the frame threw an + // exception. This is despite the fact that it delivered the exception to _us_ via + // unhandled_exception(). Anyway, it appears we can work around this by running + // coroutine.destroy() a second time. + // + // On Clang, `disposalResults.exception != nullptr` implies `!disposalResults.destructorRan`. + // We could optimize out the separate `destructorRan` flag if we verify that other compilers + // behave the same way. + coroutine.destroy(); + } while (!disposalResults.destructorRan); + + // WARNING: `this` is now a dangling pointer. + + KJ_IF_MAYBE(exception, disposalResults.exception) { + if (shouldRethrow) { + kj::throwFatalException(kj::mv(*exception)); + } else { + // An exception is already unwinding the stack, so throwing this secondary exception would + // call std::terminate(). + } + } +} + +CoroutineBase::AwaiterBase::AwaiterBase(OwnPromiseNode node): node(kj::mv(node)) {} +CoroutineBase::AwaiterBase::AwaiterBase(AwaiterBase&&) = default; +CoroutineBase::AwaiterBase::~AwaiterBase() noexcept(false) { + // Make sure it's safe to generate an async stack trace between now and when the Coroutine is + // destroyed. + KJ_IF_MAYBE(coroutineEvent, maybeCoroutineEvent) { + coroutineEvent->promiseNodeForTrace = nullptr; + } + + unwindDetector.catchExceptionsIfUnwinding([this]() { + // No need to check for a moved-from state, node will just ignore the nullification. + node = nullptr; + }); +} + +void CoroutineBase::AwaiterBase::getImpl(ExceptionOrValue& result, void* awaitedAt) { + node->get(result); + + KJ_IF_MAYBE(exception, result.exception) { + // Manually extend the stack trace with the instruction address where the co_await occurred. + exception->addTrace(awaitedAt); + + // Pass kj::maxValue for ignoreCount here so that `throwFatalException()` dosen't try to + // extend the stack trace. There's no point in extending the trace beyond the single frame we + // added above, as the rest of the trace will always be async framework stuff that no one wants + // to see. + kj::throwFatalException(kj::mv(*exception), kj::maxValue); + } +} + +bool CoroutineBase::AwaiterBase::awaitSuspendImpl(CoroutineBase& coroutineEvent) { + node->setSelfPointer(&node); + node->onReady(&coroutineEvent); + + if (coroutineEvent.hasSuspendedAtLeastOnce && coroutineEvent.isNext()) { + // The result is immediately ready and this coroutine is running on the event loop's stack, not + // a user code stack. Let's cancel our event and immediately resume. It's important that we + // don't perform this optimization if this is the first suspension, because our caller may + // depend on running code before this promise's continuations fire. + coroutineEvent.disarm(); + + // We can resume ourselves by returning false. This accomplishes the same thing as if we had + // returned true from await_ready(). + return false; + } else { + // Otherwise, we must suspend. Store a reference to the promise we're waiting on for tracing + // purposes; coroutineEvent.fire() and/or ~Adapter() will null this out. + coroutineEvent.promiseNodeForTrace = *node; + maybeCoroutineEvent = coroutineEvent; + + coroutineEvent.hasSuspendedAtLeastOnce = true; + + return true; + } +} + +// --------------------------------------------------------- +// Helpers for coCapture() + +void throwMultipleCoCaptureInvocations() { + KJ_FAIL_REQUIRE("Attempted to invoke CaptureForCoroutine functor multiple times"); +} + +} // namespace _ (private) + +#endif // KJ_HAS_COROUTINE + } // namespace kj diff --git a/c++/src/kj/async.h b/c++/src/kj/async.h index 5a9d9bdae7..564b517154 100644 --- a/c++/src/kj/async.h +++ b/c++/src/kj/async.h @@ -19,16 +19,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_ASYNC_H_ -#define KJ_ASYNC_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "async-prelude.h" -#include "exception.h" -#include "refcount.h" +#include +#include + +KJ_BEGIN_HEADER + +#ifndef KJ_USE_FIBERS + #if __BIONIC__ || __FreeBSD__ || __OpenBSD__ || KJ_NO_EXCEPTIONS + // These platforms don't support fibers. + #define KJ_USE_FIBERS 0 + #else + #define KJ_USE_FIBERS 1 + #endif +#else + #if KJ_NO_EXCEPTIONS && KJ_USE_FIBERS + #error "Fibers cannot be enabled when exceptions are disabled." + #endif +#endif namespace kj { @@ -44,12 +54,66 @@ class PromiseFulfiller; template struct PromiseFulfillerPair; +template +class FunctionParam; + template -using PromiseForResult = Promise<_::JoinPromises<_::ReturnType>>; +using PromiseForResult = _::ReducePromises<_::ReturnType>; // Evaluates to the type of Promise for the result of calling functor type Func with parameter type // T. If T is void, then the promise is for the result of calling Func with no arguments. If // Func itself returns a promise, the promises are joined, so you never get Promise>. +// ======================================================================================= + +class AsyncObject { + // You may optionally inherit privately from this to indicate that the type is a KJ async object, + // meaning it deals with KJ async I/O making it tied to a specific thread and event loop. This + // enables some additional debug checks, but does not otherwise have any effect on behavior as + // long as there are no bugs. + // + // (We prefer inheritance rather than composition here because inheriting an empty type adds zero + // size to the derived class.) + +public: + ~AsyncObject(); + +private: + KJ_NORETURN(static void failed() noexcept); +}; + +class DisallowAsyncDestructorsScope { + // Create this type on the stack in order to specify that during its scope, no KJ async objects + // should be destroyed. If AsyncObject's destructor is called in this scope, the process will + // crash with std::terminate(). + // + // This is useful as a sort of "sanitizer" to catch bugs. When tearing down an object that is + // intended to be passed between threads, you can set up one of these scopes to catch whether + // the object contains any async objects, which are not legal to pass across threads. + +public: + explicit DisallowAsyncDestructorsScope(kj::StringPtr reason); + ~DisallowAsyncDestructorsScope(); + KJ_DISALLOW_COPY_AND_MOVE(DisallowAsyncDestructorsScope); + +private: + kj::StringPtr reason; + DisallowAsyncDestructorsScope* previousValue; + + friend class AsyncObject; +}; + +class AllowAsyncDestructorsScope { + // Negates the effect of DisallowAsyncDestructorsScope. + +public: + AllowAsyncDestructorsScope(); + ~AllowAsyncDestructorsScope(); + KJ_DISALLOW_COPY_AND_MOVE(AllowAsyncDestructorsScope); + +private: + DisallowAsyncDestructorsScope* previousValue; +}; + // ======================================================================================= // Promises @@ -137,8 +201,8 @@ class Promise: protected _::PromiseBase { inline Promise(decltype(nullptr)) {} template - PromiseForResult then(Func&& func, ErrorFunc&& errorHandler = _::PropagateException()) - KJ_WARN_UNUSED_RESULT; + PromiseForResult then(Func&& func, ErrorFunc&& errorHandler = _::PropagateException(), + SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Register a continuation function to be executed when the promise completes. The continuation // (`func`) takes the promised value (an rvalue of type `T`) as its parameter. The continuation // may return a new value; `then()` itself returns a promise for the continuation's eventual @@ -199,11 +263,11 @@ class Promise: protected _::PromiseBase { // You must still wait on the returned promise if you want the task to execute. template - Promise catch_(ErrorFunc&& errorHandler) KJ_WARN_UNUSED_RESULT; + Promise catch_(ErrorFunc&& errorHandler, SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Equivalent to `.then(identityFunc, errorHandler)`, where `identifyFunc` is a function that // just returns its input. - T wait(WaitScope& waitScope); + T wait(WaitScope& waitScope, SourceLocation location = {}); // Run the event loop until the promise is fulfilled, then return its result. If the promise // is rejected, throw an exception. // @@ -232,22 +296,45 @@ class Promise: protected _::PromiseBase { // around them in arbitrary ways. Therefore, callers really need to know if a function they // are calling might wait(), and the `WaitScope&` parameter makes this clear. // - // TODO(someday): Implement fibers, and let them call wait() even when they are handling an - // event. + // Usually, there is only one `WaitScope` for each `EventLoop`, and it can only be used at the + // top level of the thread owning the loop. Calling `wait()` with this `WaitScope` is what + // actually causes the event loop to run at all. This top-level `WaitScope` cannot be used + // recursively, so cannot be used within an event callback. + // + // However, it is possible to obtain a `WaitScope` in lower-level code by using fibers. Use + // kj::startFiber() to start some code executing on an alternate call stack. That code will get + // its own `WaitScope` allowing it to operate in a synchronous style. In this case, `wait()` + // switches back to the main stack in order to run the event loop, returning to the fiber's stack + // once the awaited promise resolves. + + bool poll(WaitScope& waitScope, SourceLocation location = {}); + // Returns true if a call to wait() would complete without blocking, false if it would block. + // + // If the promise is not yet resolved, poll() will pump the event loop and poll for I/O in an + // attempt to resolve it. Only when there is nothing left to do will it return false. + // + // Generally, poll() is most useful in tests. Often, you may want to verify that a promise does + // not resolve until some specific event occurs. To do so, poll() the promise before the event to + // verify it isn't resolved, then trigger the event, then poll() again to verify that it resolves. + // The first poll() verifies that the promise doesn't resolve early, which would otherwise be + // hard to do deterministically. The second poll() allows you to check that the promise has + // resolved and avoid a wait() that might deadlock in the case that it hasn't. + // + // poll() is not supported in fibers; it will throw an exception. - ForkedPromise fork() KJ_WARN_UNUSED_RESULT; + ForkedPromise fork(SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Forks the promise, so that multiple different clients can independently wait on the result. // `T` must be copy-constructable for this to work. Or, in the special case where `T` is // `Own`, `U` must have a method `Own addRef()` which returns a new reference to the same // (or an equivalent) object (probably implemented via reference counting). - _::SplitTuplePromise split(); + _::SplitTuplePromise split(SourceLocation location = {}); // Split a promise for a tuple into a tuple of promises. // // E.g. if you have `Promise>`, `split()` returns // `kj::Tuple, Promise>`. - Promise exclusiveJoin(Promise&& other) KJ_WARN_UNUSED_RESULT; + Promise exclusiveJoin(Promise&& other, SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Return a new promise that resolves when either the original promise resolves or `other` // resolves (whichever comes first). The promise that didn't resolve first is canceled. @@ -262,14 +349,15 @@ class Promise: protected _::PromiseBase { // runs -- after calling then(), use attach() to add necessary objects to the result. template - Promise eagerlyEvaluate(ErrorFunc&& errorHandler) KJ_WARN_UNUSED_RESULT; - Promise eagerlyEvaluate(decltype(nullptr)) KJ_WARN_UNUSED_RESULT; + Promise eagerlyEvaluate(ErrorFunc&& errorHandler, SourceLocation location = {}) + KJ_WARN_UNUSED_RESULT; + Promise eagerlyEvaluate(decltype(nullptr), SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; // Force eager evaluation of this promise. Use this if you are going to hold on to the promise // for awhile without consuming the result, but you want to make sure that the system actually // processes it. // // `errorHandler` is a function that takes `kj::Exception&&`, like the second parameter to - // `then()`, except that it must return void. We make you specify this because otherwise it's + // `then()`, or the parameter to `catch_()`. We make you specify this because otherwise it's // easy to forget to handle errors in a promise that you never use. You may specify nullptr for // the error handler if you are sure that ignoring errors is fine, or if you know that you'll // eventually wait on the promise somewhere. @@ -292,24 +380,10 @@ class Promise: protected _::PromiseBase { // This method does NOT consume the promise as other methods do. private: - Promise(bool, Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {} + Promise(bool, _::OwnPromiseNode&& node): PromiseBase(kj::mv(node)) {} // Second parameter prevent ambiguity with immediate-value constructor. - template - friend class Promise; - friend class EventLoop; - template - friend Promise newAdaptedPromise(Params&&... adapterConstructorParams); - template - friend PromiseFulfillerPair newPromiseAndFulfiller(); - template - friend class _::ForkHub; - friend class _::TaskSetImpl; - friend Promise _::yield(); - friend class _::NeverDone; - template - friend Promise> joinPromises(Array>&& promises); - friend Promise joinPromises(Array>&& promises); + friend class _::PromiseNode; }; template @@ -323,6 +397,9 @@ class ForkedPromise { Promise addBranch(); // Add a new branch to the fork. The branch is equivalent to the original promise. + bool hasBranches(); + // Returns true if there are any branches that haven't been canceled. + private: Own<_::ForkHub<_::FixVoid>> hub; @@ -332,7 +409,7 @@ class ForkedPromise { friend class EventLoop; }; -constexpr _::Void READY_NOW = _::Void(); +constexpr _::ReadyNow READY_NOW = _::ReadyNow(); // Use this when you need a Promise that is already fulfilled -- this value can be implicitly // cast to `Promise`. @@ -341,6 +418,11 @@ constexpr _::NeverDone NEVER_DONE = _::NeverDone(); // implicitly converted to any promise type. You may also call `NEVER_DONE.wait()` to wait // forever (useful for servers). +template +Promise constPromise(); +// Construct a Promise which resolves to the given constant value. This function is equivalent to +// `Promise(value)` except that it avoids an allocation. + template PromiseForResult evalLater(Func&& func) KJ_WARN_UNUSED_RESULT; // Schedule for the given zero-parameter function to be executed in the event loop at some @@ -365,9 +447,127 @@ PromiseForResult evalNow(Func&& func) KJ_WARN_UNUSED_RESULT; // If `func()` throws an exception, the exception is caught and wrapped in a promise -- this is the // main reason why `evalNow()` is useful. +template +PromiseForResult evalLast(Func&& func) KJ_WARN_UNUSED_RESULT; +// Like `evalLater()`, except that the function doesn't run until the event queue is otherwise +// completely empty and the thread is about to suspend waiting for I/O. +// +// This is useful when you need to perform some disruptive action and you want to make sure that +// you don't interrupt some other task between two .then() continuations. For example, say you want +// to cancel a read() operation on a socket and know for sure that if any bytes were read, you saw +// them. It could be that a read() has completed and bytes have been transferred to the target +// buffer, but the .then() callback that handles the read result hasn't executed yet. If you +// cancel the promise at this inopportune moment, the bytes in the buffer are lost. If you do +// evalLast(), then you can be sure that any pending .then() callbacks had a chance to finish out +// and if you didn't receive the read result yet, then you know nothing has been read, and you can +// simply drop the promise. +// +// If evalLast() is called multiple times, functions are executed in LIFO order. If the first +// callback enqueues new events, then latter callbacks will not execute until those events are +// drained. + +ArrayPtr getAsyncTrace(ArrayPtr space); +kj::String getAsyncTrace(); +// If the event loop is currently running in this thread, get a trace back through the promise +// chain leading to the currently-executing event. The format is the same as kj::getStackTrace() +// from exception.c++. + +template +PromiseForResult retryOnDisconnect(Func&& func) KJ_WARN_UNUSED_RESULT; +// Promises to run `func()` asynchronously, retrying once if it fails with a DISCONNECTED exception. +// If the retry also fails, the exception is passed through. +// +// `func()` should return a `Promise`. `retryOnDisconnect(func)` returns the same promise, except +// with the retry logic added. + +template +PromiseForResult startFiber( + size_t stackSize, Func&& func, SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; +// Executes `func()` in a fiber, returning a promise for the eventual reseult. `func()` will be +// passed a `WaitScope&` as its parameter, allowing it to call `.wait()` on promises. Thus, `func()` +// can be written in a synchronous, blocking style, instead of using `.then()`. This is often much +// easier to write and read, and may even be significantly faster if it allows the use of stack +// allocation rather than heap allocation. +// +// However, fibers have a major disadvantage: memory must be allocated for the fiber's call stack. +// The entire stack must be allocated at once, making it necessary to choose a stack size upfront +// that is big enough for whatever the fiber needs to do. Estimating this is often difficult. That +// said, over-estimating is not too terrible since pages of the stack will actually be allocated +// lazily when first accessed; actual memory usage will correspond to the "high watermark" of the +// actual stack usage. That said, this lazy allocation forces page faults, which can be quite slow. +// Worse, freeing a stack forces a TLB flush and shootdown -- all currently-executing threads will +// have to be interrupted to flush their CPU cores' TLB caches. +// +// In short, when performance matters, you should try to avoid creating fibers very frequently. + +class FiberPool final { + // A freelist pool of fibers with a set stack size. This improves CPU usage with fibers at + // the expense of memory usage. Fibers in this pool will always use the max amount of memory + // used until the pool is destroyed. + +public: + explicit FiberPool(size_t stackSize); + ~FiberPool() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(FiberPool); + + void setMaxFreelist(size_t count); + // Set the maximum number of stacks to add to the freelist. If the freelist is full, stacks will + // be deleted rather than returned to the freelist. + + void useCoreLocalFreelists(); + // EXPERIMENTAL: Call to tell FiberPool to try to use core-local stack freelists, which + // in theory should increase L1/L2 cache efficacy for freelisted stacks. In practice, as of + // this writing, no performance advantage has yet been demonstrated. Note that currently this + // feature is only supported on Linux (the flag has no effect on other operating systems). + + template + PromiseForResult startFiber( + Func&& func, SourceLocation location = {}) const KJ_WARN_UNUSED_RESULT; + // Executes `func()` in a fiber from this pool, returning a promise for the eventual result. + // `func()` will be passed a `WaitScope&` as its parameter, allowing it to call `.wait()` on + // promises. Thus, `func()` can be written in a synchronous, blocking style, instead of + // using `.then()`. This is often much easier to write and read, and may even be significantly + // faster if it allows the use of stack allocation rather than heap allocation. + + void runSynchronously(kj::FunctionParam func) const; + // Use one of the stacks in the pool to synchronously execute func(), returning the result that + // func() returns. This is not the usual use case for fibers, but can be a nice optimization + // in programs that have many threads that mostly only need small stacks, but occasionally need + // a much bigger stack to run some deeply recursive algorithm. If the algorithm is run on each + // thread's normal call stack, then every thread's stack will tend to grow to be very big + // (usually, stacks automatically grow as needed, but do not shrink until the thread exits + // completely). If the thread can share a small set of big stacks that they use only when calling + // the deeply recursive algorithm, and use small stacks for everything else, overall memory usage + // is reduced. + // + // TODO(someday): If func() returns a value, return it from runSynchronously? Current use case + // doesn't need it. + + size_t getFreelistSize() const; + // Get the number of stacks currently in the freelist. Does not count stacks that are active. + +private: + class Impl; + Own impl; + + friend class _::FiberStack; + friend class _::FiberBase; +}; + +template +Promise> joinPromises(Array>&& promises, SourceLocation location = {}); +// Join an array of promises into a promise for an array. Trailing continuations on promises are not +// evaluated until all promises have settled. Exceptions are propagated only after the last promise +// has settled. +// +// TODO(cleanup): It is likely that `joinPromisesFailFast()` is what everyone should be using. +// Deprecate this function. + template -Promise> joinPromises(Array>&& promises); -// Join an array of promises into a promise for an array. +Promise> joinPromisesFailFast(Array>&& promises, SourceLocation location = {}); +// Join an array of promises into a promise for an array. Trailing continuations on promises are +// evaluated eagerly. If any promise results in an exception, the exception is immediately +// propagated to the returned join promise. // ======================================================================================= // Hack for creating a lambda that holds an owned pointer. @@ -389,6 +589,10 @@ class CaptureByMove { MovedParam param; }; +template +inline CaptureByMove> mvCapture(MovedParam&& param, Func&& func) + KJ_DEPRECATED("Use C++14 generalized captures instead."); + template inline CaptureByMove> mvCapture(MovedParam&& param, Func&& func) { // Hack to create a "lambda" which captures a variable by moving it rather than copying or @@ -404,11 +608,130 @@ inline CaptureByMove> mvCapture(MovedParam&& param, Func return CaptureByMove>(kj::fwd(func), kj::mv(param)); } +// ======================================================================================= +// Hack for safely using a lambda as a coroutine. + +#if KJ_HAS_COROUTINE + +namespace _ { + +void throwMultipleCoCaptureInvocations(); + +template +struct CaptureForCoroutine { + kj::Maybe maybeFunctor; + + explicit CaptureForCoroutine(Functor&& f) : maybeFunctor(kj::mv(f)) {} + + template + static auto coInvoke(Functor functor, Args&&... args) + -> decltype(functor(kj::fwd(args)...)) { + // Since the functor is now in the local scope and no longer a member variable, it will be + // persisted in the coroutine state. + + // Note that `co_await functor(...)` can still return `void`. It just happens that + // `co_return voidReturn();` is explicitly allowed. + co_return co_await functor(kj::fwd(args)...); + } + + template + auto operator()(Args&&... args) { + if (maybeFunctor == nullptr) { + throwMultipleCoCaptureInvocations(); + } + auto localFunctor = kj::mv(*kj::_::readMaybe(maybeFunctor)); + maybeFunctor = nullptr; + return coInvoke(kj::mv(localFunctor), kj::fwd(args)...); + } +}; + +} // namespace _ + +template +auto coCapture(Functor&& f) { + // Assuming `f()` returns a Promise `p`, wrap `f` in such a way that it will outlive its + // returned Promise. Note that the returned object may only be invoked once. + // + // This function is meant to help address this pain point with functors that return a coroutine: + // https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rcoro-capture + // + // The two most common patterns where this may be useful look like so: + // ``` + // void addTask(Value myValue) { + // auto myFun = [myValue]() -> kj::Promise { + // ... + // co_return; + // }; + // tasks.add(myFun()); + // } + // ``` + // and + // ``` + // kj::Promise afterPromise(kj::Promise promise, Value myValue) { + // auto myFun = [myValue]() -> kj::Promise { + // ... + // co_return; + // }; + // return promise.then(kj::mv(myFun)); + // } + // ``` + // + // Note that there are potentially more optimal alternatives to both of these patterns: + // ``` + // void addTask(Value myValue) { + // auto myFun = [](auto myValue) -> kj::Promise { + // ... + // co_return; + // }; + // tasks.add(myFun(myValue)); + // } + // ``` + // and + // ``` + // kj::Promise afterPromise(kj::Promise promise, Value myValue) { + // auto myFun = [&]() -> kj::Promise { + // ... + // co_return; + // }; + // co_await promise; + // co_await myFun(); + // co_return; + // } + // ``` + // + // For situations where you are trying to capture a specific local variable, kj::mvCapture() can + // also be useful: + // ``` + // kj::Promise reactToPromise(kj::Promise promise) { + // BigA a; + // TinyB b; + // + // doSomething(a, b); + // return promise.then(kj::mvCapture(b, [](TinyB b, MyType type) -> kj::Promise { + // ... + // co_return; + // }); + // } + // ``` + + return _::CaptureForCoroutine(kj::mv(f)); +} + +#endif // KJ_HAS_COROUTINE + // ======================================================================================= // Advanced promise construction +class PromiseRejector: private AsyncObject { + // Superclass of PromiseFulfiller containing the non-typed methods. Useful when you only really + // need to be able to reject a promise, and you need to operate on fulfillers of different types. +public: + virtual void reject(Exception&& exception) = 0; + virtual bool isWaiting() = 0; +}; + template -class PromiseFulfiller { +class PromiseFulfiller: public PromiseRejector { // A callback which can be used to fulfill a promise. Only the first call to fulfill() or // reject() matters; subsequent calls are ignored. @@ -432,7 +755,7 @@ class PromiseFulfiller { }; template <> -class PromiseFulfiller { +class PromiseFulfiller: public PromiseRejector { // Specialization of PromiseFulfiller for void promises. See PromiseFulfiller. public: @@ -448,7 +771,7 @@ class PromiseFulfiller { }; template -Promise newAdaptedPromise(Params&&... adapterConstructorParams); +_::ReducePromises newAdaptedPromise(Params&&... adapterConstructorParams); // Creates a new promise which owns an instance of `Adapter` which encapsulates the operation // that will eventually fulfill the promise. This is primarily useful for adapting non-KJ // asynchronous APIs to use promises. @@ -470,12 +793,12 @@ Promise newAdaptedPromise(Params&&... adapterConstructorParams); template struct PromiseFulfillerPair { - Promise<_::JoinPromises> promise; + _::ReducePromises promise; Own> fulfiller; }; template -PromiseFulfillerPair newPromiseAndFulfiller(); +PromiseFulfillerPair newPromiseAndFulfiller(SourceLocation location = {}); // Construct a Promise and a separate PromiseFulfiller which can be used to fulfill the promise. // If the PromiseFulfiller is destroyed before either of its methods are called, the Promise is // implicitly rejected. @@ -488,10 +811,164 @@ PromiseFulfillerPair newPromiseAndFulfiller(); // fulfiller will be of type `PromiseFulfiller>`. Thus you pass a `Promise` to the // `fulfill()` callback, and the promises are chained. +template +class CrossThreadPromiseFulfiller: public kj::PromiseFulfiller { + // Like PromiseFulfiller but the methods are `const`, indicating they can safely be called + // from another thread. + +public: + virtual void fulfill(T&& value) const = 0; + virtual void reject(Exception&& exception) const = 0; + virtual bool isWaiting() const = 0; + + void fulfill(T&& value) override { return constThis()->fulfill(kj::fwd(value)); } + void reject(Exception&& exception) override { return constThis()->reject(kj::mv(exception)); } + bool isWaiting() override { return constThis()->isWaiting(); } + +private: + const CrossThreadPromiseFulfiller* constThis() { return this; } +}; + +template <> +class CrossThreadPromiseFulfiller: public kj::PromiseFulfiller { + // Specialization of CrossThreadPromiseFulfiller for void promises. See + // CrossThreadPromiseFulfiller. + +public: + virtual void fulfill(_::Void&& value = _::Void()) const = 0; + virtual void reject(Exception&& exception) const = 0; + virtual bool isWaiting() const = 0; + + void fulfill(_::Void&& value) override { return constThis()->fulfill(kj::mv(value)); } + void reject(Exception&& exception) override { return constThis()->reject(kj::mv(exception)); } + bool isWaiting() override { return constThis()->isWaiting(); } + +private: + const CrossThreadPromiseFulfiller* constThis() { return this; } +}; + +template +struct PromiseCrossThreadFulfillerPair { + _::ReducePromises promise; + Own> fulfiller; +}; + +template +PromiseCrossThreadFulfillerPair newPromiseAndCrossThreadFulfiller(); +// Like `newPromiseAndFulfiller()`, but the fulfiller is allowed to be invoked from any thread, +// not just the one that called this method. Note that the Promise is still tied to the calling +// thread's event loop and *cannot* be used from another thread -- only the PromiseFulfiller is +// cross-thread. + +// ======================================================================================= +// Canceler + +class Canceler: private AsyncObject { + // A Canceler can wrap some set of Promises and then forcefully cancel them on-demand, or + // implicitly when the Canceler is destroyed. + // + // The cancellation is done in such a way that once cancel() (or the Canceler's destructor) + // returns, it's guaranteed that the promise has already been canceled and destroyed. This + // guarantee is important for enforcing ownership constraints. For example, imagine that Alice + // calls a method on Bob that returns a Promise. That Promise encapsulates a task that uses Bob's + // internal state. But, imagine that Alice does not own Bob, and indeed Bob might be destroyed + // at random without Alice having canceled the promise. In this case, it is necessary for Bob to + // ensure that the promise will be forcefully canceled. Bob can do this by constructing a + // Canceler and using it to wrap promises before returning them to callers. When Bob is + // destroyed, the Canceler is destroyed too, and all promises Bob wrapped with it throw errors. + // + // Note that another common strategy for cancellation is to use exclusiveJoin() to join a promise + // with some "cancellation promise" which only resolves if the operation should be canceled. The + // cancellation promise could itself be created by newPromiseAndFulfiller(), and thus + // calling the PromiseFulfiller cancels the operation. There is a major problem with this + // approach: upon invoking the fulfiller, an arbitrary amount of time may pass before the + // exclusive-joined promise actually resolves and cancels its other fork. During that time, the + // task might continue to execute. If it holds pointers to objects that have been destroyed, this + // might cause segfaults. Thus, it is safer to use a Canceler. + +public: + inline Canceler() {} + ~Canceler() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(Canceler); + + template + Promise wrap(Promise promise) { + return newAdaptedPromise>(*this, kj::mv(promise)); + } + + void cancel(StringPtr cancelReason); + void cancel(const Exception& exception); + // Cancel all previously-wrapped promises that have not already completed, causing them to throw + // the given exception. If you provide just a description message instead of an exception, then + // an exception object will be constructed from it -- but only if there are requests to cancel. + + void release(); + // Releases previously-wrapped promises, so that they will not be canceled regardless of what + // happens to this Canceler. + + bool isEmpty() const { return list == nullptr; } + // Indicates if any previously-wrapped promises are still executing. (If this returns true, then + // cancel() would be a no-op.) + +private: + class AdapterBase { + public: + AdapterBase(Canceler& canceler); + ~AdapterBase() noexcept(false); + + virtual void cancel(Exception&& e) = 0; + + void unlink(); + + private: + Maybe&> prev; + Maybe next; + friend class Canceler; + }; + + template + class AdapterImpl: public AdapterBase { + public: + AdapterImpl(PromiseFulfiller& fulfiller, + Canceler& canceler, Promise inner) + : AdapterBase(canceler), + fulfiller(fulfiller), + inner(inner.then( + [&fulfiller](T&& value) { fulfiller.fulfill(kj::mv(value)); }, + [&fulfiller](Exception&& e) { fulfiller.reject(kj::mv(e)); }) + .eagerlyEvaluate(nullptr)) {} + + void cancel(Exception&& e) override { + fulfiller.reject(kj::mv(e)); + inner = nullptr; + } + + private: + PromiseFulfiller& fulfiller; + Promise inner; + }; + + Maybe list; +}; + +template <> +class Canceler::AdapterImpl: public AdapterBase { +public: + AdapterImpl(kj::PromiseFulfiller& fulfiller, + Canceler& canceler, kj::Promise inner); + void cancel(kj::Exception&& e) override; + // These must be defined in async.c++ to prevent translation units compiled by MSVC from trying to + // link with symbols defined in async.c++ merely because they included async.h. + +private: + kj::PromiseFulfiller& fulfiller; + kj::Promise inner; +}; + // ======================================================================================= // TaskSet -class TaskSet { +class TaskSet: private AsyncObject { // Holds a collection of Promises and ensures that each executes to completion. Memory // associated with each promise is automatically freed when the promise completes. Destroying // the TaskSet itself automatically cancels all unfinished promises. @@ -499,7 +976,7 @@ class TaskSet { // This is useful for "daemon" objects that perform background tasks which aren't intended to // fulfill any particular external promise, but which may need to be canceled (and thus can't // use `Promise::detach()`). The daemon object holds a TaskSet to collect these tasks it is - // working on. This way, if the daemon itself is destroyed, the TaskSet is detroyed as well, + // working on. This way, if the daemon itself is destroyed, the TaskSet is destroyed as well, // and everything the daemon is doing is canceled. public: @@ -508,9 +985,9 @@ class TaskSet { virtual void taskFailed(kj::Exception&& exception) = 0; }; - TaskSet(ErrorHandler& errorHandler); - // `loop` will be used to wait on promises. `errorHandler` will be executed any time a task - // throws an exception, and will execute within the given EventLoop. + TaskSet(ErrorHandler& errorHandler, SourceLocation location = {}); + // `errorHandler` will be executed any time a task throws an exception, and will execute within + // the given EventLoop. ~TaskSet() noexcept(false); @@ -519,10 +996,146 @@ class TaskSet { kj::String trace(); // Return debug info about all promises currently in the TaskSet. + bool isEmpty() { return tasks == nullptr; } + // Check if any tasks are running. + + Promise onEmpty(); + // Returns a promise that fulfills the next time the TaskSet is empty. Only one such promise can + // exist at a time. + + void clear(); + // Cancel all tasks. + // + // As always, it is not safe to cancel the task that is currently running, so you could not call + // this from inside a task in the TaskSet. However, it IS safe to call this from the + // `taskFailed()` callback. + // + // Calling this will always trigger onEmpty(), if anyone is listening. + +private: + class Task; + using OwnTask = Own; + + TaskSet::ErrorHandler& errorHandler; + Maybe tasks; + Maybe>> emptyFulfiller; + SourceLocation location; +}; + +// ======================================================================================= +// Cross-thread execution. + +class Executor { + // Executes code on another thread's event loop. + // + // Use `kj::getCurrentThreadExecutor()` to get an executor that schedules calls on the current + // thread's event loop. You may then pass the reference to other threads to enable them to call + // back to this one. + +public: + Executor(EventLoop& loop, Badge); + ~Executor() noexcept(false); + + virtual kj::Own addRef() const = 0; + // Add a reference to this Executor. The Executor will not be destroyed until all references are + // dropped. This uses atomic refcounting for thread-safety. + // + // Use this when you can't guarantee that the target thread's event loop won't concurrently exit + // (including due to an uncaught exception!) while another thread is still using the Executor. + // Otherwise, the Executor object is destroyed when the owning event loop exits. + // + // If the target event loop has exited, then `execute{Async,Sync}` will throw DISCONNECTED + // exceptions. + + bool isLive() const; + // Returns true if the remote event loop still exists, false if it has been destroyed. In the + // latter case, `execute{Async,Sync}()` will definitely throw. Of course, if this returns true, + // it could still change to false at any moment, and `execute{Async,Sync}()` could still throw as + // a result. + // + // TODO(cleanup): Should we have tryExecute{Async,Sync}() that return Maybes that are null if + // the remote event loop exited? Currently there are multiple known use cases that check + // isLive() after catching a DISCONNECTED exception to decide whether it is due to the executor + // exiting, and then handling that case. This is borderline in violation of KJ exception + // philosophy, but right now I'm not excited about the extra template metaprogramming needed + // for "try" versions... + + template + PromiseForResult executeAsync(Func&& func, SourceLocation location = {}) const; + // Call from any thread to request that the given function be executed on the executor's thread, + // returning a promise for the result. + // + // The Promise returned by executeAsync() belongs to the requesting thread, not the executor + // thread. Hence, for example, continuations added to this promise with .then() will execute in + // the requesting thread. + // + // If func() itself returns a Promise, that Promise is *not* returned verbatim to the requesting + // thread -- after all, Promise objects cannot be used cross-thread. Instead, the executor thread + // awaits the promise. Once it resolves to a final result, that result is transferred to the + // requesting thread, resolving the promise that executeAsync() returned earlier. + // + // `func` will be destroyed in the requesting thread, after the final result has been returned + // from the executor thread. This means that it is safe for `func` to capture objects that cannot + // safely be destroyed from another thread. It is also safe for `func` to be an lvalue reference, + // so long as the functor remains live until the promise completes or is canceled, and the + // function is thread-safe. + // + // Of course, the body of `func` must be careful that any access it makes on these objects is + // safe cross-thread. For example, it must not attempt to access Promise-related objects + // cross-thread; you cannot create a `PromiseFulfiller` in one thread and then `fulfill()` it + // from another. Unfortunately, the usual convention of using const-correctness to enforce + // thread-safety does not work here, because applications can often ensure that `func` has + // exclusive access to captured objects, and thus can safely mutate them even in non-thread-safe + // ways; the const qualifier is not sufficient to express this. + // + // The final return value of `func` is transferred between threads, and hence is constructed and + // destroyed in separate threads. It is the app's responsibility to make sure this is OK. + // Alternatively, the app can perhaps arrange to send the return value back to the original + // thread for destruction, if needed. + // + // If the requesting thread destroys the returned Promise, the destructor will block waiting for + // the executor thread to acknowledge cancellation. This ensures that `func` can be destroyed + // before the Promise's destructor returns. + // + // Multiple calls to executeAsync() from the same requesting thread to the same target thread + // will be delivered in the same order in which they were requested. (However, if func() returns + // a promise, delivery of subsequent calls is not blocked on that promise. In other words, this + // call provides E-Order in the same way as Cap'n Proto.) + + template + _::UnwrapPromise> executeSync( + Func&& func, SourceLocation location = {}) const; + // Schedules `func()` to execute on the executor thread, and then blocks the requesting thread + // until `func()` completes. If `func()` returns a Promise, then the wait will continue until + // that promise resolves, and the final result will be returned to the requesting thread. + // + // The requesting thread does not need to have an EventLoop. If it does have an EventLoop, that + // loop will *not* execute while the thread is blocked. This method is particularly useful to + // allow non-event-loop threads to perform I/O via a separate event-loop thread. + // + // As with `executeAsync()`, `func` is always destroyed on the requesting thread, after the + // executor thread has signaled completion. The return value is transferred between threads. + private: - Own<_::TaskSetImpl> impl; + struct Impl; + Own impl; + // To avoid including mutex.h... + + friend class EventLoop; + friend class _::XThreadEvent; + friend class _::XThreadPaf; + + void send(_::XThreadEvent& event, bool sync) const; + void wait(); + bool poll(); + + EventLoop& getLoop() const; }; +const Executor& getCurrentThreadExecutor(); +// Get the executor for the current thread's event loop. This reference can then be passed to other +// threads. + // ======================================================================================= // The EventLoop class @@ -627,8 +1240,19 @@ class EventLoop { bool isRunnable(); // Returns true if run() would currently do anything, or false if the queue is empty. + const Executor& getExecutor(); + // Returns an Executor that can be used to schedule events on this EventLoop from another thread. + // + // Use the global function kj::getCurrentThreadExecutor() to get the current thread's EventLoop's + // Executor. + // + // Note that this is only needed for cross-thread scheduling. To schedule code to run later in + // the current thread, use `kj::evalLater()`, which will be more efficient. + private: - EventPort& port; + kj::Maybe port; + // If null, this thread doesn't receive I/O events from the OS. It can potentially receive + // events from other threads via the Executor. bool running = false; // True while looping -- wait() is then not allowed. @@ -639,19 +1263,35 @@ class EventLoop { _::Event* head = nullptr; _::Event** tail = &head; _::Event** depthFirstInsertPoint = &head; + _::Event** breadthFirstInsertPoint = &head; + + kj::Maybe> executor; + // Allocated the first time getExecutor() is requested, making cross-thread request possible. + + Own daemons; - Own<_::TaskSetImpl> daemons; + _::Event* currentlyFiring = nullptr; bool turn(); void setRunnable(bool runnable); void enterScope(); void leaveScope(); + void wait(); + void poll(); + friend void _::detach(kj::Promise&& promise); - friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, - WaitScope& waitScope); + friend void _::waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, + WaitScope& waitScope, SourceLocation location); + friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location); friend class _::Event; friend class WaitScope; + friend class Executor; + friend class _::XThreadEvent; + friend class _::XThreadPaf; + friend class _::FiberBase; + friend class _::FiberStack; + friend ArrayPtr getAsyncTrace(ArrayPtr space); }; class WaitScope { @@ -665,18 +1305,80 @@ class WaitScope { public: inline explicit WaitScope(EventLoop& loop): loop(loop) { loop.enterScope(); } - inline ~WaitScope() { loop.leaveScope(); } - KJ_DISALLOW_COPY(WaitScope); + inline ~WaitScope() { if (fiber == nullptr) loop.leaveScope(); } + KJ_DISALLOW_COPY_AND_MOVE(WaitScope); + + uint poll(uint maxTurnCount = maxValue); + // Pumps the event queue and polls for I/O until there's nothing left to do (without blocking) or + // the maximum turn count has been reached. Returns the number of events popped off the event + // queue. + // + // Not supported in fibers. + + void setBusyPollInterval(uint count) { busyPollInterval = count; } + // Set the maximum number of events to run in a row before calling poll() on the EventPort to + // check for new I/O. + // + // This has no effect when used in a fiber. + + void runEventCallbacksOnStackPool(kj::Maybe pool) { runningStacksPool = pool; } + // Arranges to switch stacks while event callbacks are executing. This is an optimization that + // is useful for programs that use extremely high thread counts, where each thread has its own + // event loop, but each thread has relatively low event throughput, i.e. each thread spends + // most of its time waiting for I/O. Normally, the biggest problem with having lots of threads + // is that each thread must allocate a stack, and stacks can take a lot of memory if the + // application commonly makes deep calls. But, most of that stack space is only needed while + // the thread is executing, not while it's sleeping. So, if threads only switch to a big stack + // during execution, switching back when it's time to sleep, and if those stacks are freelisted + // so that they can be shared among threads, then a lot of memory is saved. + // + // We use the `FiberPool` type here because it implements a freelist of stacks, which is exactly + // what we happen to want! In our case, though, we don't use those stacks to implement fibers; + // we use them as the main thread stack. + // + // This has no effect if this WaitScope itself is for a fiber. + // + // Pass `nullptr` as the parameter to go back to running events on the main stack. + + void cancelAllDetached(); + // HACK: Immediately cancel all detached promises. + // + // New code should not use detached promises, and therefore should not need this. + // + // This method exists to help existing code deal with the problems of detached promises, + // especially at teardown time. + // + // This method may be removed in the future. private: EventLoop& loop; + uint busyPollInterval = kj::maxValue; + + kj::Maybe<_::FiberBase&> fiber; + kj::Maybe runningStacksPool; + + explicit WaitScope(EventLoop& loop, _::FiberBase& fiber) + : loop(loop), fiber(fiber) {} + + template + inline void runOnStackPool(Func&& func) { + KJ_IF_MAYBE(pool, runningStacksPool) { + pool->runSynchronously(kj::fwd(func)); + } else { + func(); + } + } + friend class EventLoop; - friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, - WaitScope& waitScope); + friend class _::FiberBase; + friend void _::waitImpl(_::OwnPromiseNode&& node, _::ExceptionOrValue& result, + WaitScope& waitScope, SourceLocation location); + friend bool _::pollImpl(_::PromiseNode& node, WaitScope& waitScope, SourceLocation location); }; } // namespace kj +#define KJ_ASYNC_H_INCLUDED #include "async-inl.h" -#endif // KJ_ASYNC_H_ +KJ_END_HEADER diff --git a/c++/src/kj/cidr.c++ b/c++/src/kj/cidr.c++ new file mode 100644 index 0000000000..6a3f767ba9 --- /dev/null +++ b/c++/src/kj/cidr.c++ @@ -0,0 +1,175 @@ +// Copyright (c) 2013-2017 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if _WIN32 +// Request Vista-level APIs. +#include +#endif + +#include "debug.h" +#include "cidr.h" + +#if _WIN32 +#include +#include +#include +#include +#define inet_pton InetPtonA +#define inet_ntop InetNtopA +#include +#define dup _dup +#else +#include +#include +#endif + +namespace kj { + +CidrRange::CidrRange(StringPtr pattern) { + size_t slashPos = KJ_REQUIRE_NONNULL(pattern.findFirst('/'), "invalid CIDR", pattern); + + bitCount = pattern.slice(slashPos + 1).parseAs(); + + KJ_STACK_ARRAY(char, addr, slashPos + 1, 128, 128); + memcpy(addr.begin(), pattern.begin(), slashPos); + addr[slashPos] = '\0'; + + if (pattern.findFirst(':') == nullptr) { + family = AF_INET; + KJ_REQUIRE(bitCount <= 32, "invalid CIDR", pattern); + } else { + family = AF_INET6; + KJ_REQUIRE(bitCount <= 128, "invalid CIDR", pattern); + } + + KJ_ASSERT(inet_pton(family, addr.begin(), bits) > 0, "invalid CIDR", pattern); + zeroIrrelevantBits(); +} + +CidrRange::CidrRange(int family, ArrayPtr bits, uint bitCount) + : family(family), bitCount(bitCount) { + if (family == AF_INET) { + KJ_REQUIRE(bitCount <= 32); + } else { + KJ_REQUIRE(bitCount <= 128); + } + KJ_REQUIRE(bits.size() * 8 >= bitCount); + size_t byteCount = (bitCount + 7) / 8; + memcpy(this->bits, bits.begin(), byteCount); + memset(this->bits + byteCount, 0, sizeof(this->bits) - byteCount); + + zeroIrrelevantBits(); +} + +CidrRange CidrRange::inet4(ArrayPtr bits, uint bitCount) { + return CidrRange(AF_INET, bits, bitCount); +} +CidrRange CidrRange::inet6( + ArrayPtr prefix, ArrayPtr suffix, + uint bitCount) { + KJ_REQUIRE(prefix.size() + suffix.size() <= 8); + + byte bits[16] = { 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, }; + + for (size_t i: kj::indices(prefix)) { + bits[i * 2] = prefix[i] >> 8; + bits[i * 2 + 1] = prefix[i] & 0xff; + } + + byte* suffixBits = bits + (16 - suffix.size() * 2); + for (size_t i: kj::indices(suffix)) { + suffixBits[i * 2] = suffix[i] >> 8; + suffixBits[i * 2 + 1] = suffix[i] & 0xff; + } + + return CidrRange(AF_INET6, bits, bitCount); +} + +bool CidrRange::matches(const struct sockaddr* addr) const { + const byte* otherBits; + + switch (family) { + case AF_INET: + if (addr->sa_family == AF_INET6) { + otherBits = reinterpret_cast(addr)->sin6_addr.s6_addr; + static constexpr byte V6MAPPED[12] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff }; + if (memcmp(otherBits, V6MAPPED, sizeof(V6MAPPED)) == 0) { + // We're an ipv4 range and the address is ipv6, but it's a "v6 mapped" address, meaning + // it's equivalent to an ipv4 address. Try to match against the ipv4 part. + otherBits = otherBits + sizeof(V6MAPPED); + } else { + return false; + } + } else if (addr->sa_family == AF_INET) { + otherBits = reinterpret_cast( + &reinterpret_cast(addr)->sin_addr.s_addr); + } else { + return false; + } + + break; + + case AF_INET6: + if (addr->sa_family != AF_INET6) return false; + + otherBits = reinterpret_cast(addr)->sin6_addr.s6_addr; + break; + + default: + KJ_UNREACHABLE; + } + + if (memcmp(bits, otherBits, bitCount / 8) != 0) return false; + + return bitCount == 128 || + bits[bitCount / 8] == (otherBits[bitCount / 8] & (0xff00 >> (bitCount % 8))); +} + +bool CidrRange::matchesFamily(int family) const { + switch (family) { + case AF_INET: + return this->family == AF_INET; + case AF_INET6: + // Even if we're a v4 CIDR, we can match v6 addresses in the v4-mapped range. + return true; + default: + return false; + } +} + +String CidrRange::toString() const { + char result[128]; + KJ_ASSERT(inet_ntop(family, (void*)bits, result, sizeof(result)) == result); + return kj::str(result, '/', bitCount); +} + +void CidrRange::zeroIrrelevantBits() { + // Mask out insignificant bits of partial byte. + if (bitCount < 128) { + bits[bitCount / 8] &= 0xff00 >> (bitCount % 8); + + // Zero the remaining bytes. + size_t n = bitCount / 8 + 1; + memset(bits + n, 0, sizeof(bits) - n); + } +} + +} // namespace kj diff --git a/c++/src/kj/cidr.h b/c++/src/kj/cidr.h new file mode 100644 index 0000000000..b334ecc7d4 --- /dev/null +++ b/c++/src/kj/cidr.h @@ -0,0 +1,62 @@ + +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "common.h" +#include + +KJ_BEGIN_HEADER + +struct sockaddr; + +namespace kj { + +class CidrRange { +public: + CidrRange(StringPtr pattern); + + static CidrRange inet4(ArrayPtr bits, uint bitCount); + static CidrRange inet6(ArrayPtr prefix, ArrayPtr suffix, + uint bitCount); + // Zeros are inserted between `prefix` and `suffix` to extend the address to 128 bits. + + uint getSpecificity() const { return bitCount; } + + bool matches(const struct sockaddr* addr) const; + bool matchesFamily(int family) const; + + String toString() const; + +private: + int family; + byte bits[16]; + uint bitCount; // how many bits in `bits` need to match + + CidrRange(int family, ArrayPtr bits, uint bitCount); + + void zeroIrrelevantBits(); +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/common-test.c++ b/c++/src/kj/common-test.c++ index d1a50c250c..9785612562 100644 --- a/c++/src/kj/common-test.c++ +++ b/c++/src/kj/common-test.c++ @@ -45,6 +45,20 @@ struct ImplicitToInt { } }; +struct Immovable { + Immovable() = default; + KJ_DISALLOW_COPY_AND_MOVE(Immovable); +}; + +struct CopyOrMove { + // Type that detects the difference between copy and move. + CopyOrMove(int i): i(i) {} + CopyOrMove(CopyOrMove&& other): i(other.i) { other.i = -1; } + CopyOrMove(const CopyOrMove&) = default; + + int i; +}; + TEST(Common, Maybe) { { Maybe m = 123; @@ -61,6 +75,98 @@ TEST(Common, Maybe) { ADD_FAILURE(); } EXPECT_EQ(123, m.orDefault(456)); + bool ranLazy = false; + EXPECT_EQ(123, m.orDefault([&] { + ranLazy = true; + return 456; + })); + EXPECT_FALSE(ranLazy); + + KJ_IF_MAYBE(v, m) { + int notUsedForRef = 5; + const int& ref = m.orDefault([&]() -> int& { return notUsedForRef; }); + + EXPECT_EQ(ref, *v); + EXPECT_EQ(&ref, v); + + const int& ref2 = m.orDefault([notUsed = 5]() -> int { return notUsed; }); + EXPECT_NE(&ref, &ref2); + EXPECT_EQ(ref2, 123); + } else { + ADD_FAILURE(); + } + } + + { + Maybe> m = kj::heap(123); + EXPECT_FALSE(m == nullptr); + EXPECT_TRUE(m != nullptr); + KJ_IF_MAYBE(v, m) { + EXPECT_EQ(123, (*v)->i); + } else { + ADD_FAILURE(); + } + KJ_IF_MAYBE(v, mv(m)) { + EXPECT_EQ(123, (*v)->i); + } else { + ADD_FAILURE(); + } + // We have moved the kj::Own away, so this should give us the default and leave the Maybe empty. + EXPECT_EQ(456, m.orDefault(heap(456))->i); + EXPECT_TRUE(m == nullptr); + + bool ranLazy = false; + EXPECT_EQ(123, mv(m).orDefault([&] { + ranLazy = true; + return heap(123); + })->i); + EXPECT_TRUE(ranLazy); + EXPECT_TRUE(m == nullptr); + + m = heap(123); + EXPECT_TRUE(m != nullptr); + ranLazy = false; + EXPECT_EQ(123, mv(m).orDefault([&] { + ranLazy = true; + return heap(456); + })->i); + EXPECT_FALSE(ranLazy); + EXPECT_TRUE(m == nullptr); + } + + { + Maybe empty; + int defaultValue = 5; + auto& ref1 = empty.orDefault([&defaultValue]() -> int& { + return defaultValue; + }); + EXPECT_EQ(&ref1, &defaultValue); + + auto ref2 = empty.orDefault([&]() -> int { return defaultValue; }); + EXPECT_NE(&ref2, &defaultValue); + } + + { + Maybe m = 0; + EXPECT_FALSE(m == nullptr); + EXPECT_TRUE(m != nullptr); + KJ_IF_MAYBE(v, m) { + EXPECT_EQ(0, *v); + } else { + ADD_FAILURE(); + } + KJ_IF_MAYBE(v, mv(m)) { + EXPECT_EQ(0, *v); + } else { + ADD_FAILURE(); + } + EXPECT_EQ(0, m.orDefault(456)); + bool ranLazy = false; + EXPECT_EQ(0, m.orDefault([&] { + ranLazy = true; + return 456; + })); + EXPECT_FALSE(ranLazy); } { @@ -76,6 +182,12 @@ TEST(Common, Maybe) { EXPECT_EQ(0, *v); // avoid unused warning } EXPECT_EQ(456, m.orDefault(456)); + bool ranLazy = false; + EXPECT_EQ(456, m.orDefault([&] { + ranLazy = true; + return 456; + })); + EXPECT_TRUE(ranLazy); } int i = 234; @@ -128,6 +240,24 @@ TEST(Common, Maybe) { EXPECT_EQ(234, m.orDefault(456)); } + { + const Maybe m2 = &i; + Maybe m = m2; + EXPECT_FALSE(m == nullptr); + EXPECT_TRUE(m != nullptr); + KJ_IF_MAYBE(v, m) { + EXPECT_EQ(&i, v); + } else { + ADD_FAILURE(); + } + KJ_IF_MAYBE(v, mv(m)) { + EXPECT_EQ(&i, v); + } else { + ADD_FAILURE(); + } + EXPECT_EQ(234, m.orDefault(456)); + } + { Maybe m = implicitCast(nullptr); EXPECT_TRUE(m == nullptr); @@ -144,35 +274,67 @@ TEST(Common, Maybe) { } { - Maybe m = &i; + Maybe mi = i; + Maybe m = mi; EXPECT_FALSE(m == nullptr); EXPECT_TRUE(m != nullptr); KJ_IF_MAYBE(v, m) { - EXPECT_NE(v, &i); - EXPECT_EQ(234, *v); + EXPECT_EQ(&KJ_ASSERT_NONNULL(mi), v); } else { ADD_FAILURE(); } KJ_IF_MAYBE(v, mv(m)) { - EXPECT_NE(v, &i); - EXPECT_EQ(234, *v); + EXPECT_EQ(&KJ_ASSERT_NONNULL(mi), v); } else { ADD_FAILURE(); } + EXPECT_EQ(234, m.orDefault(456)); } { - Maybe m = implicitCast(nullptr); + Maybe mi = nullptr; + Maybe m = mi; EXPECT_TRUE(m == nullptr); - EXPECT_FALSE(m != nullptr); KJ_IF_MAYBE(v, m) { + KJ_FAIL_EXPECT(*v); + } + } + + { + const Maybe mi = i; + Maybe m = mi; + EXPECT_FALSE(m == nullptr); + EXPECT_TRUE(m != nullptr); + KJ_IF_MAYBE(v, m) { + EXPECT_EQ(&KJ_ASSERT_NONNULL(mi), v); + } else { ADD_FAILURE(); - EXPECT_EQ(0, *v); // avoid unused warning } KJ_IF_MAYBE(v, mv(m)) { + EXPECT_EQ(&KJ_ASSERT_NONNULL(mi), v); + } else { ADD_FAILURE(); - EXPECT_EQ(0, *v); // avoid unused warning } + EXPECT_EQ(234, m.orDefault(456)); + } + + { + const Maybe mi = nullptr; + Maybe m = mi; + EXPECT_TRUE(m == nullptr); + KJ_IF_MAYBE(v, m) { + KJ_FAIL_EXPECT(*v); + } + } + + { + // Verify orDefault() works with move-only types. + Maybe m = nullptr; + kj::String s = kj::mv(m).orDefault(kj::str("foo")); + EXPECT_EQ("foo", s); + EXPECT_EQ("foo", kj::mv(m).orDefault([] { + return kj::str("foo"); + })); } { @@ -191,6 +353,82 @@ TEST(Common, Maybe) { ADD_FAILURE(); } } + + { + // Test usage of immovable types. + Maybe m; + KJ_EXPECT(m == nullptr); + m.emplace(); + KJ_EXPECT(m != nullptr); + m = nullptr; + KJ_EXPECT(m == nullptr); + } + + { + // Test that initializing Maybe from Maybe&& does a copy, not a move. + CopyOrMove x(123); + Maybe m(x); + Maybe m2 = kj::mv(m); + KJ_EXPECT(m == nullptr); // m is moved out of and cleared + KJ_EXPECT(x.i == 123); // but what m *referenced* was not moved out of + KJ_EXPECT(KJ_ASSERT_NONNULL(m2).i == 123); // m2 is a copy of what m referenced + } + + { + // Test that a moved-out-of Maybe is left empty after move constructor. + Maybe m = 123; + KJ_EXPECT(m != nullptr); + + Maybe n(kj::mv(m)); + KJ_EXPECT(m == nullptr); + KJ_EXPECT(n != nullptr); + } + + { + // Test that a moved-out-of Maybe is left empty after move constructor. + Maybe m = 123; + KJ_EXPECT(m != nullptr); + + Maybe n = kj::mv(m); + KJ_EXPECT(m == nullptr); + KJ_EXPECT(n != nullptr); + } + + { + // Test that a moved-out-of Maybe is left empty when moved to a Maybe. + int x = 123; + Maybe m = x; + KJ_EXPECT(m != nullptr); + + Maybe n(kj::mv(m)); + KJ_EXPECT(m == nullptr); + KJ_EXPECT(n != nullptr); + } + + { + // Test that a moved-out-of Maybe is left empty when moved to another Maybe. + int x = 123; + Maybe m = x; + KJ_EXPECT(m != nullptr); + + Maybe n(kj::mv(m)); + KJ_EXPECT(m == nullptr); + KJ_EXPECT(n != nullptr); + } + + { + Maybe m1 = 123; + Maybe m2 = 123; + Maybe m3 = 456; + Maybe m4 = nullptr; + Maybe m5 = nullptr; + + KJ_EXPECT(m1 == m2); + KJ_EXPECT(m1 != m3); + KJ_EXPECT(m1 != m4); + KJ_EXPECT(m4 == m5); + KJ_EXPECT(m4 != m1); + } } TEST(Common, MaybeConstness) { @@ -217,9 +455,99 @@ TEST(Common, MaybeConstness) { } } +#if __GNUC__ +TEST(Common, MaybeUnwrapOrReturn) { + { + auto func = [](Maybe i) -> int { + int& j = KJ_UNWRAP_OR_RETURN(i, -1); + KJ_EXPECT(&j == &KJ_ASSERT_NONNULL(i)); + return j + 2; + }; + + KJ_EXPECT(func(123) == 125); + KJ_EXPECT(func(nullptr) == -1); + } + + { + auto func = [&](Maybe maybe) -> int { + String str = KJ_UNWRAP_OR_RETURN(kj::mv(maybe), -1); + return str.parseAs(); + }; + + KJ_EXPECT(func(kj::str("123")) == 123); + KJ_EXPECT(func(nullptr) == -1); + } + + // Test void return. + { + int val = 0; + auto func = [&](Maybe i) { + val = KJ_UNWRAP_OR_RETURN(i); + }; + + func(123); + KJ_EXPECT(val == 123); + val = 321; + func(nullptr); + KJ_EXPECT(val == 321); + } + + // Test KJ_UNWRAP_OR + { + bool wasNull = false; + auto func = [&](Maybe i) -> int { + int& j = KJ_UNWRAP_OR(i, { + wasNull = true; + return -1; + }); + KJ_EXPECT(&j == &KJ_ASSERT_NONNULL(i)); + return j + 2; + }; + + KJ_EXPECT(func(123) == 125); + KJ_EXPECT(!wasNull); + KJ_EXPECT(func(nullptr) == -1); + KJ_EXPECT(wasNull); + } + + { + bool wasNull = false; + auto func = [&](Maybe maybe) -> int { + String str = KJ_UNWRAP_OR(kj::mv(maybe), { + wasNull = true; + return -1; + }); + return str.parseAs(); + }; + + KJ_EXPECT(func(kj::str("123")) == 123); + KJ_EXPECT(!wasNull); + KJ_EXPECT(func(nullptr) == -1); + KJ_EXPECT(wasNull); + } + + // Test void return. + { + int val = 0; + auto func = [&](Maybe i) { + val = KJ_UNWRAP_OR(i, { + return; + }); + }; + + func(123); + KJ_EXPECT(val == 123); + val = 321; + func(nullptr); + KJ_EXPECT(val == 321); + } + +} +#endif + class Foo { public: - KJ_DISALLOW_COPY(Foo); + KJ_DISALLOW_COPY_AND_MOVE(Foo); virtual ~Foo() {} protected: Foo() = default; @@ -228,14 +556,14 @@ protected: class Bar: public Foo { public: Bar() = default; - KJ_DISALLOW_COPY(Bar); + KJ_DISALLOW_COPY_AND_MOVE(Bar); virtual ~Bar() {} }; class Baz: public Foo { public: Baz() = delete; - KJ_DISALLOW_COPY(Baz); + KJ_DISALLOW_COPY_AND_MOVE(Baz); virtual ~Baz() {} }; @@ -464,6 +792,31 @@ TEST(Common, ArrayAsBytes) { } } +KJ_TEST("ArrayPtr operator ==") { + KJ_EXPECT(ArrayPtr({123, 456}) == ArrayPtr({123, 456})); + KJ_EXPECT(!(ArrayPtr({123, 456}) != ArrayPtr({123, 456}))); + KJ_EXPECT(ArrayPtr({123, 456}) != ArrayPtr({123, 321})); + KJ_EXPECT(ArrayPtr({123, 456}) != ArrayPtr({123})); + + KJ_EXPECT(ArrayPtr({123, 456}) == ArrayPtr({123, 456})); + KJ_EXPECT(!(ArrayPtr({123, 456}) != ArrayPtr({123, 456}))); + KJ_EXPECT(ArrayPtr({123, 456}) != ArrayPtr({123, 321})); + KJ_EXPECT(ArrayPtr({123, 456}) != ArrayPtr({123})); + + KJ_EXPECT((ArrayPtr({"foo", "bar"}) == + ArrayPtr({"foo", "bar"}))); + KJ_EXPECT(!(ArrayPtr({"foo", "bar"}) != + ArrayPtr({"foo", "bar"}))); + KJ_EXPECT((ArrayPtr({"foo", "bar"}) != + ArrayPtr({"foo", "baz"}))); + KJ_EXPECT((ArrayPtr({"foo", "bar"}) != + ArrayPtr({"foo"}))); + + // operator== should not use memcmp for double elements. + double d[1] = { nan() }; + KJ_EXPECT(ArrayPtr(d, 1) != ArrayPtr(d, 1)); +} + KJ_TEST("kj::range()") { uint expected = 5; for (uint i: range(5, 10)) { @@ -478,5 +831,99 @@ KJ_TEST("kj::range()") { KJ_EXPECT(expected == 8); } +KJ_TEST("kj::defer()") { + { + // rvalue reference + bool executed = false; + { + auto deferred = kj::defer([&executed]() { + executed = true; + }); + KJ_EXPECT(!executed); + } + + KJ_EXPECT(executed); + } + + { + // lvalue reference + bool executed = false; + auto executor = [&executed]() { + executed = true; + }; + + { + auto deferred = kj::defer(executor); + KJ_EXPECT(!executed); + } + + KJ_EXPECT(executed); + } + + { + // Cancellation via `cancel()`. + bool executed = false; + { + auto deferred = kj::defer([&executed]() { + executed = true; + }); + KJ_EXPECT(!executed); + + // Cancel and release the functor. + deferred.cancel(); + KJ_EXPECT(!executed); + } + + KJ_EXPECT(!executed); + } + + { + // Execution via `run()`. + size_t runCount = 0; + { + auto deferred = kj::defer([&runCount](){ + ++runCount; + }); + + // Run and release the functor. + deferred.run(); + KJ_EXPECT(runCount == 1); + } + + // `deferred` is already been run, so nothing is run when we destruct it. + KJ_EXPECT(runCount == 1); + } + +} + +KJ_TEST("kj::ArrayPtr startsWith / endsWith / findFirst / findLast") { + // Note: char-/byte- optimized versions are covered by string-test.c++. + + int rawArray[] = {12, 34, 56, 34, 12}; + ArrayPtr arr(rawArray); + + KJ_EXPECT(arr.startsWith({12, 34})); + KJ_EXPECT(arr.startsWith({12, 34, 56})); + KJ_EXPECT(!arr.startsWith({12, 34, 56, 78})); + KJ_EXPECT(arr.startsWith({12, 34, 56, 34, 12})); + KJ_EXPECT(!arr.startsWith({12, 34, 56, 34, 12, 12})); + + KJ_EXPECT(arr.endsWith({34, 12})); + KJ_EXPECT(arr.endsWith({56, 34, 12})); + KJ_EXPECT(!arr.endsWith({78, 56, 34, 12})); + KJ_EXPECT(arr.endsWith({12, 34, 56, 34, 12})); + KJ_EXPECT(!arr.endsWith({12, 12, 34, 56, 34, 12})); + + KJ_EXPECT(arr.findFirst(12).orDefault(100) == 0); + KJ_EXPECT(arr.findFirst(34).orDefault(100) == 1); + KJ_EXPECT(arr.findFirst(56).orDefault(100) == 2); + KJ_EXPECT(arr.findFirst(78).orDefault(100) == 100); + + KJ_EXPECT(arr.findLast(12).orDefault(100) == 4); + KJ_EXPECT(arr.findLast(34).orDefault(100) == 3); + KJ_EXPECT(arr.findLast(56).orDefault(100) == 2); + KJ_EXPECT(arr.findLast(78).orDefault(100) == 100); +} + } // namespace } // namespace kj diff --git a/c++/src/kj/common.c++ b/c++/src/kj/common.c++ index 7d87587112..8960b2316a 100644 --- a/c++/src/kj/common.c++ +++ b/c++/src/kj/common.c++ @@ -22,9 +22,6 @@ #include "common.h" #include "debug.h" #include -#ifdef _MSC_VER -#include -#endif namespace kj { namespace _ { // private @@ -44,15 +41,9 @@ void unreachable() { KJ_FAIL_ASSERT("Supposedly-unreachable branch executed."); // Really make sure we abort. - abort(); + KJ_KNOWN_UNREACHABLE(abort()); } } // namespace _ (private) -#if _MSC_VER && !__clang__ - -float nan() { return std::numeric_limits::quiet_NaN(); } - -#endif - } // namespace kj diff --git a/c++/src/kj/common.h b/c++/src/kj/common.h index c2a7644c14..5385ee3197 100644 --- a/c++/src/kj/common.h +++ b/c++/src/kj/common.h @@ -23,56 +23,85 @@ // // This defines very simple utilities that are widely applicable. -#ifndef KJ_COMMON_H_ -#define KJ_COMMON_H_ +#pragma once -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header +#if defined(__GNUC__) || defined(__clang__) +#define KJ_BEGIN_SYSTEM_HEADER _Pragma("GCC system_header") +#elif defined(_MSC_VER) +#define KJ_BEGIN_SYSTEM_HEADER __pragma(warning(push, 0)) +#define KJ_END_SYSTEM_HEADER __pragma(warning(pop)) +#endif + +#ifndef KJ_BEGIN_SYSTEM_HEADER +#define KJ_BEGIN_SYSTEM_HEADER +#endif + +#ifndef KJ_END_SYSTEM_HEADER +#define KJ_END_SYSTEM_HEADER +#endif + +#if !defined(KJ_HEADER_WARNINGS) || !KJ_HEADER_WARNINGS +#define KJ_BEGIN_HEADER KJ_BEGIN_SYSTEM_HEADER +#define KJ_END_HEADER KJ_END_SYSTEM_HEADER +#else +#define KJ_BEGIN_HEADER +#define KJ_END_HEADER #endif +#ifdef __has_cpp_attribute +#define KJ_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +#define KJ_HAS_CPP_ATTRIBUTE(x) 0 +#endif + +#ifdef __has_feature +#define KJ_HAS_COMPILER_FEATURE(x) __has_feature(x) +#else +#define KJ_HAS_COMPILER_FEATURE(x) 0 +#endif + +KJ_BEGIN_HEADER + #ifndef KJ_NO_COMPILER_CHECK -#if __cplusplus < 201103L && !__CDT_PARSER__ && !_MSC_VER - #error "This code requires C++11. Either your compiler does not support it or it is not enabled." +// Technically, __cplusplus should be 201402L for C++14, but GCC 4.9 -- which is supported -- still +// had it defined to 201300L even with -std=c++14. +#if __cplusplus < 201300L && !__CDT_PARSER__ && !_MSC_VER + #error "This code requires C++14. Either your compiler does not support it or it is not enabled." #ifdef __GNUC__ // Compiler claims compatibility with GCC, so presumably supports -std. - #error "Pass -std=c++11 on the compiler command line to enable C++11." + #error "Pass -std=c++14 on the compiler command line to enable C++14." #endif #endif #ifdef __GNUC__ #if __clang__ - #if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 2) - #warning "This library requires at least Clang 3.2." - #elif defined(__apple_build_version__) && __apple_build_version__ <= 4250028 - #warning "This library requires at least Clang 3.2. XCode 4.6's Clang, which claims to be "\ - "version 4.2 (wat?), is actually built from some random SVN revision between 3.1 "\ - "and 3.2. Unfortunately, it is insufficient for compiling this library. You can "\ - "download the real Clang 3.2 (or newer) from the Clang web site. Step-by-step "\ - "instructions can be found in Cap'n Proto's documentation: "\ - "http://kentonv.github.io/capnproto/install.html#clang_32_on_mac_osx" - #elif __cplusplus >= 201103L && !__has_include() - #warning "Your compiler supports C++11 but your C++ standard library does not. If your "\ + #if __clang_major__ < 5 + #warning "This library requires at least Clang 5.0." + #elif __cplusplus >= 201402L && !__has_include() + #warning "Your compiler supports C++14 but your C++ standard library does not. If your "\ "system has libc++ installed (as should be the case on e.g. Mac OSX), try adding "\ "-stdlib=libc++ to your CXXFLAGS." #endif #else - #if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 7) - #warning "This library requires at least GCC 4.7." + #if __GNUC__ < 5 + #warning "This library requires at least GCC 5.0." #endif #endif #elif defined(_MSC_VER) - #if _MSC_VER < 1900 - #error "You need Visual Studio 2015 or better to compile this code." + #if _MSC_VER < 1910 && !defined(__clang__) + #error "You need Visual Studio 2017 or better to compile this code." #endif #else - #warning "I don't recognize your compiler. As of this writing, Clang and GCC are the only "\ - "known compilers with enough C++11 support for this library. "\ + #warning "I don't recognize your compiler. As of this writing, Clang, GCC, and Visual Studio "\ + "are the only known compilers with enough C++14 support for this library. "\ "#define KJ_NO_COMPILER_CHECK to make this warning go away." #endif #endif #include +#include #include +#include #if __linux__ && __cplusplus > 201200L // Hack around stdlib bug with C++14 that exists on some Linux systems. @@ -82,10 +111,19 @@ #undef _GLIBCXX_HAVE_GETS #endif -#if defined(_MSC_VER) +#if _WIN32 +// Windows likes to define macros for min() and max(). We just can't deal with this. +// If windows.h was included already, undef these. +#undef min +#undef max +// If windows.h was not included yet, define the macro that prevents min() and max() from being +// defined. #ifndef NOMINMAX #define NOMINMAX 1 #endif +#endif + +#if defined(_MSC_VER) #include // __popcnt #endif @@ -102,7 +140,19 @@ typedef unsigned char byte; // Detect whether RTTI and exceptions are enabled, assuming they are unless we have specific // evidence to the contrary. Clients can always define KJ_NO_RTTI or KJ_NO_EXCEPTIONS explicitly // to override these checks. -#ifdef __GNUC__ + +// TODO: Ideally we'd use __cpp_exceptions/__cpp_rtti not being defined as the first pass since +// that is the standard compliant way. However, it's unclear how to use those macros (or any +// others) to distinguish between the compiler supporting feature detection and the feature being +// disabled vs the compiler not supporting feature detection at all. +#if defined(__has_feature) + #if !defined(KJ_NO_RTTI) && !__has_feature(cxx_rtti) + #define KJ_NO_RTTI 1 + #endif + #if !defined(KJ_NO_EXCEPTIONS) && !__has_feature(cxx_exceptions) + #define KJ_NO_EXCEPTIONS 1 + #endif +#elif defined(__GNUC__) #if !defined(KJ_NO_RTTI) && !__GXX_RTTI #define KJ_NO_RTTI 1 #endif @@ -135,7 +185,22 @@ typedef unsigned char byte; #define KJ_DISALLOW_COPY(classname) \ classname(const classname&) = delete; \ classname& operator=(const classname&) = delete -// Deletes the implicit copy constructor and assignment operator. +// Deletes the implicit copy constructor and assignment operator. This inhibits the compiler from +// generating the implicit move constructor and assignment operator for this class, but allows the +// code author to supply them, if they make sense to implement. +// +// This macro should not be your first choice. Instead, prefer using KJ_DISALLOW_COPY_AND_MOVE, and only use +// this macro when you have determined that you must implement move semantics for your type. + +#define KJ_DISALLOW_COPY_AND_MOVE(classname) \ + classname(const classname&) = delete; \ + classname& operator=(const classname&) = delete; \ + classname(classname&&) = delete; \ + classname& operator=(classname&&) = delete +// Deletes the implicit copy and move constructors and assignment operators. This is useful in cases +// where the code author wants to provide an additional compile-time guard against subsequent +// maintainers casually adding move operations. This is particularly useful when implementing RAII +// classes that are intended to be completely immobile. #ifdef __GNUC__ #define KJ_LIKELY(condition) __builtin_expect(condition, true) @@ -152,7 +217,7 @@ typedef unsigned char byte; #define KJ_ALWAYS_INLINE(...) inline __VA_ARGS__ // Don't force inline in debug mode. #else -#if defined(_MSC_VER) +#if defined(_MSC_VER) && !defined(__clang__) #define KJ_ALWAYS_INLINE(...) __forceinline __VA_ARGS__ #else #define KJ_ALWAYS_INLINE(...) inline __VA_ARGS__ __attribute__((always_inline)) @@ -160,7 +225,7 @@ typedef unsigned char byte; // Force a function to always be inlined. Apply only to the prototype, not to the definition. #endif -#if defined(_MSC_VER) +#if defined(_MSC_VER) && !defined(__clang__) #define KJ_NOINLINE __declspec(noinline) #else #define KJ_NOINLINE __attribute__((noinline)) @@ -179,6 +244,29 @@ typedef unsigned char byte; #define KJ_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) #endif +#if KJ_HAS_CPP_ATTRIBUTE(clang::lifetimebound) +// If this is generating too many false-positives, the user is responsible for disabling the +// problematic warning at the compiler switch level or by suppressing the place where the +// false-positive is reported through compiler-specific pragmas if available. +#define KJ_LIFETIMEBOUND [[clang::lifetimebound]] +#else +#define KJ_LIFETIMEBOUND +#endif +// Annotation that indicates the returned value is referencing a resource owned by this type (e.g. +// cStr() on a std::string). Unfortunately this lifetime can only be superficial currently & cannot +// track further. For example, there's no way to get `array.asPtr().slice(5, 6))` to warn if the +// last slice exceeds the lifetime of `array`. That's because in the general case `ArrayPtr::slice` +// can't have the lifetime bound annotation since it's not wrong to do something like: +// ArrayPtr doSomething(ArrayPtr foo) { +// ... +// return foo.slice(5, 6); +// } +// If `ArrayPtr::slice` had a lifetime bound then the compiler would warn about this perfectly +// legitimate method. Really there needs to be 2 more annotations. One to inherit the lifetime bound +// and another to inherit the lifetime bound from a parameter (which really could be the same thing +// by allowing a syntax like `[[clang::lifetimebound(*this)]]`. +// https://clang.llvm.org/docs/AttributeReference.html#lifetimebound + #if __clang__ #define KJ_UNUSED_MEMBER __attribute__((unused)) // Inhibits "unused" warning for member variables. Only Clang produces such a warning, while GCC @@ -187,6 +275,20 @@ typedef unsigned char byte; #define KJ_UNUSED_MEMBER #endif +#if __cplusplus > 201703L || (__clang__ && __clang_major__ >= 9 && __cplusplus >= 201103L) +// Technically this was only added to C++20 but Clang allows it for >= C++11 and spelunking the +// attributes manual indicates it first came in with Clang 9. +#define KJ_NO_UNIQUE_ADDRESS [[no_unique_address]] +#else +#define KJ_NO_UNIQUE_ADDRESS +#endif + +#if KJ_HAS_COMPILER_FEATURE(thread_sanitizer) || defined(__SANITIZE_THREAD__) +#define KJ_DISABLE_TSAN __attribute__((no_sanitize("thread"), noinline)) +#else +#define KJ_DISABLE_TSAN +#endif + #if __clang__ #define KJ_DEPRECATED(reason) \ __attribute__((deprecated(reason))) @@ -195,13 +297,22 @@ typedef unsigned char byte; #elif __GNUC__ #define KJ_DEPRECATED(reason) \ __attribute__((deprecated)) -#define KJ_UNAVAILABLE(reason) +#define KJ_UNAVAILABLE(reason) = delete +// If the `unavailable` attribute is not supproted, just mark the method deleted, which at least +// makes it a compile-time error to try to call it. Note that on Clang, marking a method deleted +// *and* unavailable unfortunately defeats the purpose of the unavailable annotation, as the +// generic "deleted" error is reported instead. #else #define KJ_DEPRECATED(reason) -#define KJ_UNAVAILABLE(reason) +#define KJ_UNAVAILABLE(reason) = delete // TODO(msvc): Again, here, MSVC prefers a prefix, __declspec(deprecated). #endif +#if KJ_TESTING_KJ // defined in KJ's own unit tests; others should not define this +#undef KJ_DEPRECATED +#define KJ_DEPRECATED(reason) +#endif + namespace _ { // private KJ_NORETURN(void inlineRequireFailure( @@ -212,8 +323,12 @@ KJ_NORETURN(void unreachable()); } // namespace _ (private) +#if _MSC_VER && !defined(__clang__) && (!defined(_MSVC_TRADITIONAL) || _MSVC_TRADITIONAL) +#define KJ_MSVC_TRADITIONAL_CPP 1 +#endif + #ifdef KJ_DEBUG -#if _MSC_VER +#if KJ_MSVC_TRADITIONAL_CPP #define KJ_IREQUIRE(condition, ...) \ if (KJ_LIKELY(condition)); else ::kj::_::inlineRequireFailure( \ __FILE__, __LINE__, #condition, "" #__VA_ARGS__, __VA_ARGS__) @@ -246,6 +361,26 @@ KJ_NORETURN(void unreachable()); #define KJ_CLANG_KNOWS_THIS_IS_UNREACHABLE_BUT_GCC_DOESNT KJ_UNREACHABLE #endif +#if __clang__ +#define KJ_KNOWN_UNREACHABLE(code) \ + do { \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wunreachable-code\"") \ + code; \ + _Pragma("clang diagnostic pop") \ + } while (false) +// Suppress "unreachable code" warnings on intentionally unreachable code. +#else +// TODO(someday): Add support for non-clang compilers. +#define KJ_KNOWN_UNREACHABLE(code) do {code;} while(false) +#endif + +#if KJ_HAS_CPP_ATTRIBUTE(fallthrough) +#define KJ_FALLTHROUGH [[fallthrough]] +#else +#define KJ_FALLTHROUGH +#endif + // #define KJ_STACK_ARRAY(type, name, size, minStack, maxStack) // // Allocate an array, preferably on the stack, unless it is too big. On GCC this will use @@ -256,7 +391,7 @@ KJ_NORETURN(void unreachable()); #define KJ_STACK_ARRAY(type, name, size, minStack, maxStack) \ size_t name##_size = (size); \ bool name##_isOnStack = name##_size <= (maxStack); \ - type name##_stack[name##_isOnStack ? size : 0]; \ + type name##_stack[kj::max(1, name##_isOnStack ? name##_size : 0)]; \ ::kj::Array name##_heap = name##_isOnStack ? \ nullptr : kj::heapArray(name##_size); \ ::kj::ArrayPtr name = name##_isOnStack ? \ @@ -278,7 +413,7 @@ KJ_NORETURN(void unreachable()); // Create a unique identifier name. We use concatenate __LINE__ rather than __COUNTER__ so that // the name can be used multiple times in the same macro. -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) #define KJ_CONSTEXPR(...) __VA_ARGS__ // Use in cases where MSVC barfs on constexpr. A replacement keyword (e.g. "const") can be @@ -307,6 +442,15 @@ KJ_NORETURN(void unreachable()); // ======================================================================================= // Template metaprogramming helpers. +#define KJ_HAS_TRIVIAL_CONSTRUCTOR __is_trivially_constructible +#if __GNUC__ && !__clang__ +#define KJ_HAS_NOTHROW_CONSTRUCTOR __has_nothrow_constructor +#define KJ_HAS_TRIVIAL_DESTRUCTOR __has_trivial_destructor +#else +#define KJ_HAS_NOTHROW_CONSTRUCTOR __is_nothrow_constructible +#define KJ_HAS_TRIVIAL_DESTRUCTOR __is_trivially_destructible +#endif + template struct NoInfer_ { typedef T Type; }; template using NoInfer = typename NoInfer_::Type; // Use NoInfer::Type in place of T for a template function parameter to prevent inference of @@ -337,7 +481,7 @@ template <> struct EnableIf_ { typedef void Type; }; template using EnableIf = typename EnableIf_::Type; // Use like: // -// template ()> +// template ()>> // void func(T&& t); template struct VoidSfinae_ { using Type = void; }; @@ -384,7 +528,7 @@ struct DisallowConstCopy { #endif }; -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) #define KJ_CPCAP(obj) obj=::kj::cp(obj) // TODO(msvc): MSVC refuses to invoke non-const versions of copy constructors in by-value lambda @@ -455,10 +599,27 @@ T refIfLvalue(T&&); // KJ_DECLTYPE_REF(i) i3(i); // i3 has type int&. // KJ_DECLTYPE_REF(kj::mv(i)) i4(kj::mv(i)); // i4 has type int. +template struct IsSameType_ { static constexpr bool value = false; }; +template struct IsSameType_ { static constexpr bool value = true; }; +template constexpr bool isSameType() { return IsSameType_::value; } + +template constexpr bool isIntegral() { return false; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } +template <> constexpr bool isIntegral() { return true; } + template struct CanConvert_ { static int sfinae(T); - static bool sfinae(...); + static char sfinae(...); }; template @@ -492,6 +653,35 @@ constexpr bool canMemcpy() { static_assert(kj::canMemcpy(), "this code expects this type to be memcpy()-able"); #endif +template +class Badge { + // A pattern for marking individual methods such that they can only be called from a specific + // caller class: Make the method public but give it a parameter of type `Badge`. Only + // `Caller` can construct one, so only `Caller` can call the method. + // + // // We only allow calls from the class `Bar`. + // void foo(Badge) + // + // The call site looks like: + // + // foo({}); + // + // This pattern also works well for declaring private constructors, but still being able to use + // them with `kj::heap()`, etc. + // + // Idea from: https://awesomekling.github.io/Serenity-C++-patterns-The-Badge/ + // + // Note that some forms of this idea make the copy constructor private as well, in order to + // prohibit `Badge(*(Badge*)nullptr)`. However, that would prevent badges from + // being passed through forwarding functions like `kj::heap()`, which would ruin one of the main + // use cases for this pattern in KJ. In any case, dereferencing a null pointer is UB; there are + // plenty of other ways to get access to private members if you're willing to go UB. For one-off + // debugging purposes, you might as well use `#define private public` at the top of the file. +private: + Badge() {} + friend T; +}; + // ======================================================================================= // Equivalents to std::move() and std::forward(), since these are very commonly needed and the // std header pulls in lots of other stuff. @@ -623,29 +813,13 @@ struct ThrowOverflow { // Functor which throws an exception complaining about integer overflow. Usually this is used // with the interfaces in units.h, but is defined here because Cap'n Proto wants to avoid // including units.h when not using CAPNP_DEBUG_TYPES. - void operator()() const; + [[noreturn]] void operator()() const; }; -#if __GNUC__ || __clang__ +#if __GNUC__ || __clang__ || _MSC_VER inline constexpr float inf() { return __builtin_huge_valf(); } inline constexpr float nan() { return __builtin_nanf(""); } -#elif _MSC_VER - -// Do what MSVC math.h does -#pragma warning(push) -#pragma warning(disable: 4756) // "overflow in constant arithmetic" -inline constexpr float inf() { return (float)(1e300 * 1e300); } -#pragma warning(pop) - -float nan(); -// Unfortunatley, inf() * 0.0f produces a NaN with the sign bit set, whereas our preferred -// canonical NaN should not have the sign bit set. std::numeric_limits::quiet_NaN() -// returns the correct NaN, but we don't want to #include that here. So, we give up and make -// this out-of-line on MSVC. -// -// TODO(msvc): Can we do better? - #else #error "Not sure how to support your compiler." #endif @@ -654,7 +828,7 @@ inline constexpr bool isNaN(float f) { return f != f; } inline constexpr bool isNaN(double f) { return f != f; } inline int popCount(unsigned int x) { -#if defined(_MSC_VER) +#if defined(_MSC_VER) && !defined(__clang__) return __popcnt(x); // Note: __popcnt returns unsigned int, but the value is clearly guaranteed to fit into an int #else @@ -792,6 +966,68 @@ inline constexpr Repeat> repeat(T&& value, size_t count) { return Repeat>(value, count); } +template +class MappedIterator: private Mapping { + // An iterator that wraps some other iterator and maps the values through a mapping function. + // The type `Mapping` must define a method `map()` which performs this mapping. + +public: + template + MappedIterator(Inner inner, Params&&... params) + : Mapping(kj::fwd(params)...), inner(inner) {} + + inline auto operator->() const { return &Mapping::map(*inner); } + inline decltype(auto) operator* () const { return Mapping::map(*inner); } + inline decltype(auto) operator[](size_t index) const { return Mapping::map(inner[index]); } + inline MappedIterator& operator++() { ++inner; return *this; } + inline MappedIterator operator++(int) { return MappedIterator(inner++, *this); } + inline MappedIterator& operator--() { --inner; return *this; } + inline MappedIterator operator--(int) { return MappedIterator(inner--, *this); } + inline MappedIterator& operator+=(ptrdiff_t amount) { inner += amount; return *this; } + inline MappedIterator& operator-=(ptrdiff_t amount) { inner -= amount; return *this; } + inline MappedIterator operator+ (ptrdiff_t amount) const { + return MappedIterator(inner + amount, *this); + } + inline MappedIterator operator- (ptrdiff_t amount) const { + return MappedIterator(inner - amount, *this); + } + inline ptrdiff_t operator- (const MappedIterator& other) const { return inner - other.inner; } + + inline bool operator==(const MappedIterator& other) const { return inner == other.inner; } + inline bool operator!=(const MappedIterator& other) const { return inner != other.inner; } + inline bool operator<=(const MappedIterator& other) const { return inner <= other.inner; } + inline bool operator>=(const MappedIterator& other) const { return inner >= other.inner; } + inline bool operator< (const MappedIterator& other) const { return inner < other.inner; } + inline bool operator> (const MappedIterator& other) const { return inner > other.inner; } + +private: + Inner inner; +}; + +template +class MappedIterable: private Mapping { + // An iterable that wraps some other iterable and maps the values through a mapping function. + // The type `Mapping` must define a method `map()` which performs this mapping. + +public: + template + MappedIterable(Inner inner, Params&&... params) + : Mapping(kj::fwd(params)...), inner(inner) {} + + typedef Decay().begin())> InnerIterator; + typedef MappedIterator Iterator; + typedef Decay().begin())> InnerConstIterator; + typedef MappedIterator ConstIterator; + + inline Iterator begin() { return { inner.begin(), (Mapping&)*this }; } + inline Iterator end() { return { inner.end(), (Mapping&)*this }; } + inline ConstIterator begin() const { return { inner.begin(), (const Mapping&)*this }; } + inline ConstIterator end() const { return { inner.end(), (const Mapping&)*this }; } + +private: + Inner inner; +}; + // ======================================================================================= // Manually invoking constructors and destructors // @@ -831,8 +1067,7 @@ inline void dtor(T& location) { // forces the caller to handle the null case in order to satisfy the compiler, thus reliably // preventing null pointer dereferences at runtime. // -// Maybe can be implicitly constructed from T and from nullptr. Additionally, it can be -// implicitly constructed from T*, in which case the pointer is checked for nullness at runtime. +// Maybe can be implicitly constructed from T and from nullptr. // To read the value of a Maybe, do: // // KJ_IF_MAYBE(value, someFuncReturningMaybe()) { @@ -859,7 +1094,7 @@ class NullableValue { // boolean flag indicating nullness. public: - inline NullableValue(NullableValue&& other) noexcept(noexcept(T(instance()))) + inline NullableValue(NullableValue&& other) : isSet(other.isSet) { if (isSet) { ctor(value, kj::mv(other.value)); @@ -878,7 +1113,7 @@ class NullableValue { } } inline ~NullableValue() -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) // TODO(msvc): MSVC has a hard time with noexcept specifier expressions that are more complex // than `true` or `false`. We had a workaround for VS2015, but VS2017 regressed. noexcept(false) @@ -911,9 +1146,8 @@ class NullableValue { return value; } -private: // internal interface used by friends only - inline NullableValue() noexcept: isSet(false) {} - inline NullableValue(T&& t) noexcept(noexcept(T(instance()))) + inline NullableValue(): isSet(false) {} + inline NullableValue(T&& t) : isSet(true) { ctor(value, kj::mv(t)); } @@ -925,12 +1159,8 @@ class NullableValue { : isSet(true) { ctor(value, t); } - inline NullableValue(const T* t) - : isSet(t != nullptr) { - if (isSet) ctor(value, *t); - } template - inline NullableValue(NullableValue&& other) noexcept(noexcept(T(instance()))) + inline NullableValue(NullableValue&& other) : isSet(other.isSet) { if (isSet) { ctor(value, kj::mv(other.value)); @@ -997,13 +1227,56 @@ class NullableValue { return *this; } + inline NullableValue& operator=(T&& other) { emplace(kj::mv(other)); return *this; } + inline NullableValue& operator=(T& other) { emplace(other); return *this; } + inline NullableValue& operator=(const T& other) { emplace(other); return *this; } + template + inline NullableValue& operator=(NullableValue&& other) { + if (other.isSet) { + emplace(kj::mv(other.value)); + } else { + *this = nullptr; + } + return *this; + } + template + inline NullableValue& operator=(const NullableValue& other) { + if (other.isSet) { + emplace(other.value); + } else { + *this = nullptr; + } + return *this; + } + template + inline NullableValue& operator=(const NullableValue& other) { + if (other.isSet) { + emplace(other.value); + } else { + *this = nullptr; + } + return *this; + } + inline NullableValue& operator=(decltype(nullptr)) { + if (isSet) { + isSet = false; + dtor(value); + } + return *this; + } + inline bool operator==(decltype(nullptr)) const { return !isSet; } inline bool operator!=(decltype(nullptr)) const { return isSet; } + NullableValue(const T* t) = delete; + NullableValue& operator=(const T* other) = delete; + // We used to permit assigning a Maybe directly from a T*, and the assignment would check for + // nullness. This turned out never to be useful, and sometimes to be dangerous. + private: bool isSet; -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) #pragma warning(push) #pragma warning(disable: 4624) // Warns that the anonymous union has a deleted destructor when T is non-trivial. This warning @@ -1014,7 +1287,7 @@ class NullableValue { T value; }; -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) #pragma warning(pop) #endif @@ -1042,6 +1315,63 @@ inline T* readMaybe(T* ptr) { return ptr; } #define KJ_IF_MAYBE(name, exp) if (auto name = ::kj::_::readMaybe(exp)) +#if __GNUC__ || __clang__ +// These two macros provide a friendly syntax to extract the value of a Maybe or return early. +// +// Use KJ_UNWRAP_OR_RETURN if you just want to return a simple value when the Maybe is null: +// +// int foo(Maybe maybe) { +// int value = KJ_UNWRAP_OR_RETURN(maybe, -1); +// // ... use value ... +// } +// +// For functions returning void, omit the second parameter to KJ_UNWRAP_OR_RETURN: +// +// void foo(Maybe maybe) { +// int value = KJ_UNWRAP_OR_RETURN(maybe); +// // ... use value ... +// } +// +// Use KJ_UNWRAP_OR if you want to execute a block with multiple statements. +// +// int foo(Maybe maybe) { +// int value = KJ_UNWRAP_OR(maybe, { +// KJ_LOG(ERROR, "problem!!!"); +// return -1; +// }); +// // ... use value ... +// } +// +// The block MUST return at the end or you will get a compiler error +// +// Unfortunately, these macros seem impossible to express without using GCC's non-standard +// "statement expressions" extension. IIFEs don't do the trick here because a lambda cannot +// return out of the parent scope. These macros should therefore only be used in projects that +// target GCC or GCC-compatible compilers. +// +// `__GNUC__` is not defined when using LLVM's MSVC-compatible compiler driver `clang-cl` (even +// though clang supports the required extension), hence the additional `|| __clang__`. + +#define KJ_UNWRAP_OR_RETURN(value, ...) \ + (*({ \ + auto _kj_result = ::kj::_::readMaybe(value); \ + if (!_kj_result) { \ + return __VA_ARGS__; \ + } \ + kj::mv(_kj_result); \ + })) + +#define KJ_UNWRAP_OR(value, block) \ + (*({ \ + auto _kj_result = ::kj::_::readMaybe(value); \ + if (!_kj_result) { \ + block; \ + asm("KJ_UNWRAP_OR_block_is_missing_return_statement\n"); \ + } \ + kj::mv(_kj_result); \ + })) +#endif + template class Maybe { // A T, or nullptr. @@ -1050,18 +1380,25 @@ class Maybe { public: Maybe(): ptr(nullptr) {} - Maybe(T&& t) noexcept(noexcept(T(instance()))): ptr(kj::mv(t)) {} + Maybe(T&& t): ptr(kj::mv(t)) {} Maybe(T& t): ptr(t) {} Maybe(const T& t): ptr(t) {} - Maybe(const T* t) noexcept: ptr(t) {} - Maybe(Maybe&& other) noexcept(noexcept(T(instance()))): ptr(kj::mv(other.ptr)) {} + Maybe(Maybe&& other): ptr(kj::mv(other.ptr)) { other = nullptr; } Maybe(const Maybe& other): ptr(other.ptr) {} Maybe(Maybe& other): ptr(other.ptr) {} template - Maybe(Maybe&& other) noexcept(noexcept(T(instance()))) { + Maybe(Maybe&& other) { KJ_IF_MAYBE(val, kj::mv(other)) { ptr.emplace(kj::mv(*val)); + other = nullptr; + } + } + template + Maybe(Maybe&& other) { + KJ_IF_MAYBE(val, other) { + ptr.emplace(*val); + other = nullptr; } } template @@ -1071,38 +1408,132 @@ class Maybe { } } - Maybe(decltype(nullptr)) noexcept: ptr(nullptr) {} + Maybe(decltype(nullptr)): ptr(nullptr) {} template inline T& emplace(Params&&... params) { - // Replace this Maybe's content with a new value constructed by passing the given parametrs to + // Replace this Maybe's content with a new value constructed by passing the given parameters to // T's constructor. This can be used to initialize a Maybe without copying or even moving a T. // Returns a reference to the newly-constructed value. return ptr.emplace(kj::fwd(params)...); } - inline Maybe& operator=(Maybe&& other) { ptr = kj::mv(other.ptr); return *this; } + inline Maybe& operator=(T&& other) { ptr = kj::mv(other); return *this; } + inline Maybe& operator=(T& other) { ptr = other; return *this; } + inline Maybe& operator=(const T& other) { ptr = other; return *this; } + + inline Maybe& operator=(Maybe&& other) { ptr = kj::mv(other.ptr); other = nullptr; return *this; } inline Maybe& operator=(Maybe& other) { ptr = other.ptr; return *this; } inline Maybe& operator=(const Maybe& other) { ptr = other.ptr; return *this; } + template + Maybe& operator=(Maybe&& other) { + KJ_IF_MAYBE(val, kj::mv(other)) { + ptr.emplace(kj::mv(*val)); + other = nullptr; + } else { + ptr = nullptr; + } + return *this; + } + template + Maybe& operator=(const Maybe& other) { + KJ_IF_MAYBE(val, other) { + ptr.emplace(*val); + } else { + ptr = nullptr; + } + return *this; + } + + inline Maybe& operator=(decltype(nullptr)) { ptr = nullptr; return *this; } + inline bool operator==(decltype(nullptr)) const { return ptr == nullptr; } inline bool operator!=(decltype(nullptr)) const { return ptr != nullptr; } - T& orDefault(T& defaultValue) { + inline bool operator==(const Maybe& other) const { + if (ptr == nullptr) { + return other == nullptr; + } else { + return other.ptr != nullptr && *ptr == *other.ptr; + } + } + inline bool operator!=(const Maybe& other) const { return !(*this == other); } + + Maybe(const T* t) = delete; + Maybe& operator=(const T* other) = delete; + // We used to permit assigning a Maybe directly from a T*, and the assignment would check for + // nullness. This turned out never to be useful, and sometimes to be dangerous. + + T& orDefault(T& defaultValue) & { if (ptr == nullptr) { return defaultValue; } else { return *ptr; } } - const T& orDefault(const T& defaultValue) const { + const T& orDefault(const T& defaultValue) const & { if (ptr == nullptr) { return defaultValue; } else { return *ptr; } } + T&& orDefault(T&& defaultValue) && { + if (ptr == nullptr) { + return kj::mv(defaultValue); + } else { + return kj::mv(*ptr); + } + } + const T&& orDefault(const T&& defaultValue) const && { + if (ptr == nullptr) { + return kj::mv(defaultValue); + } else { + return kj::mv(*ptr); + } + } + + template () ? instance() : instance()())> + Result orDefault(F&& lazyDefaultValue) & { + if (ptr == nullptr) { + return lazyDefaultValue(); + } else { + return *ptr; + } + } + + template () ? instance() : instance()())> + Result orDefault(F&& lazyDefaultValue) const & { + if (ptr == nullptr) { + return lazyDefaultValue(); + } else { + return *ptr; + } + } + + template () ? instance() : instance()())> + Result orDefault(F&& lazyDefaultValue) && { + if (ptr == nullptr) { + return lazyDefaultValue(); + } else { + return kj::mv(*ptr); + } + } + + template () ? instance() : instance()())> + Result orDefault(F&& lazyDefaultValue) const && { + if (ptr == nullptr) { + return lazyDefaultValue(); + } else { + return kj::mv(*ptr); + } + } template auto map(Func&& f) & -> Maybe()))> { @@ -1154,24 +1585,51 @@ class Maybe { }; template -class Maybe: public DisallowConstCopyIfNotConst { +class Maybe { public: - Maybe() noexcept: ptr(nullptr) {} - Maybe(T& t) noexcept: ptr(&t) {} - Maybe(T* t) noexcept: ptr(t) {} + constexpr Maybe(): ptr(nullptr) {} + constexpr Maybe(T& t): ptr(&t) {} + constexpr Maybe(T* t): ptr(t) {} + + inline constexpr Maybe(PropagateConst& other): ptr(other.ptr) {} + // Allow const copy only if `T` itself is const. Otherwise allow only non-const copy, to + // protect transitive constness. Clang is happy for this constructor to be declared `= default` + // since, after evaluation of `PropagateConst`, it does end up being a default-able constructor. + // But, GCC and MSVC both complain about that, claiming this constructor cannot be declared + // default. I don't know who is correct, but whatever, we'll write out an implementation, fine. + // + // Note that we can't solve this by inheriting DisallowConstCopyIfNotConst because we want + // to override the move constructor, and if we override the move constructor then we must define + // the copy constructor here. + + inline constexpr Maybe(Maybe&& other): ptr(other.ptr) { other.ptr = nullptr; } template - inline Maybe(Maybe& other) noexcept: ptr(other.ptr) {} + inline constexpr Maybe(Maybe& other): ptr(other.ptr) {} template - inline Maybe(const Maybe& other) noexcept: ptr(other.ptr) {} - inline Maybe(decltype(nullptr)) noexcept: ptr(nullptr) {} - - inline Maybe& operator=(T& other) noexcept { ptr = &other; return *this; } - inline Maybe& operator=(T* other) noexcept { ptr = other; return *this; } + inline constexpr Maybe(const Maybe& other): ptr(const_cast(other.ptr)) {} template - inline Maybe& operator=(Maybe& other) noexcept { ptr = other.ptr; return *this; } + inline constexpr Maybe(Maybe&& other): ptr(other.ptr) { other.ptr = nullptr; } template - inline Maybe& operator=(const Maybe& other) noexcept { ptr = other.ptr; return *this; } + inline constexpr Maybe(const Maybe&& other) = delete; + template ()>> + constexpr Maybe(Maybe& other): ptr(other.ptr.operator U*()) {} + template ()>> + constexpr Maybe(const Maybe& other): ptr(other.ptr.operator const U*()) {} + inline constexpr Maybe(decltype(nullptr)): ptr(nullptr) {} + + inline Maybe& operator=(T& other) { ptr = &other; return *this; } + inline Maybe& operator=(T* other) { ptr = other; return *this; } + inline Maybe& operator=(PropagateConst& other) { ptr = other.ptr; return *this; } + inline Maybe& operator=(Maybe&& other) { ptr = other.ptr; other.ptr = nullptr; return *this; } + template + inline Maybe& operator=(Maybe& other) { ptr = other.ptr; return *this; } + template + inline Maybe& operator=(const Maybe& other) { ptr = other.ptr; return *this; } + template + inline Maybe& operator=(Maybe&& other) { ptr = other.ptr; other.ptr = nullptr; return *this; } + template + inline Maybe& operator=(const Maybe&& other) = delete; inline bool operator==(decltype(nullptr)) const { return ptr == nullptr; } inline bool operator!=(decltype(nullptr)) const { return ptr != nullptr; } @@ -1200,6 +1658,16 @@ class Maybe: public DisallowConstCopyIfNotConst { } } + template + auto map(Func&& f) const -> Maybe()))> { + if (ptr == nullptr) { + return nullptr; + } else { + const T& ref = *ptr; + return f(ref); + } + } + private: T* ptr; @@ -1216,6 +1684,9 @@ class Maybe: public DisallowConstCopyIfNotConst { // // So common that we put it in common.h rather than array.h. +template +class Array; + template class ArrayPtr: public DisallowConstCopyIfNotConst { // A pointer to an array. Includes a size. Like any pointer, it doesn't own the target data, @@ -1224,14 +1695,71 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { public: inline constexpr ArrayPtr(): ptr(nullptr), size_(0) {} inline constexpr ArrayPtr(decltype(nullptr)): ptr(nullptr), size_(0) {} - inline constexpr ArrayPtr(T* ptr, size_t size): ptr(ptr), size_(size) {} - inline constexpr ArrayPtr(T* begin, T* end): ptr(begin), size_(end - begin) {} - inline KJ_CONSTEXPR() ArrayPtr(::std::initializer_list> init) + inline constexpr ArrayPtr(T* ptr KJ_LIFETIMEBOUND, size_t size): ptr(ptr), size_(size) {} + inline constexpr ArrayPtr(T* begin KJ_LIFETIMEBOUND, T* end KJ_LIFETIMEBOUND) + : ptr(begin), size_(end - begin) {} + ArrayPtr& operator=(Array&&) = delete; + ArrayPtr& operator=(decltype(nullptr)) { + ptr = nullptr; + size_ = 0; + return *this; + } + +#if __GNUC__ && !__clang__ && __GNUC__ >= 9 +// GCC 9 added a warning when we take an initializer_list as a constructor parameter and save a +// pointer to its content in a class member. GCC apparently imagines we're going to do something +// dumb like this: +// ArrayPtr ptr = { 1, 2, 3 }; +// foo(ptr[1]); // undefined behavior! +// Any KJ programmer should be able to recognize that this is UB, because an ArrayPtr does not own +// its content. That's not what this constructor is for, tohugh. This constructor is meant to allow +// code like this: +// int foo(ArrayPtr p); +// // ... later ... +// foo({1, 2, 3}); +// In this case, the initializer_list's backing array, like any temporary, lives until the end of +// the statement `foo({1, 2, 3});`. Therefore, it lives at least until the call to foo() has +// returned, which is exactly what we care about. This usage is fine! GCC is wrong to warn. +// +// Amusingly, Clang's implementation has a similar type that they call ArrayRef which apparently +// triggers this same GCC warning. My guess is that Clang will not introduce a similar warning +// given that it triggers on their own, legitimate code. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Winit-list-lifetime" +#endif + inline KJ_CONSTEXPR() ArrayPtr( + ::std::initializer_list> init KJ_LIFETIMEBOUND) : ptr(init.begin()), size_(init.size()) {} +#if __GNUC__ && !__clang__ && __GNUC__ >= 9 +#pragma GCC diagnostic pop +#endif template - inline constexpr ArrayPtr(T (&native)[size]): ptr(native), size_(size) {} - // Construct an ArrayPtr from a native C-style array. + inline constexpr ArrayPtr(KJ_LIFETIMEBOUND T (&native)[size]): ptr(native), size_(size) { + // Construct an ArrayPtr from a native C-style array. + // + // We disable this constructor for const char arrays because otherwise you would be able to + // implicitly convert a character literal to ArrayPtr, which sounds really great, + // except that the NUL terminator would be included, which probably isn't what you intended. + // + // TODO(someday): Maybe we should support character literals but explicitly chop off the NUL + // terminator. This could do the wrong thing if someone tries to construct an + // ArrayPtr from a non-NUL-terminated char array, but evidence suggests that all + // real use cases are in fact intending to remove the NUL terminator. It's convenient to be + // able to specify ArrayPtr as a parameter type and be able to accept strings + // as input in addition to arrays. Currently, you'll need overloading to support string + // literals in this case, but if you overload StringPtr, then you'll find that several + // conversions (e.g. from String and from a literal char array) become ambiguous! You end up + // having to overload for literal char arrays specifically which is cumbersome. + + static_assert(!isSameType(), + "Can't implicitly convert literal char array to ArrayPtr because we don't know if " + "you meant to include the NUL terminator. We may change this in the future to " + "automatically drop the NUL terminator. For now, try explicitly converting to StringPtr, " + "which can in turn implicitly convert to ArrayPtr."); + static_assert(!isSameType(), "see above"); + static_assert(!isSameType(), "see above"); + } inline operator ArrayPtr() const { return ArrayPtr(ptr, size_); @@ -1240,7 +1768,7 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { return ArrayPtr(ptr, size_); } - inline size_t size() const { return size_; } + inline constexpr size_t size() const { return size_; } inline const T& operator[](size_t index) const { KJ_IREQUIRE(index < size_, "Out-of-bounds ArrayPtr access."); return ptr[index]; @@ -1254,8 +1782,8 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { inline T* end() { return ptr + size_; } inline T& front() { return *ptr; } inline T& back() { return *(ptr + size_ - 1); } - inline const T* begin() const { return ptr; } - inline const T* end() const { return ptr + size_; } + inline constexpr const T* begin() const { return ptr; } + inline constexpr const T* end() const { return ptr + size_; } inline const T& front() const { return *ptr; } inline const T& back() const { return *(ptr + size_ - 1); } @@ -1267,6 +1795,29 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { KJ_IREQUIRE(start <= end && end <= size_, "Out-of-bounds ArrayPtr::slice()."); return ArrayPtr(ptr + start, end - start); } + inline bool startsWith(const ArrayPtr& other) const { + return other.size() <= size_ && slice(0, other.size()) == other; + } + inline bool endsWith(const ArrayPtr& other) const { + return other.size() <= size_ && slice(size_ - other.size(), size_) == other; + } + + inline Maybe findFirst(const T& match) const { + for (size_t i = 0; i < size_; i++) { + if (ptr[i] == match) { + return i; + } + } + return nullptr; + } + inline Maybe findLast(const T& match) const { + for (size_t i = size_; i--;) { + if (ptr[i] == match) { + return i; + } + } + return nullptr; + } inline ArrayPtr> asBytes() const { // Reinterpret the array as a byte array. This is explicitly legal under C++ aliasing @@ -1284,26 +1835,95 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { inline bool operator==(const ArrayPtr& other) const { if (size_ != other.size_) return false; + if (isIntegral>()) { + if (size_ == 0) return true; + return memcmp(ptr, other.ptr, size_ * sizeof(T)) == 0; + } for (size_t i = 0; i < size_; i++) { if (ptr[i] != other[i]) return false; } return true; } +#if !__cpp_impl_three_way_comparison inline bool operator!=(const ArrayPtr& other) const { return !(*this == other); } +#endif + + template + inline bool operator==(const ArrayPtr& other) const { + if (size_ != other.size()) return false; + for (size_t i = 0; i < size_; i++) { + if (ptr[i] != other[i]) return false; + } + return true; + } +#if !__cpp_impl_three_way_comparison + template + inline bool operator!=(const ArrayPtr& other) const { return !(*this == other); } +#endif + + template + Array attach(Attachments&&... attachments) const KJ_WARN_UNUSED_RESULT; + // Like Array::attach(), but also promotes an ArrayPtr to an Array. Generally the attachment + // should be an object that actually owns the array that the ArrayPtr is pointing at. + // + // You must include kj/array.h to call this. private: T* ptr; size_t size_; }; +template <> +inline Maybe ArrayPtr::findFirst(const char& c) const { + const char* pos = reinterpret_cast(memchr(ptr, c, size_)); + if (pos == nullptr) { + return nullptr; + } else { + return pos - ptr; + } +} + +template <> +inline Maybe ArrayPtr::findFirst(const char& c) const { + char* pos = reinterpret_cast(memchr(ptr, c, size_)); + if (pos == nullptr) { + return nullptr; + } else { + return pos - ptr; + } +} + +template <> +inline Maybe ArrayPtr::findFirst(const byte& c) const { + const byte* pos = reinterpret_cast(memchr(ptr, c, size_)); + if (pos == nullptr) { + return nullptr; + } else { + return pos - ptr; + } +} + +template <> +inline Maybe ArrayPtr::findFirst(const byte& c) const { + byte* pos = reinterpret_cast(memchr(ptr, c, size_)); + if (pos == nullptr) { + return nullptr; + } else { + return pos - ptr; + } +} + +// glibc has a memrchr() for reverse search but it's non-standard, so we don't bother optimizing +// findLast(), which isn't used much anyway. + template -inline constexpr ArrayPtr arrayPtr(T* ptr, size_t size) { +inline constexpr ArrayPtr arrayPtr(T* ptr KJ_LIFETIMEBOUND, size_t size) { // Use this function to construct ArrayPtrs without writing out the type name. return ArrayPtr(ptr, size); } template -inline constexpr ArrayPtr arrayPtr(T* begin, T* end) { +inline constexpr ArrayPtr arrayPtr(T* begin KJ_LIFETIMEBOUND, T* end KJ_LIFETIMEBOUND) { // Use this function to construct ArrayPtrs without writing out the type name. return ArrayPtr(begin, end); } @@ -1365,29 +1985,48 @@ namespace _ { // private template class Deferred { public: - inline Deferred(Func&& func): func(kj::fwd(func)), canceled(false) {} - inline ~Deferred() noexcept(false) { if (!canceled) func(); } + Deferred(Func&& func): maybeFunc(kj::fwd(func)) {} + ~Deferred() noexcept(false) { + run(); + } KJ_DISALLOW_COPY(Deferred); - // This move constructor is usually optimized away by the compiler. - inline Deferred(Deferred&& other): func(kj::mv(other.func)), canceled(false) { - other.canceled = true; + Deferred(Deferred&&) = default; + // Since we use a kj::Maybe, the default move constructor does exactly what we want it to do. + + void run() { + // Move `maybeFunc` to the local scope so that even if we throw, we destroy the functor we had. + auto maybeLocalFunc = kj::mv(maybeFunc); + KJ_IF_MAYBE(func, maybeLocalFunc) { + (*func)(); + } } + + void cancel() { + maybeFunc = nullptr; + } + private: - Func func; - bool canceled; + kj::Maybe maybeFunc; + // Note that `Func` may actually be an lvalue reference because `kj::defer` takes its argument via + // universal reference. `kj::Maybe` has specializations for lvalue reference types, so this works + // out. }; } // namespace _ (private) template _::Deferred defer(Func&& func) { - // Returns an object which will invoke the given functor in its destructor. The object is not - // copyable but is movable with the semantics you'd expect. Since the return type is private, - // you need to assign to an `auto` variable. + // Returns an object which will invoke the given functor in its destructor. The object is not + // copyable but is move-constructable with the semantics you'd expect. Since the return type is + // private, you need to assign to an `auto` variable. // // The KJ_DEFER macro provides slightly more convenient syntax for the common case where you // want some code to run at current scope exit. + // + // KJ_DEFER does not support move-assignment for its returned objects. If you need to reuse the + // variable for your deferred function object, then you will want to write your own class for that + // purpose. return _::Deferred(kj::fwd(func)); } @@ -1397,4 +2036,4 @@ _::Deferred defer(Func&& func) { } // namespace kj -#endif // KJ_COMMON_H_ +KJ_END_HEADER diff --git a/c++/src/kj/compat/BUILD.bazel b/c++/src/kj/compat/BUILD.bazel new file mode 100644 index 0000000000..f9e31ce6be --- /dev/null +++ b/c++/src/kj/compat/BUILD.bazel @@ -0,0 +1,155 @@ +exports_files(["gtest.h"]) + +cc_library( + name = "kj-tls", + srcs = [ + "readiness-io.c++", + "tls.c++", + ], + hdrs = [ + "readiness-io.h", + "tls.h", + ], + include_prefix = "kj/compat", + target_compatible_with = select({ + "//src/kj:use_openssl": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + visibility = ["//visibility:public"], + deps = [ + "//src/kj:kj-async", + "@ssl", + ], +) + +cc_library( + name = "kj-http", + srcs = [ + "http.c++", + "url.c++", + ], + hdrs = [ + "http.h", + "url.h", + ], + include_prefix = "kj/compat", + visibility = ["//visibility:public"], + deps = [ + "//src/kj:kj-async", + "@zlib", + ], +) + +cc_library( + name = "kj-gzip", + srcs = ["gzip.c++"], + hdrs = ["gzip.h"], + include_prefix = "kj/compat", + visibility = ["//visibility:public"], + deps = [ + "//src/kj:kj-async", + "@zlib", + ], +) + +cc_library( + name = "kj-brotli", + srcs = ["brotli.c++"], + hdrs = ["brotli.h"], + include_prefix = "kj/compat", + visibility = ["//visibility:public"], + target_compatible_with = select({ + "//src/kj:use_brotli": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + "//src/kj:kj-async", + "@brotli//:brotlienc", + "@brotli//:brotlidec", + ], +) + +cc_library( + name = "gtest", + hdrs = ["gtest.h"], + include_prefix = "kj/compat", + visibility = ["//visibility:public"], + deps = ["//src/kj"], +) + +kj_tests = [ + "http-test.c++", + "url-test.c++", +] + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + deps = [ + ":kj-http", + "//src/kj:kj-test", + ], +) for f in kj_tests] + +cc_library( + name = "http-socketpair-test-base", + hdrs = ["http-test.c++"], +) + +cc_test( + name = "http-socketpair-test", + srcs = ["http-socketpair-test.c++"], + deps = [ + ":http-socketpair-test-base", + ":kj-http", + "//src/kj:kj-test", + ], + target_compatible_with = [ + "@platforms//os:linux", # TODO: Investigate why this fails on macOS + ], +) + +kj_tls_tests = [ + "tls-test.c++", + "readiness-io-test.c++", +] + +[cc_test( + name = f.removesuffix(".c++"), + srcs = [f], + target_compatible_with = select({ + "//src/kj:use_openssl": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kj-tls", + ":kj-http", + "//src/kj:kj-test", + ], +) for f in kj_tls_tests] + +cc_test( + name = "gzip-test", + srcs = ["gzip-test.c++"], + target_compatible_with = select({ + "//src/kj:use_zlib": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kj-gzip", + "//src/kj:kj-test", + ], +) + +cc_test( + name = "brotli-test", + srcs = ["brotli-test.c++"], + target_compatible_with = select({ + "//src/kj:use_brotli": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kj-brotli", + "//src/kj:kj-test", + ], +) diff --git a/c++/src/kj/compat/brotli-test.c++ b/c++/src/kj/compat/brotli-test.c++ new file mode 100644 index 0000000000..f0e00d1e08 --- /dev/null +++ b/c++/src/kj/compat/brotli-test.c++ @@ -0,0 +1,410 @@ +// Copyright (c) 2023 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if KJ_HAS_BROTLI + +#include "brotli.h" +#include +#include +#include + +namespace kj { +namespace { + +static const byte FOOBAR_BR[] = { + 0x83, 0x02, 0x80, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72, 0x03, +}; + +// brotli stream with 24 window bits, i.e. the max window size. If KJ_BROTLI_MAX_DEC_WBITS is less +// than 24, the stream will be rejected by default. This approach should be acceptable in a web +// context, where few files benefit from larger windows and memory usage matters for +// concurrent transfers. +static const byte FOOBAR_BR_LARGE_WIN[] = { + 0x8f, 0x02, 0x80, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72, 0x03, +}; + +class MockInputStream: public InputStream { +public: + MockInputStream(kj::ArrayPtr bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr bytes; + size_t blockSize; +}; + +class MockAsyncInputStream: public AsyncInputStream { +public: + MockAsyncInputStream(kj::ArrayPtr bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr bytes; + size_t blockSize; +}; + +class MockOutputStream: public OutputStream { +public: + kj::Vector bytes; + + kj::String decompress() { + MockInputStream rawInput(bytes, kj::maxValue); + BrotliInputStream brotli(rawInput); + return brotli.readAllText(); + } + + void write(const void* buffer, size_t size) override { + bytes.addAll(arrayPtr(reinterpret_cast(buffer), size)); + } + void write(ArrayPtr> pieces) override { + for (auto& piece: pieces) { + bytes.addAll(piece); + } + } +}; + +class MockAsyncOutputStream: public AsyncOutputStream { +public: + kj::Vector bytes; + + kj::String decompress(WaitScope& ws) { + MockAsyncInputStream rawInput(bytes, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput); + return brotli.readAllText().wait(ws); + } + + Promise write(const void* buffer, size_t size) override { + bytes.addAll(arrayPtr(reinterpret_cast(buffer), size)); + return kj::READY_NOW; + } + Promise write(ArrayPtr> pieces) override { + for (auto& piece: pieces) { + bytes.addAll(piece); + } + return kj::READY_NOW; + } + + Promise whenWriteDisconnected() override { KJ_UNIMPLEMENTED("not used"); } +}; + +KJ_TEST("brotli decompression") { + // Normal read. + { + MockInputStream rawInput(FOOBAR_BR, kj::maxValue); + BrotliInputStream brotli(rawInput); + KJ_EXPECT(brotli.readAllText() == "foobar"); + } + + // Force read one byte at a time. + { + MockInputStream rawInput(FOOBAR_BR, 1); + BrotliInputStream brotli(rawInput); + KJ_EXPECT(brotli.readAllText() == "foobar"); + } + + // Read truncated input. + { + MockInputStream rawInput(kj::arrayPtr(FOOBAR_BR, sizeof(FOOBAR_BR) / 2), kj::maxValue); + BrotliInputStream brotli(rawInput); + + char text[16]; + size_t n = brotli.tryRead(text, 1, sizeof(text)); + text[n] = '\0'; + KJ_EXPECT(StringPtr(text, n) == "fo"); + + KJ_EXPECT_THROW_MESSAGE("brotli compressed stream ended prematurely", + brotli.tryRead(text, 1, sizeof(text))); + } + + // Check that stream with high window size is rejected. Conversely, check that it is accepted if + // configured to accept the full window size. + { + MockInputStream rawInput(FOOBAR_BR_LARGE_WIN, kj::maxValue); + BrotliInputStream brotli(rawInput, BROTLI_DEFAULT_WINDOW); + KJ_EXPECT_THROW_MESSAGE("brotli window size too big", brotli.readAllText()); + } + + { + MockInputStream rawInput(FOOBAR_BR_LARGE_WIN, kj::maxValue); + BrotliInputStream brotli(rawInput, BROTLI_MAX_WINDOW_BITS); + KJ_EXPECT(brotli.readAllText() == "foobar"); + } + + // Check that invalid stream is rejected. + { + MockInputStream rawInput(kj::arrayPtr(FOOBAR_BR + 3, sizeof(FOOBAR_BR) - 3), kj::maxValue); + BrotliInputStream brotli(rawInput); + KJ_EXPECT_THROW_MESSAGE("brotli decompression failed", brotli.readAllText()); + } + + // Read concatenated input. + { + Vector bytes; + bytes.addAll(ArrayPtr(FOOBAR_BR)); + bytes.addAll(ArrayPtr(FOOBAR_BR)); + MockInputStream rawInput(bytes, kj::maxValue); + BrotliInputStream brotli(rawInput); + + KJ_EXPECT(brotli.readAllText() == "foobarfoobar"); + } +} + +KJ_TEST("async brotli decompression") { + auto io = setupAsyncIo(); + + // Normal read. + { + MockAsyncInputStream rawInput(FOOBAR_BR, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput); + KJ_EXPECT(brotli.readAllText().wait(io.waitScope) == "foobar"); + } + + // Force read one byte at a time. + { + MockAsyncInputStream rawInput(FOOBAR_BR, 1); + BrotliAsyncInputStream brotli(rawInput); + KJ_EXPECT(brotli.readAllText().wait(io.waitScope) == "foobar"); + } + + // Read truncated input. + { + MockAsyncInputStream rawInput(kj::arrayPtr(FOOBAR_BR, sizeof(FOOBAR_BR) / 2), kj::maxValue); + BrotliAsyncInputStream brotli(rawInput); + + char text[16]; + size_t n = brotli.tryRead(text, 1, sizeof(text)).wait(io.waitScope); + text[n] = '\0'; + KJ_EXPECT(StringPtr(text, n) == "fo"); + + KJ_EXPECT_THROW_MESSAGE("brotli compressed stream ended prematurely", + brotli.tryRead(text, 1, sizeof(text)).wait(io.waitScope)); + } + + // Check that stream with high window size is rejected. Conversely, check that it is accepted if + // configured to accept the full window size. + { + MockAsyncInputStream rawInput(FOOBAR_BR_LARGE_WIN, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput, BROTLI_DEFAULT_WINDOW); + KJ_EXPECT_THROW_MESSAGE("brotli window size too big", + brotli.readAllText().wait(io.waitScope)); + } + + { + MockAsyncInputStream rawInput(FOOBAR_BR_LARGE_WIN, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput, BROTLI_MAX_WINDOW_BITS); + KJ_EXPECT(brotli.readAllText().wait(io.waitScope) == "foobar"); + } + + // Read concatenated input. + { + Vector bytes; + bytes.addAll(ArrayPtr(FOOBAR_BR)); + bytes.addAll(ArrayPtr(FOOBAR_BR)); + MockAsyncInputStream rawInput(bytes, kj::maxValue); + BrotliAsyncInputStream brotli(rawInput); + + KJ_EXPECT(brotli.readAllText().wait(io.waitScope) == "foobarfoobar"); + } + + // Decompress using an output stream. + { + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotli(rawOutput, BrotliAsyncOutputStream::DECOMPRESS); + + auto mid = sizeof(FOOBAR_BR) / 2; + brotli.write(FOOBAR_BR, mid).wait(io.waitScope); + auto str1 = kj::heapString(rawOutput.bytes.asPtr().asChars()); + KJ_EXPECT(str1 == "fo", str1); + + brotli.write(FOOBAR_BR + mid, sizeof(FOOBAR_BR) - mid).wait(io.waitScope); + auto str2 = kj::heapString(rawOutput.bytes.asPtr().asChars()); + KJ_EXPECT(str2 == "foobar", str2); + + brotli.end().wait(io.waitScope); + } +} + +KJ_TEST("brotli compression") { + // Normal write. + { + MockOutputStream rawOutput; + { + BrotliOutputStream brotli(rawOutput); + brotli.write("foobar", 6); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } + + // Multi-part write. + { + MockOutputStream rawOutput; + { + BrotliOutputStream brotli(rawOutput); + brotli.write("foo", 3); + brotli.write("bar", 3); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } + + // Array-of-arrays write. + { + MockOutputStream rawOutput; + + { + BrotliOutputStream brotli(rawOutput); + + ArrayPtr pieces[] = { + kj::StringPtr("foo").asBytes(), + kj::StringPtr("bar").asBytes(), + }; + brotli.write(pieces); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } +} + +KJ_TEST("brotli huge round trip") { + auto bytes = heapArray(96*1024); + for (auto& b: bytes) { + b = rand(); + } + + MockOutputStream rawOutput; + { + BrotliOutputStream brotliOut(rawOutput); + brotliOut.write(bytes.begin(), bytes.size()); + } + + MockInputStream rawInput(rawOutput.bytes, kj::maxValue); + BrotliInputStream brotliIn(rawInput); + auto decompressed = brotliIn.readAllBytes(); + + KJ_ASSERT(decompressed.size() == bytes.size()); + KJ_ASSERT(memcmp(bytes.begin(), decompressed.begin(), bytes.size()) == 0); +} + +KJ_TEST("async brotli compression") { + auto io = setupAsyncIo(); + // Normal write. + { + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotli(rawOutput); + brotli.write("foobar", 6).wait(io.waitScope); + brotli.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } + + // Multi-part write. + { + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotli(rawOutput); + + brotli.write("foo", 3).wait(io.waitScope); + auto prevSize = rawOutput.bytes.size(); + + brotli.write("bar", 3).wait(io.waitScope); + auto curSize = rawOutput.bytes.size(); + KJ_EXPECT(prevSize == curSize, prevSize, curSize); + + brotli.flush().wait(io.waitScope); + curSize = rawOutput.bytes.size(); + KJ_EXPECT(prevSize < curSize, prevSize, curSize); + + brotli.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } + + // Array-of-arrays write. + { + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotli(rawOutput); + + ArrayPtr pieces[] = { + kj::StringPtr("foo").asBytes(), + kj::StringPtr("bar").asBytes(), + }; + brotli.write(pieces).wait(io.waitScope); + brotli.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } +} + +KJ_TEST("async brotli huge round trip") { + auto io = setupAsyncIo(); + + auto bytes = heapArray(65536); + for (auto& b: bytes) { + b = rand(); + } + + MockAsyncOutputStream rawOutput; + BrotliAsyncOutputStream brotliOut(rawOutput); + brotliOut.write(bytes.begin(), bytes.size()).wait(io.waitScope); + brotliOut.end().wait(io.waitScope); + + MockAsyncInputStream rawInput(rawOutput.bytes, kj::maxValue); + BrotliAsyncInputStream brotliIn(rawInput); + auto decompressed = brotliIn.readAllBytes().wait(io.waitScope); + + KJ_ASSERT(decompressed.size() == bytes.size()); + KJ_ASSERT(memcmp(bytes.begin(), decompressed.begin(), bytes.size()) == 0); +} + +} // namespace +} // namespace kj + +#endif // KJ_HAS_BROTLI diff --git a/c++/src/kj/compat/brotli.c++ b/c++/src/kj/compat/brotli.c++ new file mode 100644 index 0000000000..08efc8abfa --- /dev/null +++ b/c++/src/kj/compat/brotli.c++ @@ -0,0 +1,369 @@ +// Copyright (c) 2023 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if KJ_HAS_BROTLI + +#include "brotli.h" +#include + +namespace kj { + +namespace { + +int getBrotliWindowBits(kj::byte peek) { + // Check number of window bits used by the stream, see RFC 7932 + // (https://www.rfc-editor.org/rfc/rfc7932.html#section-9.1) for the specification. + // Adapted from an internal Cloudflare codebase. + if ((peek & 0x01) == 0) { + return 16; + } + + if (((peek >> 1) & 0x07) != 0) { + return 17 + (peek >> 1 & 0x07); + } + + if (((peek >> 4) & 0x07) == 0) { + return 17; + } + + if (((peek >> 4) & 0x07) == 1) { + // Large window brotli, not part of RFC 7932 and not supported in web contexts + return BROTLI_MAX_WINDOW_BITS + 1; + } + + return 8 + ((peek >> 4) & 0x07); +} + +} // namespace + +namespace _ { // private + +BrotliOutputContext::BrotliOutputContext(kj::Maybe compressionLevel, + kj::Maybe windowBitsParam) + : nextIn(nullptr), availableIn(0) { + KJ_IF_MAYBE(level, compressionLevel) { + // Emulate zlib's behavior of using -1 to signify the default quality + if (*level == -1) {*level = KJ_BROTLI_DEFAULT_QUALITY;} + KJ_REQUIRE(*level >= BROTLI_MIN_QUALITY && *level <= BROTLI_MAX_QUALITY, + "invalid brotli compression level", *level); + windowBits = windowBitsParam.orDefault(_::KJ_BROTLI_DEFAULT_WBITS); + KJ_REQUIRE(windowBits >= BROTLI_MIN_WINDOW_BITS && windowBits <= BROTLI_MAX_WINDOW_BITS, + "invalid brotli window size", windowBits); + BrotliEncoderState* cctx = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(cctx, "brotli state allocation failed"); + KJ_ASSERT(BrotliEncoderSetParameter(cctx, BROTLI_PARAM_QUALITY, *level) == BROTLI_TRUE); + KJ_ASSERT(BrotliEncoderSetParameter(cctx, BROTLI_PARAM_LGWIN, windowBits) == BROTLI_TRUE); + ctx = cctx; + } else { + // In the decoder, we manually check that the stream does not have a higher window size than + // requested and reject it otherwise, no way to automate this step. + // By default, we accept streams with a window size up to (1 << KJ_BROTLI_MAX_DEC_WBITS), + // this is more than the default window size for compression (i.e. KJ_BROTLI_DEFAULT_WBITS) + windowBits = windowBitsParam.orDefault(_::KJ_BROTLI_MAX_DEC_WBITS); + KJ_REQUIRE(windowBits >= BROTLI_MIN_WINDOW_BITS && windowBits <= BROTLI_MAX_WINDOW_BITS, + "invalid brotli window size", windowBits); + BrotliDecoderState* dctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(dctx, "brotli state allocation failed"); + ctx = dctx; + } +} + +BrotliOutputContext::~BrotliOutputContext() noexcept(false) { + KJ_SWITCH_ONEOF(ctx) { + KJ_CASE_ONEOF(cctx, BrotliEncoderState*) { + BrotliEncoderDestroyInstance(cctx); + } + KJ_CASE_ONEOF(dctx, BrotliDecoderState*) { + BrotliDecoderDestroyInstance(dctx); + } + } +} + +void BrotliOutputContext::setInput(const void* in, size_t size) { + nextIn = reinterpret_cast(in); + availableIn = size; +} + +kj::Tuple> BrotliOutputContext::pumpOnce( + BrotliEncoderOperation flush) { + byte* nextOut = buffer; + size_t availableOut = sizeof(buffer); + // Brotli does not accept a null input pointer; make sure there is a valid pointer even if we are + // not actually reading from it. + if (!nextIn) { + KJ_ASSERT(availableIn == 0); + nextIn = buffer; + } + + KJ_SWITCH_ONEOF(ctx) { + KJ_CASE_ONEOF(dctx, BrotliDecoderState*) { + // Check window bits + if (firstInput && availableIn) { + firstInput = false; + int streamWbits = getBrotliWindowBits(nextIn[0]); + KJ_REQUIRE(streamWbits <= windowBits, "brotli window size too big", (1 << streamWbits)); + } + BrotliDecoderResult result = BrotliDecoderDecompressStream(dctx, &availableIn, &nextIn, + &availableOut, &nextOut, nullptr); + if (result == BROTLI_DECODER_RESULT_ERROR) { + // Note: Unlike BrotliInputStream, this will implicitly reject trailing data during + // decompression, matching the behavior for gzip. + KJ_FAIL_REQUIRE("brotli decompression failed", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(dctx))); + } + // The 'ok' parameter represented by the first parameter of the tuple indicates that + // pumpOnce() should be called again as more output data can be produced. This is the case + // when the stream is not finished and there is either pending output data (that didn't fit + // into the buffer) or input that has not been processed yet. + return kj::tuple(BrotliDecoderHasMoreOutput(dctx), + kj::arrayPtr(buffer, sizeof(buffer) - availableOut)); + } + KJ_CASE_ONEOF(cctx, BrotliEncoderState*) { + BROTLI_BOOL result = BrotliEncoderCompressStream(cctx, flush, &availableIn, &nextIn, + &availableOut, &nextOut, nullptr); + KJ_REQUIRE(result == BROTLI_TRUE, "brotli compression failed"); + + return kj::tuple(BrotliEncoderHasMoreOutput(cctx), + kj::arrayPtr(buffer, sizeof(buffer) - availableOut)); + } + } + KJ_UNREACHABLE; +} + +} // namespace _ (private) + +// ======================================================================================= + +BrotliInputStream::BrotliInputStream(InputStream& inner, kj::Maybe windowBitsParam) + : inner(inner), windowBits(windowBitsParam.orDefault(_::KJ_BROTLI_MAX_DEC_WBITS)), + nextIn(nullptr), availableIn(0) { + KJ_REQUIRE(windowBits >= BROTLI_MIN_WINDOW_BITS && windowBits <= BROTLI_MAX_WINDOW_BITS, + "invalid brotli window size", windowBits); + ctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(ctx, "brotli state allocation failed"); +} + +BrotliInputStream::~BrotliInputStream() noexcept(false) { + BrotliDecoderDestroyInstance(ctx); +} + +size_t BrotliInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) { + if (maxBytes == 0) return size_t(0); + + return readImpl(reinterpret_cast(out), minBytes, maxBytes, 0); +} + +size_t BrotliInputStream::readImpl( + byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) { + // Ask for more input unless there is pending output + if (availableIn == 0 && !BrotliDecoderHasMoreOutput(ctx)) { + size_t amount = inner.tryRead(buffer, 1, sizeof(buffer)); + if (amount == 0) { + KJ_REQUIRE(atValidEndpoint, "brotli compressed stream ended prematurely"); + return alreadyRead; + } else { + nextIn = buffer; + availableIn = amount; + } + } + + byte* nextOut = out; + size_t availableOut = maxBytes; + // Check window bits + if (firstInput && availableIn) { + firstInput = false; + int streamWbits = getBrotliWindowBits(nextIn[0]); + KJ_REQUIRE(streamWbits <= windowBits, + "brotli window size too big", (1 << streamWbits)); + } + BrotliDecoderResult result = BrotliDecoderDecompressStream( + ctx, &availableIn, &nextIn, &availableOut, &nextOut, nullptr); + KJ_REQUIRE(result != BROTLI_DECODER_RESULT_ERROR, "brotli decompression failed", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(ctx))); + + atValidEndpoint = result == BROTLI_DECODER_RESULT_SUCCESS; + if (atValidEndpoint && availableIn > 0) { + // There's more data available. Assume start of new content. + // Not sure if we actually want this, but there is limited potential for breakage as arbitrary + // trailing data should still be rejected. Unfortunately this is kind of clunky as brotli does + // not support resetting an instance. + BrotliDecoderDestroyInstance(ctx); + ctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(ctx, "brotli state allocation failed"); + firstInput = true; + } + + size_t n = maxBytes - availableOut; + if (n >= minBytes) { + return n + alreadyRead; + } else { + return readImpl(out + n, minBytes - n, maxBytes - n, alreadyRead + n); + } +} + +BrotliOutputStream::BrotliOutputStream(OutputStream& inner, int compressionLevel, int windowBits) + : inner(inner), ctx(compressionLevel, windowBits) {} + +BrotliOutputStream::BrotliOutputStream(OutputStream& inner, decltype(DECOMPRESS), int windowBits) + : inner(inner), ctx(nullptr, windowBits) {} + +BrotliOutputStream::~BrotliOutputStream() noexcept(false) { + pump(BROTLI_OPERATION_FINISH); +} + +void BrotliOutputStream::write(const void* in, size_t size) { + ctx.setInput(in, size); + pump(BROTLI_OPERATION_PROCESS); +} + +void BrotliOutputStream::pump(BrotliEncoderOperation flush) { + bool ok; + do { + auto result = ctx.pumpOnce(flush); + ok = get<0>(result); + auto chunk = get<1>(result); + if (chunk.size() > 0) { + inner.write(chunk.begin(), chunk.size()); + } + } while (ok); +} + +// ======================================================================================= + +BrotliAsyncInputStream::BrotliAsyncInputStream(AsyncInputStream& inner, + kj::Maybe windowBitsParam) + : inner(inner), windowBits(windowBitsParam.orDefault(_::KJ_BROTLI_MAX_DEC_WBITS)), + nextIn(nullptr), availableIn(0) { + KJ_REQUIRE(windowBits >= BROTLI_MIN_WINDOW_BITS && windowBits <= BROTLI_MAX_WINDOW_BITS, + "invalid brotli window size", windowBits); + ctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(ctx, "brotli state allocation failed"); +} + +BrotliAsyncInputStream::~BrotliAsyncInputStream() noexcept(false) { + BrotliDecoderDestroyInstance(ctx); +} + +Promise BrotliAsyncInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) { + if (maxBytes == 0) return constPromise(); + + return readImpl(reinterpret_cast(out), minBytes, maxBytes, 0); +} + +Promise BrotliAsyncInputStream::readImpl( + byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) { + // Ask for more input unless there is pending output + if (availableIn == 0 && !BrotliDecoderHasMoreOutput(ctx)) { + return inner.tryRead(buffer, 1, sizeof(buffer)) + .then([this,out,minBytes,maxBytes,alreadyRead](size_t amount) -> Promise { + if (amount == 0) { + if (!atValidEndpoint) { + return KJ_EXCEPTION(DISCONNECTED, "brotli compressed stream ended prematurely"); + } + return alreadyRead; + } else { + nextIn = buffer; + availableIn = amount; + return readImpl(out, minBytes, maxBytes, alreadyRead); + } + }); + } + + byte* nextOut = out; + size_t availableOut = maxBytes; + // Check window bits + if (firstInput && availableIn) { + firstInput = false; + int streamWbits = getBrotliWindowBits(nextIn[0]); + KJ_REQUIRE(streamWbits <= windowBits, + "brotli window size too big", (1 << streamWbits)); + } + BrotliDecoderResult result = BrotliDecoderDecompressStream( + ctx, &availableIn, &nextIn, &availableOut, &nextOut, nullptr); + KJ_REQUIRE(result != BROTLI_DECODER_RESULT_ERROR, "brotli decompression failed", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(ctx))); + + atValidEndpoint = result == BROTLI_DECODER_RESULT_SUCCESS; + if (atValidEndpoint && availableIn > 0) { + // There's more data available. Assume start of new content. + // Not sure if we actually want this, but there is limited potential for breakage as arbitrary + // trailing data should still be rejected. Unfortunately this is kind of clunky as brotli does + // not support resetting an instance. + BrotliDecoderDestroyInstance(ctx); + ctx = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + KJ_REQUIRE(ctx, "brotli state allocation failed"); + firstInput = true; + } + + size_t n = maxBytes - availableOut; + if (n >= minBytes) { + return n + alreadyRead; + } else { + return readImpl(out + n, minBytes - n, maxBytes - n, alreadyRead + n); + } +} + +// ======================================================================================= + +BrotliAsyncOutputStream::BrotliAsyncOutputStream(AsyncOutputStream& inner, int compressionLevel, + int windowBits) + : inner(inner), ctx(compressionLevel, windowBits) {} + +BrotliAsyncOutputStream::BrotliAsyncOutputStream(AsyncOutputStream& inner, decltype(DECOMPRESS), + int windowBits) + : inner(inner), ctx(nullptr, windowBits) {} + +Promise BrotliAsyncOutputStream::write(const void* in, size_t size) { + ctx.setInput(in, size); + return pump(BROTLI_OPERATION_PROCESS); +} + +Promise BrotliAsyncOutputStream::write(ArrayPtr> pieces) { + if (pieces.size() == 0) return kj::READY_NOW; + return write(pieces[0].begin(), pieces[0].size()) + .then([this,pieces]() { + return write(pieces.slice(1, pieces.size())); + }); +} + +kj::Promise BrotliAsyncOutputStream::pump(BrotliEncoderOperation flush) { + auto result = ctx.pumpOnce(flush); + auto ok = get<0>(result); + auto chunk = get<1>(result); + + if (chunk.size() == 0) { + if (ok) { + return pump(flush); + } else { + return kj::READY_NOW; + } + } else { + auto promise = inner.write(chunk.begin(), chunk.size()); + if (ok) { + promise = promise.then([this, flush]() { return pump(flush); }); + } + return promise; + } +} + +} // namespace kj + +#endif // KJ_HAS_BROTLI diff --git a/c++/src/kj/compat/brotli.h b/c++/src/kj/compat/brotli.h new file mode 100644 index 0000000000..3fd2181b5c --- /dev/null +++ b/c++/src/kj/compat/brotli.h @@ -0,0 +1,190 @@ +// Copyright (c) 2023 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include +#include +#include +#include + +KJ_BEGIN_HEADER + +namespace kj { + +// level 5 should offer a good default tradeoff based on concerns about being slower than gzip at +// e.g. level 6 and about compressing worse than gzip at lower levels. Note that +// BROTLI_DEFAULT_QUALITY is set to the maximum level of 11 – way too slow for on-the-fly +// compression. +constexpr size_t KJ_BROTLI_DEFAULT_QUALITY = 5; + +namespace _ { // private +// Use a window size of (1 << 19) = 512K by default. Higher values improve compression on longer +// streams but increase memory usage. +constexpr size_t KJ_BROTLI_DEFAULT_WBITS = 19; + +// Maximum window size for streams to be decompressed, streams with larger windows are rejected. +// This is currently set to the maximum window size of 16MB, so all RFC 7932-compliant brotli +// streams will be accepted. For applications where memory usage is a concern, using +// BROTLI_DEFAULT_WINDOW (equivalent to 4MB window) is recommended instead as larger window sizes +// are rarely useful in a web context. +constexpr size_t KJ_BROTLI_MAX_DEC_WBITS = BROTLI_MAX_WINDOW_BITS; + +// Use an output buffer size of 8K, larger sizes did not seem to significantly improve performance, +// perhaps due to brotli's internal output buffer. +constexpr size_t KJ_BROTLI_BUF_SIZE = 8192; + +class BrotliOutputContext final { +public: + BrotliOutputContext(kj::Maybe compressionLevel, kj::Maybe windowBits = nullptr); + ~BrotliOutputContext() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(BrotliOutputContext); + + void setInput(const void* in, size_t size); + kj::Tuple> pumpOnce(BrotliEncoderOperation flush); + // Flush the stream. Parameter is ignored for decoding as brotli only uses an operation parameter + // during encoding. + +private: + int windowBits; + const byte* nextIn; + size_t availableIn; + bool firstInput = true; + + kj::OneOf ctx; + byte buffer[_::KJ_BROTLI_BUF_SIZE]; +}; + +} // namespace _ (private) + +class BrotliInputStream final: public InputStream { +public: + BrotliInputStream(InputStream& inner, kj::Maybe windowBits = nullptr); + ~BrotliInputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(BrotliInputStream); + + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + +private: + InputStream& inner; + BrotliDecoderState* ctx; + int windowBits; + bool atValidEndpoint = false; + + byte buffer[_::KJ_BROTLI_BUF_SIZE]; + + const byte* nextIn; + size_t availableIn; + bool firstInput = true; + + size_t readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); +}; + +class BrotliOutputStream final: public OutputStream { +public: + enum { DECOMPRESS }; + + // Order of arguments is not ideal, but allows us to specify the window size if needed while + // remaining compatible with the gzip API. + BrotliOutputStream(OutputStream& inner, int compressionLevel = KJ_BROTLI_DEFAULT_QUALITY, + int windowBits = _::KJ_BROTLI_DEFAULT_WBITS); + BrotliOutputStream(OutputStream& inner, decltype(DECOMPRESS), + int windowBits = _::KJ_BROTLI_MAX_DEC_WBITS); + ~BrotliOutputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(BrotliOutputStream); + + void write(const void* buffer, size_t size) override; + using OutputStream::write; + + inline void flush() { + // brotli decoder does not use this parameter, but automatically flushes as much as it can. + pump(BROTLI_OPERATION_FLUSH); + } + +private: + OutputStream& inner; + _::BrotliOutputContext ctx; + + void pump(BrotliEncoderOperation flush); +}; + +class BrotliAsyncInputStream final: public AsyncInputStream { +public: + BrotliAsyncInputStream(AsyncInputStream& inner, kj::Maybe windowBits = nullptr); + ~BrotliAsyncInputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(BrotliAsyncInputStream); + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + +private: + AsyncInputStream& inner; + BrotliDecoderState* ctx; + int windowBits; + bool atValidEndpoint = false; + + byte buffer[_::KJ_BROTLI_BUF_SIZE]; + const byte* nextIn; + size_t availableIn; + bool firstInput = true; + + Promise readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); +}; + +class BrotliAsyncOutputStream final: public AsyncOutputStream { +public: + enum { DECOMPRESS }; + + BrotliAsyncOutputStream(AsyncOutputStream& inner, + int compressionLevel = KJ_BROTLI_DEFAULT_QUALITY, + int windowBits = _::KJ_BROTLI_DEFAULT_WBITS); + BrotliAsyncOutputStream(AsyncOutputStream& inner, decltype(DECOMPRESS), + int windowBits = _::KJ_BROTLI_MAX_DEC_WBITS); + KJ_DISALLOW_COPY_AND_MOVE(BrotliAsyncOutputStream); + + Promise write(const void* buffer, size_t size) override; + Promise write(ArrayPtr> pieces) override; + + Promise whenWriteDisconnected() override { return inner.whenWriteDisconnected(); } + + inline Promise flush() { + // brotli decoder does not use this parameter, but automatically flushes as much as it can. + return pump(BROTLI_OPERATION_FLUSH); + } + // Call if you need to flush a stream at an arbitrary data point. + + Promise end() { + return pump(BROTLI_OPERATION_FINISH); + } + // Must call to flush and finish the stream, since some data may be buffered. + // + // TODO(cleanup): This should be a virtual method on AsyncOutputStream. + +private: + AsyncOutputStream& inner; + _::BrotliOutputContext ctx; + + kj::Promise pump(BrotliEncoderOperation flush); +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/compat/gtest.h b/c++/src/kj/compat/gtest.h index 016dbdfac3..4db0535c35 100644 --- a/c++/src/kj/compat/gtest.h +++ b/c++/src/kj/compat/gtest.h @@ -19,8 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_COMPAT_GTEST_H_ -#define KJ_COMPAT_GTEST_H_ +#pragma once // This file defines compatibility macros converting Google Test tests into KJ tests. // // This is only intended to cover the most common functionality. Many tests will likely need @@ -30,7 +29,10 @@ // - Test fixtures are not supported. Allocate your "test fixture" on the stack instead. Do setup // in the constructor, teardown in the destructor. -#include "../test.h" +#include +#include // work-around macro conflict with `ERROR` + +KJ_BEGIN_HEADER namespace kj { @@ -119,4 +121,4 @@ class AddFailureAdapter { } // namespace kj -#endif // KJ_COMPAT_GTEST_H_ +KJ_END_HEADER diff --git a/c++/src/kj/compat/gzip-test.c++ b/c++/src/kj/compat/gzip-test.c++ new file mode 100644 index 0000000000..09bb0c05be --- /dev/null +++ b/c++/src/kj/compat/gzip-test.c++ @@ -0,0 +1,370 @@ +// Copyright (c) 2017 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if KJ_HAS_ZLIB + +#include "gzip.h" +#include +#include +#include + +namespace kj { +namespace { + +static const byte FOOBAR_GZIP[] = { + 0x1F, 0x8B, 0x08, 0x00, 0xF9, 0x05, 0xB7, 0x59, + 0x00, 0x03, 0x4B, 0xCB, 0xCF, 0x4F, 0x4A, 0x2C, + 0x02, 0x00, 0x95, 0x1F, 0xF6, 0x9E, 0x06, 0x00, + 0x00, 0x00, +}; + +class MockInputStream: public InputStream { +public: + MockInputStream(kj::ArrayPtr bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr bytes; + size_t blockSize; +}; + +class MockAsyncInputStream: public AsyncInputStream { +public: + MockAsyncInputStream(kj::ArrayPtr bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr bytes; + size_t blockSize; +}; + +class MockOutputStream: public OutputStream { +public: + kj::Vector bytes; + + kj::String decompress() { + MockInputStream rawInput(bytes, kj::maxValue); + GzipInputStream gzip(rawInput); + return gzip.readAllText(); + } + + void write(const void* buffer, size_t size) override { + bytes.addAll(arrayPtr(reinterpret_cast(buffer), size)); + } + void write(ArrayPtr> pieces) override { + for (auto& piece: pieces) { + bytes.addAll(piece); + } + } +}; + +class MockAsyncOutputStream: public AsyncOutputStream { +public: + kj::Vector bytes; + + kj::String decompress(WaitScope& ws) { + MockAsyncInputStream rawInput(bytes, kj::maxValue); + GzipAsyncInputStream gzip(rawInput); + return gzip.readAllText().wait(ws); + } + + Promise write(const void* buffer, size_t size) override { + bytes.addAll(arrayPtr(reinterpret_cast(buffer), size)); + return kj::READY_NOW; + } + Promise write(ArrayPtr> pieces) override { + for (auto& piece: pieces) { + bytes.addAll(piece); + } + return kj::READY_NOW; + } + + Promise whenWriteDisconnected() override { KJ_UNIMPLEMENTED("not used"); } +}; + +KJ_TEST("gzip decompression") { + // Normal read. + { + MockInputStream rawInput(FOOBAR_GZIP, kj::maxValue); + GzipInputStream gzip(rawInput); + KJ_EXPECT(gzip.readAllText() == "foobar"); + } + + // Force read one byte at a time. + { + MockInputStream rawInput(FOOBAR_GZIP, 1); + GzipInputStream gzip(rawInput); + KJ_EXPECT(gzip.readAllText() == "foobar"); + } + + // Read truncated input. + { + MockInputStream rawInput(kj::arrayPtr(FOOBAR_GZIP, sizeof(FOOBAR_GZIP) / 2), kj::maxValue); + GzipInputStream gzip(rawInput); + + char text[16]; + size_t n = gzip.tryRead(text, 1, sizeof(text)); + text[n] = '\0'; + KJ_EXPECT(StringPtr(text, n) == "fo"); + + KJ_EXPECT_THROW_MESSAGE("gzip compressed stream ended prematurely", + gzip.tryRead(text, 1, sizeof(text))); + } + + // Read concatenated input. + { + Vector bytes; + bytes.addAll(ArrayPtr(FOOBAR_GZIP)); + bytes.addAll(ArrayPtr(FOOBAR_GZIP)); + MockInputStream rawInput(bytes, kj::maxValue); + GzipInputStream gzip(rawInput); + + KJ_EXPECT(gzip.readAllText() == "foobarfoobar"); + } +} + +KJ_TEST("async gzip decompression") { + auto io = setupAsyncIo(); + + // Normal read. + { + MockAsyncInputStream rawInput(FOOBAR_GZIP, kj::maxValue); + GzipAsyncInputStream gzip(rawInput); + KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobar"); + } + + // Force read one byte at a time. + { + MockAsyncInputStream rawInput(FOOBAR_GZIP, 1); + GzipAsyncInputStream gzip(rawInput); + KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobar"); + } + + // Read truncated input. + { + MockAsyncInputStream rawInput(kj::arrayPtr(FOOBAR_GZIP, sizeof(FOOBAR_GZIP) / 2), kj::maxValue); + GzipAsyncInputStream gzip(rawInput); + + char text[16]; + size_t n = gzip.tryRead(text, 1, sizeof(text)).wait(io.waitScope); + text[n] = '\0'; + KJ_EXPECT(StringPtr(text, n) == "fo"); + + KJ_EXPECT_THROW_MESSAGE("gzip compressed stream ended prematurely", + gzip.tryRead(text, 1, sizeof(text)).wait(io.waitScope)); + } + + // Read concatenated input. + { + Vector bytes; + bytes.addAll(ArrayPtr(FOOBAR_GZIP)); + bytes.addAll(ArrayPtr(FOOBAR_GZIP)); + MockAsyncInputStream rawInput(bytes, kj::maxValue); + GzipAsyncInputStream gzip(rawInput); + + KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobarfoobar"); + } + + // Decompress using an output stream. + { + MockAsyncOutputStream rawOutput; + GzipAsyncOutputStream gzip(rawOutput, GzipAsyncOutputStream::DECOMPRESS); + + auto mid = sizeof(FOOBAR_GZIP) / 2; + gzip.write(FOOBAR_GZIP, mid).wait(io.waitScope); + auto str1 = kj::heapString(rawOutput.bytes.asPtr().asChars()); + KJ_EXPECT(str1 == "fo", str1); + + gzip.write(FOOBAR_GZIP + mid, sizeof(FOOBAR_GZIP) - mid).wait(io.waitScope); + auto str2 = kj::heapString(rawOutput.bytes.asPtr().asChars()); + KJ_EXPECT(str2 == "foobar", str2); + + gzip.end().wait(io.waitScope); + } +} + +KJ_TEST("gzip compression") { + // Normal write. + { + MockOutputStream rawOutput; + { + GzipOutputStream gzip(rawOutput); + gzip.write("foobar", 6); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } + + // Multi-part write. + { + MockOutputStream rawOutput; + { + GzipOutputStream gzip(rawOutput); + gzip.write("foo", 3); + gzip.write("bar", 3); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } + + // Array-of-arrays write. + { + MockOutputStream rawOutput; + + { + GzipOutputStream gzip(rawOutput); + + ArrayPtr pieces[] = { + kj::StringPtr("foo").asBytes(), + kj::StringPtr("bar").asBytes(), + }; + gzip.write(pieces); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } +} + +KJ_TEST("gzip huge round trip") { + auto bytes = heapArray(65536); + for (auto& b: bytes) { + b = rand(); + } + + MockOutputStream rawOutput; + { + GzipOutputStream gzipOut(rawOutput); + gzipOut.write(bytes.begin(), bytes.size()); + } + + MockInputStream rawInput(rawOutput.bytes, kj::maxValue); + GzipInputStream gzipIn(rawInput); + auto decompressed = gzipIn.readAllBytes(); + + KJ_ASSERT(decompressed.size() == bytes.size()); + KJ_ASSERT(memcmp(bytes.begin(), decompressed.begin(), bytes.size()) == 0); +} + +KJ_TEST("async gzip compression") { + auto io = setupAsyncIo(); + + // Normal write. + { + MockAsyncOutputStream rawOutput; + GzipAsyncOutputStream gzip(rawOutput); + gzip.write("foobar", 6).wait(io.waitScope); + gzip.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } + + // Multi-part write. + { + MockAsyncOutputStream rawOutput; + GzipAsyncOutputStream gzip(rawOutput); + + gzip.write("foo", 3).wait(io.waitScope); + auto prevSize = rawOutput.bytes.size(); + + gzip.write("bar", 3).wait(io.waitScope); + auto curSize = rawOutput.bytes.size(); + KJ_EXPECT(prevSize == curSize, prevSize, curSize); + + gzip.flush().wait(io.waitScope); + curSize = rawOutput.bytes.size(); + KJ_EXPECT(prevSize < curSize, prevSize, curSize); + + gzip.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } + + // Array-of-arrays write. + { + MockAsyncOutputStream rawOutput; + GzipAsyncOutputStream gzip(rawOutput); + + ArrayPtr pieces[] = { + kj::StringPtr("foo").asBytes(), + kj::StringPtr("bar").asBytes(), + }; + gzip.write(pieces).wait(io.waitScope); + gzip.end().wait(io.waitScope); + + KJ_EXPECT(rawOutput.decompress(io.waitScope) == "foobar"); + } +} + +KJ_TEST("async gzip huge round trip") { + auto io = setupAsyncIo(); + + auto bytes = heapArray(65536); + for (auto& b: bytes) { + b = rand(); + } + + MockAsyncOutputStream rawOutput; + GzipAsyncOutputStream gzipOut(rawOutput); + gzipOut.write(bytes.begin(), bytes.size()).wait(io.waitScope); + gzipOut.end().wait(io.waitScope); + + MockAsyncInputStream rawInput(rawOutput.bytes, kj::maxValue); + GzipAsyncInputStream gzipIn(rawInput); + auto decompressed = gzipIn.readAllBytes().wait(io.waitScope); + + KJ_ASSERT(decompressed.size() == bytes.size()); + KJ_ASSERT(memcmp(bytes.begin(), decompressed.begin(), bytes.size()) == 0); +} + +} // namespace +} // namespace kj + +#endif // KJ_HAS_ZLIB diff --git a/c++/src/kj/compat/gzip.c++ b/c++/src/kj/compat/gzip.c++ new file mode 100644 index 0000000000..a36cde774f --- /dev/null +++ b/c++/src/kj/compat/gzip.c++ @@ -0,0 +1,283 @@ +// Copyright (c) 2017 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if KJ_HAS_ZLIB + +#include "gzip.h" +#include + +namespace kj { + +namespace _ { // private + +GzipOutputContext::GzipOutputContext(kj::Maybe compressionLevel) { + int initResult; + + KJ_IF_MAYBE(level, compressionLevel) { + compressing = true; + initResult = + deflateInit2(&ctx, *level, Z_DEFLATED, + 15 + 16, // windowBits = 15 (maximum) + magic value 16 to ask for gzip. + 8, // memLevel = 8 (the default) + Z_DEFAULT_STRATEGY); + } else { + compressing = false; + initResult = inflateInit2(&ctx, 15 + 16); + } + + if (initResult != Z_OK) { + fail(initResult); + } +} + +GzipOutputContext::~GzipOutputContext() noexcept(false) { + compressing ? deflateEnd(&ctx) : inflateEnd(&ctx); +} + +void GzipOutputContext::setInput(const void* in, size_t size) { + ctx.next_in = const_cast(reinterpret_cast(in)); + ctx.avail_in = size; +} + +kj::Tuple> GzipOutputContext::pumpOnce(int flush) { + ctx.next_out = buffer; + ctx.avail_out = sizeof(buffer); + + auto result = compressing ? deflate(&ctx, flush) : inflate(&ctx, flush); + if (result != Z_OK && result != Z_BUF_ERROR && result != Z_STREAM_END) { + fail(result); + } + + // - Z_STREAM_END means we have finished the stream successfully. + // - Z_BUF_ERROR means we didn't have any more input to process + // (but still have to make a call to write to potentially flush data). + return kj::tuple(result == Z_OK, kj::arrayPtr(buffer, sizeof(buffer) - ctx.avail_out)); +} + +void GzipOutputContext::fail(int result) { + auto header = compressing ? "gzip compression failed" : "gzip decompression failed"; + if (ctx.msg == nullptr) { + KJ_FAIL_REQUIRE(header, result); + } else { + KJ_FAIL_REQUIRE(header, ctx.msg); + } +} + +} // namespace _ (private) + +GzipInputStream::GzipInputStream(InputStream& inner) + : inner(inner) { + // windowBits = 15 (maximum) + magic value 16 to ask for gzip. + KJ_ASSERT(inflateInit2(&ctx, 15 + 16) == Z_OK); +} + +GzipInputStream::~GzipInputStream() noexcept(false) { + inflateEnd(&ctx); +} + +size_t GzipInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) { + if (maxBytes == 0) return size_t(0); + + return readImpl(reinterpret_cast(out), minBytes, maxBytes, 0); +} + +size_t GzipInputStream::readImpl( + byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) { + if (ctx.avail_in == 0) { + size_t amount = inner.tryRead(buffer, 1, sizeof(buffer)); + // Note: This check would reject valid streams with a high compression ratio if zlib were to + // read in the entire input data, getting more decompressed data than fits in the out buffer + // and subsequently fill the output buffer and internally store some pending data. It turns + // out that zlib does not maintain pending output during decompression and this is not + // possible, but this may be a concern when implementing support for other algorithms as e.g. + // brotli's reference implementation maintains a decompression output buffer. + if (amount == 0) { + if (!atValidEndpoint) { + KJ_FAIL_REQUIRE("gzip compressed stream ended prematurely"); + } + return alreadyRead; + } else { + ctx.next_in = buffer; + ctx.avail_in = amount; + } + } + + ctx.next_out = out; + ctx.avail_out = maxBytes; + + auto inflateResult = inflate(&ctx, Z_NO_FLUSH); + atValidEndpoint = inflateResult == Z_STREAM_END; + if (inflateResult == Z_OK || inflateResult == Z_STREAM_END) { + if (atValidEndpoint && ctx.avail_in > 0) { + // There's more data available. Assume start of new content. + KJ_ASSERT(inflateReset(&ctx) == Z_OK); + } + + size_t n = maxBytes - ctx.avail_out; + if (n >= minBytes) { + return n + alreadyRead; + } else { + return readImpl(out + n, minBytes - n, maxBytes - n, alreadyRead + n); + } + } else { + if (ctx.msg == nullptr) { + KJ_FAIL_REQUIRE("gzip decompression failed", inflateResult); + } else { + KJ_FAIL_REQUIRE("gzip decompression failed", ctx.msg); + } + } +} + +// ======================================================================================= + +GzipOutputStream::GzipOutputStream(OutputStream& inner, int compressionLevel) + : inner(inner), ctx(compressionLevel) {} + +GzipOutputStream::GzipOutputStream(OutputStream& inner, decltype(DECOMPRESS)) + : inner(inner), ctx(nullptr) {} + +GzipOutputStream::~GzipOutputStream() noexcept(false) { + pump(Z_FINISH); +} + +void GzipOutputStream::write(const void* in, size_t size) { + ctx.setInput(in, size); + pump(Z_NO_FLUSH); +} + +void GzipOutputStream::pump(int flush) { + bool ok; + do { + auto result = ctx.pumpOnce(flush); + ok = get<0>(result); + auto chunk = get<1>(result); + if (chunk.size() > 0) { + inner.write(chunk.begin(), chunk.size()); + } + } while (ok); +} + +// ======================================================================================= + +GzipAsyncInputStream::GzipAsyncInputStream(AsyncInputStream& inner) + : inner(inner) { + // windowBits = 15 (maximum) + magic value 16 to ask for gzip. + KJ_ASSERT(inflateInit2(&ctx, 15 + 16) == Z_OK); +} + +GzipAsyncInputStream::~GzipAsyncInputStream() noexcept(false) { + inflateEnd(&ctx); +} + +Promise GzipAsyncInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) { + if (maxBytes == 0) return constPromise(); + + return readImpl(reinterpret_cast(out), minBytes, maxBytes, 0); +} + +Promise GzipAsyncInputStream::readImpl( + byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) { + if (ctx.avail_in == 0) { + return inner.tryRead(buffer, 1, sizeof(buffer)) + .then([this,out,minBytes,maxBytes,alreadyRead](size_t amount) -> Promise { + if (amount == 0) { + if (!atValidEndpoint) { + return KJ_EXCEPTION(DISCONNECTED, "gzip compressed stream ended prematurely"); + } + return alreadyRead; + } else { + ctx.next_in = buffer; + ctx.avail_in = amount; + return readImpl(out, minBytes, maxBytes, alreadyRead); + } + }); + } + + ctx.next_out = out; + ctx.avail_out = maxBytes; + + auto inflateResult = inflate(&ctx, Z_NO_FLUSH); + atValidEndpoint = inflateResult == Z_STREAM_END; + if (inflateResult == Z_OK || inflateResult == Z_STREAM_END) { + if (atValidEndpoint && ctx.avail_in > 0) { + // There's more data available. Assume start of new content. + KJ_ASSERT(inflateReset(&ctx) == Z_OK); + } + + size_t n = maxBytes - ctx.avail_out; + if (n >= minBytes) { + return n + alreadyRead; + } else { + return readImpl(out + n, minBytes - n, maxBytes - n, alreadyRead + n); + } + } else { + if (ctx.msg == nullptr) { + KJ_FAIL_REQUIRE("gzip decompression failed", inflateResult); + } else { + KJ_FAIL_REQUIRE("gzip decompression failed", ctx.msg); + } + } +} + +// ======================================================================================= + +GzipAsyncOutputStream::GzipAsyncOutputStream(AsyncOutputStream& inner, int compressionLevel) + : inner(inner), ctx(compressionLevel) {} + +GzipAsyncOutputStream::GzipAsyncOutputStream(AsyncOutputStream& inner, decltype(DECOMPRESS)) + : inner(inner), ctx(nullptr) {} + +Promise GzipAsyncOutputStream::write(const void* in, size_t size) { + ctx.setInput(in, size); + return pump(Z_NO_FLUSH); +} + +Promise GzipAsyncOutputStream::write(ArrayPtr> pieces) { + if (pieces.size() == 0) return kj::READY_NOW; + return write(pieces[0].begin(), pieces[0].size()) + .then([this,pieces]() { + return write(pieces.slice(1, pieces.size())); + }); +} + +kj::Promise GzipAsyncOutputStream::pump(int flush) { + auto result = ctx.pumpOnce(flush); + auto ok = get<0>(result); + auto chunk = get<1>(result); + + if (chunk.size() == 0) { + if (ok) { + return pump(flush); + } else { + return kj::READY_NOW; + } + } else { + auto promise = inner.write(chunk.begin(), chunk.size()); + if (ok) { + promise = promise.then([this, flush]() { return pump(flush); }); + } + return promise; + } +} + +} // namespace kj + +#endif // KJ_HAS_ZLIB diff --git a/c++/src/kj/compat/gzip.h b/c++/src/kj/compat/gzip.h new file mode 100644 index 0000000000..37b4961fed --- /dev/null +++ b/c++/src/kj/compat/gzip.h @@ -0,0 +1,148 @@ +// Copyright (c) 2017 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include +#include + +KJ_BEGIN_HEADER + +namespace kj { + +namespace _ { // private + +constexpr size_t KJ_GZ_BUF_SIZE = 4096; + +class GzipOutputContext final { +public: + GzipOutputContext(kj::Maybe compressionLevel); + ~GzipOutputContext() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(GzipOutputContext); + + void setInput(const void* in, size_t size); + kj::Tuple> pumpOnce(int flush); + +private: + bool compressing; + z_stream ctx = {}; + byte buffer[_::KJ_GZ_BUF_SIZE]; + + [[noreturn]] void fail(int result); +}; + +} // namespace _ (private) + +class GzipInputStream final: public InputStream { +public: + GzipInputStream(InputStream& inner); + ~GzipInputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(GzipInputStream); + + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + +private: + InputStream& inner; + z_stream ctx = {}; + bool atValidEndpoint = false; + + byte buffer[_::KJ_GZ_BUF_SIZE]; + + size_t readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); +}; + +class GzipOutputStream final: public OutputStream { +public: + enum { DECOMPRESS }; + + GzipOutputStream(OutputStream& inner, int compressionLevel = Z_DEFAULT_COMPRESSION); + GzipOutputStream(OutputStream& inner, decltype(DECOMPRESS)); + ~GzipOutputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(GzipOutputStream); + + void write(const void* buffer, size_t size) override; + using OutputStream::write; + + inline void flush() { + pump(Z_SYNC_FLUSH); + } + +private: + OutputStream& inner; + _::GzipOutputContext ctx; + + void pump(int flush); +}; + +class GzipAsyncInputStream final: public AsyncInputStream { +public: + GzipAsyncInputStream(AsyncInputStream& inner); + ~GzipAsyncInputStream() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(GzipAsyncInputStream); + + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + +private: + AsyncInputStream& inner; + z_stream ctx = {}; + bool atValidEndpoint = false; + + byte buffer[_::KJ_GZ_BUF_SIZE]; + + Promise readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); +}; + +class GzipAsyncOutputStream final: public AsyncOutputStream { +public: + enum { DECOMPRESS }; + + GzipAsyncOutputStream(AsyncOutputStream& inner, int compressionLevel = Z_DEFAULT_COMPRESSION); + GzipAsyncOutputStream(AsyncOutputStream& inner, decltype(DECOMPRESS)); + KJ_DISALLOW_COPY_AND_MOVE(GzipAsyncOutputStream); + + Promise write(const void* buffer, size_t size) override; + Promise write(ArrayPtr> pieces) override; + + Promise whenWriteDisconnected() override { return inner.whenWriteDisconnected(); } + + inline Promise flush() { + return pump(Z_SYNC_FLUSH); + } + // Call if you need to flush a stream at an arbitrary data point. + + Promise end() { + return pump(Z_FINISH); + } + // Must call to flush and finish the stream, since some data may be buffered. + // + // TODO(cleanup): This should be a virtual method on AsyncOutputStream. + +private: + AsyncOutputStream& inner; + _::GzipOutputContext ctx; + + kj::Promise pump(int flush); +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/compat/http-socketpair-test.c++ b/c++/src/kj/compat/http-socketpair-test.c++ new file mode 100644 index 0000000000..67c53b79b2 --- /dev/null +++ b/c++/src/kj/compat/http-socketpair-test.c++ @@ -0,0 +1,25 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Run http-test, but use real OS socketpairs to connect rather than using in-process pipes. +// This is essentially an integration test between KJ HTTP and KJ OS socket handling. +#define KJ_HTTP_TEST_USE_OS_PIPE 1 +#include "http-test.c++" diff --git a/c++/src/kj/compat/http-test.c++ b/c++/src/kj/compat/http-test.c++ index 6fd75232bb..f10ff8d156 100644 --- a/c++/src/kj/compat/http-test.c++ +++ b/c++/src/kj/compat/http-test.c++ @@ -19,19 +19,55 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#define KJ_TESTING_KJ 1 + #include "http.h" #include #include +#include +#include #include +#if KJ_HTTP_TEST_USE_OS_PIPE +// Run the test using OS-level socketpairs. (See http-socketpair-test.c++.) +#define KJ_HTTP_TEST_SETUP_IO \ + auto io = kj::setupAsyncIo(); \ + auto& waitScope KJ_UNUSED = io.waitScope +#define KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR \ + auto listener = io.provider->getNetwork().parseAddress("localhost", 0) \ + .wait(waitScope)->listen(); \ + auto addr = io.provider->getNetwork().parseAddress("localhost", listener->getPort()) \ + .wait(waitScope) +#define KJ_HTTP_TEST_CREATE_2PIPE \ + io.provider->newTwoWayPipe() +#else +// Run the test using in-process two-way pipes. +#define KJ_HTTP_TEST_SETUP_IO \ + kj::EventLoop eventLoop; \ + kj::WaitScope waitScope(eventLoop) +#define KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR \ + auto capPipe = newCapabilityPipe(); \ + auto listener = kj::heap(*capPipe.ends[0]); \ + auto addr = kj::heap(nullptr, *capPipe.ends[1]) +#define KJ_HTTP_TEST_CREATE_2PIPE \ + kj::newTwoWayPipe() +#endif + namespace kj { namespace { KJ_TEST("HttpMethod parse / stringify") { #define TRY(name) \ KJ_EXPECT(kj::str(HttpMethod::name) == #name); \ - KJ_IF_MAYBE(parsed, tryParseHttpMethod(#name)) { \ - KJ_EXPECT(*parsed == HttpMethod::name); \ + KJ_IF_MAYBE(parsed, tryParseHttpMethodAllowingConnect(#name)) { \ + KJ_SWITCH_ONEOF(*parsed) { \ + KJ_CASE_ONEOF(method, HttpMethod) { \ + KJ_EXPECT(method == HttpMethod::name); \ + } \ + KJ_CASE_ONEOF(method, HttpConnectMethod) { \ + KJ_FAIL_EXPECT("http method parsed as CONNECT", #name); \ + } \ + } \ } else { \ KJ_FAIL_EXPECT("couldn't parse \"" #name "\" as HttpMethod"); \ } @@ -45,6 +81,10 @@ KJ_TEST("HttpMethod parse / stringify") { KJ_EXPECT(tryParseHttpMethod("GE") == nullptr); KJ_EXPECT(tryParseHttpMethod("GET ") == nullptr); KJ_EXPECT(tryParseHttpMethod("get") == nullptr); + + KJ_EXPECT(KJ_ASSERT_NONNULL(tryParseHttpMethodAllowingConnect("CONNECT")) + .is()); + KJ_EXPECT(tryParseHttpMethod("connect") == nullptr); } KJ_TEST("HttpHeaderTable") { @@ -107,8 +147,9 @@ KJ_TEST("HttpHeaders::parseRequest") { "Content-Length: 123\r\n" "DATE: early\r\n" "other-Header: yep\r\n" + "with.dots: sure\r\n" "\r\n"); - auto result = KJ_ASSERT_NONNULL(headers.tryParseRequest(text.asArray())); + auto result = headers.tryParseRequest(text.asArray()).get(); KJ_EXPECT(result.method == HttpMethod::POST); KJ_EXPECT(result.url == "/some/path"); @@ -117,26 +158,29 @@ KJ_TEST("HttpHeaders::parseRequest") { KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz"); KJ_EXPECT(headers.get(bazQux) == nullptr); KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr); - KJ_EXPECT(result.connectionHeaders.contentLength == "123"); - KJ_EXPECT(result.connectionHeaders.transferEncoding == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(HttpHeaderId::CONTENT_LENGTH)) == "123"); + KJ_EXPECT(headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr); std::map unpackedHeaders; headers.forEach([&](kj::StringPtr name, kj::StringPtr value) { KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second); }); - KJ_EXPECT(unpackedHeaders.size() == 4); + KJ_EXPECT(unpackedHeaders.size() == 6); + KJ_EXPECT(unpackedHeaders["Content-Length"] == "123"); KJ_EXPECT(unpackedHeaders["Host"] == "example.com"); KJ_EXPECT(unpackedHeaders["Date"] == "early"); KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz"); KJ_EXPECT(unpackedHeaders["other-Header"] == "yep"); + KJ_EXPECT(unpackedHeaders["with.dots"] == "sure"); - KJ_EXPECT(headers.serializeRequest(result.method, result.url, result.connectionHeaders) == + KJ_EXPECT(headers.serializeRequest(result.method, result.url) == "POST /some/path HTTP/1.1\r\n" "Content-Length: 123\r\n" "Host: example.com\r\n" "Date: early\r\n" "Foo-Bar: Baz\r\n" "other-Header: yep\r\n" + "with.dots: sure\r\n" "\r\n"); } @@ -157,7 +201,7 @@ KJ_TEST("HttpHeaders::parseResponse") { "DATE: early\r\n" "other-Header: yep\r\n" "\r\n"); - auto result = KJ_ASSERT_NONNULL(headers.tryParseResponse(text.asArray())); + auto result = headers.tryParseResponse(text.asArray()).get(); KJ_EXPECT(result.statusCode == 418); KJ_EXPECT(result.statusText == "I'm a teapot"); @@ -166,21 +210,22 @@ KJ_TEST("HttpHeaders::parseResponse") { KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz"); KJ_EXPECT(headers.get(bazQux) == nullptr); KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr); - KJ_EXPECT(result.connectionHeaders.contentLength == "123"); - KJ_EXPECT(result.connectionHeaders.transferEncoding == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(HttpHeaderId::CONTENT_LENGTH)) == "123"); + KJ_EXPECT(headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr); std::map unpackedHeaders; headers.forEach([&](kj::StringPtr name, kj::StringPtr value) { KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second); }); - KJ_EXPECT(unpackedHeaders.size() == 4); + KJ_EXPECT(unpackedHeaders.size() == 5); + KJ_EXPECT(unpackedHeaders["Content-Length"] == "123"); KJ_EXPECT(unpackedHeaders["Host"] == "example.com"); KJ_EXPECT(unpackedHeaders["Date"] == "early"); KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz"); KJ_EXPECT(unpackedHeaders["other-Header"] == "yep"); KJ_EXPECT(headers.serializeResponse( - result.statusCode, result.statusText, result.connectionHeaders) == + result.statusCode, result.statusText) == "HTTP/1.1 418 I'm a teapot\r\n" "Content-Length: 123\r\n" "Host: example.com\r\n" @@ -195,40 +240,160 @@ KJ_TEST("HttpHeaders parse invalid") { HttpHeaders headers(*table); // NUL byte in request. - KJ_EXPECT(headers.tryParseRequest(kj::heapString( - "POST \0 /some/path \t HTTP/1.1\r\n" - "Foo-BaR: Baz\r\n" - "Host: example.com\r\n" - "DATE: early\r\n" - "other-Header: yep\r\n" - "\r\n")) == nullptr); + { + auto input = kj::heapString( + "POST \0 /some/path \t HTTP/1.1\r\n" + "Foo-BaR: Baz\r\n" + "Host: example.com\r\n" + "DATE: early\r\n" + "other-Header: yep\r\n" + "\r\n"); + + auto protocolError = headers.tryParseRequest(input).get(); + + KJ_EXPECT(protocolError.description == "Request headers have no terminal newline.", + protocolError.description); + KJ_EXPECT(protocolError.rawContent.asChars() == input); + } // Control character in header name. - KJ_EXPECT(headers.tryParseRequest(kj::heapString( - "POST /some/path \t HTTP/1.1\r\n" - "Foo-BaR: Baz\r\n" - "Cont\001ent-Length: 123\r\n" - "DATE: early\r\n" - "other-Header: yep\r\n" - "\r\n")) == nullptr); + { + auto input = kj::heapString( + "POST /some/path \t HTTP/1.1\r\n" + "Foo-BaR: Baz\r\n" + "Cont\001ent-Length: 123\r\n" + "DATE: early\r\n" + "other-Header: yep\r\n" + "\r\n"); + + auto protocolError = headers.tryParseRequest(input).get(); + + KJ_EXPECT(protocolError.description == "The headers sent by your client are not valid.", + protocolError.description); + KJ_EXPECT(protocolError.rawContent.asChars() == input); + } // Separator character in header name. - KJ_EXPECT(headers.tryParseRequest(kj::heapString( - "POST /some/path \t HTTP/1.1\r\n" - "Foo-BaR: Baz\r\n" - "Host: example.com\r\n" - "DATE/: early\r\n" - "other-Header: yep\r\n" - "\r\n")) == nullptr); + { + auto input = kj::heapString( + "POST /some/path \t HTTP/1.1\r\n" + "Foo-BaR: Baz\r\n" + "Host: example.com\r\n" + "DATE/: early\r\n" + "other-Header: yep\r\n" + "\r\n"); + + auto protocolError = headers.tryParseRequest(input).get(); + + KJ_EXPECT(protocolError.description == "The headers sent by your client are not valid.", + protocolError.description); + KJ_EXPECT(protocolError.rawContent.asChars() == input); + } // Response status code not numeric. - KJ_EXPECT(headers.tryParseResponse(kj::heapString( + { + auto input = kj::heapString( "HTTP/1.1\t\t abc\t I'm a teapot\r\n" "Foo-BaR: Baz\r\n" "Host: example.com\r\n" "DATE: early\r\n" "other-Header: yep\r\n" - "\r\n")) == nullptr); + "\r\n"); + + auto protocolError = headers.tryParseRequest(input).get(); + + KJ_EXPECT(protocolError.description == "Unrecognized request method.", + protocolError.description); + KJ_EXPECT(protocolError.rawContent.asChars() == input); + } +} + +KJ_TEST("HttpHeaders require valid HttpHeaderTable") { + const auto ERROR_MESSAGE = + "HttpHeaders object was constructed from HttpHeaderTable " + "that wasn't fully built yet at the time of construction"_kj; + + { + // A tabula rasa is valid. + HttpHeaderTable table; + KJ_REQUIRE(table.isReady()); + + HttpHeaders headers(table); + } + + { + // A future table is not valid. + HttpHeaderTable::Builder builder; + + auto& futureTable = builder.getFutureTable(); + KJ_REQUIRE(!futureTable.isReady()); + + auto makeHeadersThenBuild = [&]() { + HttpHeaders headers(futureTable); + auto table = builder.build(); + }; + KJ_EXPECT_THROW_MESSAGE(ERROR_MESSAGE, makeHeadersThenBuild()); + } + + { + // A well built table is valid. + HttpHeaderTable::Builder builder; + + auto& futureTable = builder.getFutureTable(); + KJ_REQUIRE(!futureTable.isReady()); + + auto ownedTable = builder.build(); + KJ_REQUIRE(futureTable.isReady()); + KJ_REQUIRE(ownedTable->isReady()); + + HttpHeaders headers(futureTable); + } +} + +KJ_TEST("HttpHeaders validation") { + auto table = HttpHeaderTable::Builder().build(); + HttpHeaders headers(*table); + + headers.add("Valid-Name", "valid value"); + + // The HTTP RFC prohibits control characters, but browsers only prohibit \0, \r, and \n. KJ goes + // with the browsers for compatibility. + headers.add("Valid-Name", "valid\x01value"); + + // The HTTP RFC does not permit non-ASCII values. + // KJ chooses to interpret them as UTF-8, to avoid the need for any expensive conversion. + // Browsers apparently interpret them as LATIN-1. Applications can reinterpet these strings as + // LATIN-1 easily enough if they really need to. + headers.add("Valid-Name", u8"valid€value"); + + KJ_EXPECT_THROW_MESSAGE("invalid header name", headers.add("Invalid Name", "value")); + KJ_EXPECT_THROW_MESSAGE("invalid header name", headers.add("Invalid@Name", "value")); + + KJ_EXPECT_THROW_MESSAGE("invalid header value", headers.set(HttpHeaderId::HOST, "in\nvalid")); + KJ_EXPECT_THROW_MESSAGE("invalid header value", headers.add("Valid-Name", "in\nvalid")); +} + +KJ_TEST("HttpHeaders Set-Cookie handling") { + HttpHeaderTable::Builder builder; + auto hCookie = builder.add("Cookie"); + auto hSetCookie = builder.add("Set-Cookie"); + auto table = builder.build(); + + HttpHeaders headers(*table); + headers.set(hCookie, "Foo"); + headers.add("Cookie", "Bar"); + headers.add("Cookie", "Baz"); + headers.set(hSetCookie, "Foo"); + headers.add("Set-Cookie", "Bar"); + headers.add("Set-Cookie", "Baz"); + + auto text = headers.toString(); + KJ_EXPECT(text == + "Cookie: Foo, Bar, Baz\r\n" + "Set-Cookie: Foo\r\n" + "Set-Cookie: Bar\r\n" + "Set-Cookie: Baz\r\n" + "\r\n", text); } // ======================================================================================= @@ -261,6 +426,10 @@ public: return inner.tryPumpFrom(input, amount); } + Promise whenWriteDisconnected() override { + return inner.whenWriteDisconnected(); + } + void shutdownWrite() override { return inner.shutdownWrite(); } @@ -367,7 +536,7 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { auto buffer = kj::heapArray(expected.size()); auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); - return promise.then(kj::mvCapture(buffer, [&in,expected](kj::Array buffer, size_t amount) { + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { if (amount == 0) { KJ_FAIL_ASSERT("expected data never sent", expected); } @@ -378,11 +547,40 @@ kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { } return expectRead(in, expected.slice(amount)); - })); + }); +} + +kj::Promise expectRead(kj::AsyncInputStream& in, kj::ArrayPtr expected) { + if (expected.size() == 0) return kj::READY_NOW; + + auto buffer = kj::heapArray(expected.size()); + + auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { + if (amount == 0) { + KJ_FAIL_ASSERT("expected data never sent", expected); + } + + auto actual = buffer.slice(0, amount); + if (memcmp(actual.begin(), expected.begin(), actual.size()) != 0) { + KJ_FAIL_ASSERT("data from stream doesn't match expected", expected, actual); + } + + return expectRead(in, expected.slice(amount, expected.size())); + }); +} + +kj::Promise expectEnd(kj::AsyncInputStream& in) { + static char buffer; + + auto promise = in.tryRead(&buffer, 1, 1); + return promise.then([](size_t amount) { + KJ_ASSERT(amount == 0, "expected EOF"); + }); } -void testHttpClientRequest(kj::AsyncIoContext& io, const HttpRequestTestCase& testCase) { - auto pipe = io.provider->newTwoWayPipe(); +void testHttpClientRequest(kj::WaitScope& waitScope, const HttpRequestTestCase& testCase, + kj::TwoWayPipe pipe) { auto serverTask = expectRead(*pipe.ends[1], testCase.raw).then([&]() { static const char SIMPLE_RESPONSE[] = @@ -404,7 +602,7 @@ void testHttpClientRequest(kj::AsyncIoContext& io, const HttpRequestTestCase& te auto request = client->request(testCase.method, testCase.path, headers, testCase.requestBodySize); if (testCase.requestBodyParts.size() > 0) { - writeEach(*request.body, testCase.requestBodyParts).wait(io.waitScope); + writeEach(*request.body, testCase.requestBodyParts).wait(waitScope); } request.body = nullptr; auto clientTask = request.response @@ -413,17 +611,16 @@ void testHttpClientRequest(kj::AsyncIoContext& io, const HttpRequestTestCase& te return promise.attach(kj::mv(response.body)); }).ignoreResult(); - serverTask.exclusiveJoin(kj::mv(clientTask)).wait(io.waitScope); + serverTask.exclusiveJoin(kj::mv(clientTask)).wait(waitScope); // Verify no more data written by client. client = nullptr; pipe.ends[0]->shutdownWrite(); - KJ_EXPECT(pipe.ends[1]->readAllText().wait(io.waitScope) == ""); + KJ_EXPECT(pipe.ends[1]->readAllText().wait(waitScope) == ""); } -void testHttpClientResponse(kj::AsyncIoContext& io, const HttpResponseTestCase& testCase, - size_t readFragmentSize) { - auto pipe = io.provider->newTwoWayPipe(); +void testHttpClientResponse(kj::WaitScope& waitScope, const HttpResponseTestCase& testCase, + size_t readFragmentSize, kj::TwoWayPipe pipe) { ReadFragmenter fragmenter(*pipe.ends[0], readFragmentSize); auto expectedReqText = testCase.method == HttpMethod::GET || testCase.method == HttpMethod::HEAD @@ -457,12 +654,39 @@ void testHttpClientResponse(kj::AsyncIoContext& io, const HttpResponseTestCase& KJ_EXPECT(body == kj::strArray(testCase.responseBodyParts, ""), body); }); - serverTask.exclusiveJoin(kj::mv(clientTask)).wait(io.waitScope); + serverTask.exclusiveJoin(kj::mv(clientTask)).wait(waitScope); // Verify no more data written by client. client = nullptr; pipe.ends[0]->shutdownWrite(); - KJ_EXPECT(pipe.ends[1]->readAllText().wait(io.waitScope) == ""); + KJ_EXPECT(pipe.ends[1]->readAllText().wait(waitScope) == ""); +} + +void testHttpClient(kj::WaitScope& waitScope, HttpHeaderTable& table, + HttpClient& client, const HttpTestCase& testCase) { + KJ_CONTEXT(testCase.request.raw, testCase.response.raw); + + HttpHeaders headers(table); + for (auto& header: testCase.request.requestHeaders) { + headers.set(header.id, header.value); + } + + auto request = client.request( + testCase.request.method, testCase.request.path, headers, testCase.request.requestBodySize); + for (auto& part: testCase.request.requestBodyParts) { + request.body->write(part.begin(), part.size()).wait(waitScope); + } + request.body = nullptr; + + auto response = request.response.wait(waitScope); + + KJ_EXPECT(response.statusCode == testCase.response.statusCode); + auto body = response.body->readAllText().wait(waitScope); + if (testCase.request.method == HttpMethod::HEAD) { + KJ_EXPECT(body == ""); + } else { + KJ_EXPECT(body == kj::strArray(testCase.response.responseBodyParts, ""), body); + } } class TestHttpService final: public HttpService { @@ -534,23 +758,22 @@ private: uint requestCount = 0; }; -void testHttpServerRequest(kj::AsyncIoContext& io, +void testHttpServerRequest(kj::WaitScope& waitScope, kj::Timer& timer, const HttpRequestTestCase& requestCase, - const HttpResponseTestCase& responseCase) { - auto pipe = io.provider->newTwoWayPipe(); - + const HttpResponseTestCase& responseCase, + kj::TwoWayPipe pipe) { HttpHeaderTable table; TestHttpService service(requestCase, responseCase, table); - HttpServer server(io.provider->getTimer(), table, service); + HttpServer server(timer, table, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - pipe.ends[1]->write(requestCase.raw.begin(), requestCase.raw.size()).wait(io.waitScope); + pipe.ends[1]->write(requestCase.raw.begin(), requestCase.raw.size()).wait(waitScope); pipe.ends[1]->shutdownWrite(); - expectRead(*pipe.ends[1], responseCase.raw).wait(io.waitScope); + expectRead(*pipe.ends[1], responseCase.raw).wait(waitScope); - listenTask.wait(io.waitScope); + listenTask.wait(waitScope); KJ_EXPECT(service.getRequestCount() == 1); } @@ -571,7 +794,7 @@ kj::ArrayPtr requestTestCases() { HttpMethod::GET, "/foo/bar", {{HttpHeaderId::HOST, "example.com"}}, - nullptr, {}, + uint64_t(0), {}, }, { @@ -582,7 +805,7 @@ kj::ArrayPtr requestTestCases() { HttpMethod::HEAD, "/foo/bar", {{HttpHeaderId::HOST, "example.com"}}, - nullptr, {}, + uint64_t(0), {}, }, { @@ -650,8 +873,40 @@ kj::ArrayPtr requestTestCases() { HttpMethod::GET, "/", {{HttpHeaderId::HOST, HUGE_STRING}}, - nullptr, {} + uint64_t(0), {} + }, + + { + "GET /foo/bar HTTP/1.1\r\n" + "Content-Length: 6\r\n" + "Host: example.com\r\n" + "\r\n" + "foobar", + + HttpMethod::GET, + "/foo/bar", + {{HttpHeaderId::HOST, "example.com"}}, + uint64_t(6), { "foobar" }, }, + + { + "GET /foo/bar HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "Host: example.com\r\n" + "\r\n" + "3\r\n" + "foo\r\n" + "3\r\n" + "bar\r\n" + "0\r\n" + "\r\n", + + HttpMethod::GET, + "/foo/bar", + {{HttpHeaderId::HOST, "example.com"}, + {HttpHeaderId::TRANSFER_ENCODING, "chunked"}}, + nullptr, { "foo", "bar" }, + } }; // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents REQUEST_TEST_CASES from implicitly @@ -676,6 +931,36 @@ kj::ArrayPtr responseTestCases() { CLIENT_ONLY, // Server never sends connection: close }, + { + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/plain\r\n" + "Transfer-Encoding: identity\r\n" + "Content-Length: foobar\r\n" // intentionally wrong + "\r\n" + "baz qux", + + 200, "OK", + {{HttpHeaderId::CONTENT_TYPE, "text/plain"}}, + nullptr, {"baz qux"}, + + HttpMethod::GET, + CLIENT_ONLY, // Server never sends transfer-encoding: identity + }, + + { + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "baz qux", + + 200, "OK", + {{HttpHeaderId::CONTENT_TYPE, "text/plain"}}, + nullptr, {"baz qux"}, + + HttpMethod::GET, + CLIENT_ONLY, // Server never sends non-delimited message + }, + { "HTTP/1.1 200 OK\r\n" "Content-Length: 123\r\n" @@ -689,6 +974,60 @@ kj::ArrayPtr responseTestCases() { HttpMethod::HEAD, }, + { + "HTTP/1.1 200 OK\r\n" + "Content-Length: foobar\r\n" + "Content-Type: text/plain\r\n" + "\r\n", + + 200, "OK", + {{HttpHeaderId::CONTENT_TYPE, "text/plain"}, + {HttpHeaderId::CONTENT_LENGTH, "foobar"}}, + 123, {}, + + HttpMethod::HEAD, + }, + + // Zero-length expected size response to HEAD request has no Content-Length header. + { + "HTTP/1.1 200 OK\r\n" + "\r\n", + + 200, "OK", + {}, + uint64_t(0), {}, + + HttpMethod::HEAD, + }, + + { + "HTTP/1.1 204 No Content\r\n" + "\r\n", + + 204, "No Content", + {}, + uint64_t(0), {}, + }, + + { + "HTTP/1.1 205 Reset Content\r\n" + "Content-Length: 0\r\n" + "\r\n", + + 205, "Reset Content", + {}, + uint64_t(0), {}, + }, + + { + "HTTP/1.1 304 Not Modified\r\n" + "\r\n", + + 304, "Not Modified", + {}, + uint64_t(0), {}, + }, + { "HTTP/1.1 200 OK\r\n" "Content-Length: 8\r\n" @@ -725,26 +1064,145 @@ kj::ArrayPtr responseTestCases() { } KJ_TEST("HttpClient requests") { - auto io = kj::setupAsyncIo(); + KJ_HTTP_TEST_SETUP_IO; for (auto& testCase: requestTestCases()) { if (testCase.side == SERVER_ONLY) continue; KJ_CONTEXT(testCase.raw); - testHttpClientRequest(io, testCase); + testHttpClientRequest(waitScope, testCase, KJ_HTTP_TEST_CREATE_2PIPE); } } KJ_TEST("HttpClient responses") { - auto io = kj::setupAsyncIo(); + KJ_HTTP_TEST_SETUP_IO; size_t FRAGMENT_SIZES[] = { 1, 2, 3, 4, 5, 6, 7, 8, 16, 31, kj::maxValue }; for (auto& testCase: responseTestCases()) { if (testCase.side == SERVER_ONLY) continue; for (size_t fragmentSize: FRAGMENT_SIZES) { KJ_CONTEXT(testCase.raw, fragmentSize); - testHttpClientResponse(io, testCase, fragmentSize); + testHttpClientResponse(waitScope, testCase, fragmentSize, KJ_HTTP_TEST_CREATE_2PIPE); + } + } +} + +KJ_TEST("HttpClient canceled write") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto serverPromise = pipe.ends[1]->readAllText(); + + { + HttpHeaderTable table; + auto client = newHttpClient(table, *pipe.ends[0]); + + auto body = kj::heapArray(4096); + memset(body.begin(), 0xcf, body.size()); + + auto req = client->request(HttpMethod::POST, "/", HttpHeaders(table), uint64_t(4096)); + + // Note: This poll() forces the server to receive what was written so far. Otherwise, + // cancelling the write below may in fact cancel the previous write as well. + KJ_EXPECT(!serverPromise.poll(waitScope)); + + // Start a write and immediately cancel it. + { + auto ignore KJ_UNUSED = req.body->write(body.begin(), body.size()); + } + + KJ_EXPECT_THROW_MESSAGE("overwrote", req.body->write("foo", 3).wait(waitScope)); + req.body = nullptr; + + KJ_EXPECT(!serverPromise.poll(waitScope)); + + KJ_EXPECT_THROW_MESSAGE("can't start new request until previous request body", + client->request(HttpMethod::GET, "/", HttpHeaders(table)).response.wait(waitScope)); + } + + pipe.ends[0]->shutdownWrite(); + auto text = serverPromise.wait(waitScope); + KJ_EXPECT(text == "POST / HTTP/1.1\r\nContent-Length: 4096\r\n\r\n", text); +} + +KJ_TEST("HttpClient chunked body gather-write") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto serverPromise = pipe.ends[1]->readAllText(); + + { + HttpHeaderTable table; + auto client = newHttpClient(table, *pipe.ends[0]); + + auto req = client->request(HttpMethod::POST, "/", HttpHeaders(table)); + + kj::ArrayPtr bodyParts[] = { + "foo"_kj.asBytes(), " "_kj.asBytes(), + "bar"_kj.asBytes(), " "_kj.asBytes(), + "baz"_kj.asBytes() + }; + + req.body->write(kj::arrayPtr(bodyParts, kj::size(bodyParts))).wait(waitScope); + req.body = nullptr; + + // Wait for a response so the client has a chance to end the request body with a 0-chunk. + kj::StringPtr responseText = "HTTP/1.1 204 No Content\r\n\r\n"; + pipe.ends[1]->write(responseText.begin(), responseText.size()).wait(waitScope); + auto response = req.response.wait(waitScope); + } + + pipe.ends[0]->shutdownWrite(); + + auto text = serverPromise.wait(waitScope); + KJ_EXPECT(text == "POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + "b\r\nfoo bar baz\r\n0\r\n\r\n", text); +} + +KJ_TEST("HttpClient chunked body pump from fixed length stream") { + class FixedBodyStream final: public kj::AsyncInputStream { + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + auto n = kj::min(body.size(), maxBytes); + n = kj::max(n, minBytes); + n = kj::min(n, body.size()); + memcpy(buffer, body.begin(), n); + body = body.slice(n); + return n; } + + Maybe tryGetLength() override { return body.size(); } + + kj::StringPtr body = "foo bar baz"; + }; + + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto serverPromise = pipe.ends[1]->readAllText(); + + { + HttpHeaderTable table; + auto client = newHttpClient(table, *pipe.ends[0]); + + auto req = client->request(HttpMethod::POST, "/", HttpHeaders(table)); + + FixedBodyStream bodyStream; + bodyStream.pumpTo(*req.body).wait(waitScope); + req.body = nullptr; + + // Wait for a response so the client has a chance to end the request body with a 0-chunk. + kj::StringPtr responseText = "HTTP/1.1 204 No Content\r\n\r\n"; + pipe.ends[1]->write(responseText.begin(), responseText.size()).wait(waitScope); + auto response = req.response.wait(waitScope); } + + pipe.ends[0]->shutdownWrite(); + + auto text = serverPromise.wait(waitScope); + KJ_EXPECT(text == "POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + "b\r\nfoo bar baz\r\n0\r\n\r\n", text); } KJ_TEST("HttpServer requests") { @@ -769,13 +1227,15 @@ KJ_TEST("HttpServer requests") { 3, {"foo"} }; - auto io = kj::setupAsyncIo(); + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); for (auto& testCase: requestTestCases()) { if (testCase.side == CLIENT_ONLY) continue; KJ_CONTEXT(testCase.raw); - testHttpServerRequest(io, testCase, - testCase.method == HttpMethod::HEAD ? HEAD_RESPONSE : RESPONSE); + testHttpServerRequest(waitScope, timer, testCase, + testCase.method == HttpMethod::HEAD ? HEAD_RESPONSE : RESPONSE, + KJ_HTTP_TEST_CREATE_2PIPE); } } @@ -787,7 +1247,7 @@ KJ_TEST("HttpServer responses") { HttpMethod::GET, "/", {}, - nullptr, {}, + uint64_t(0), {}, }; HttpRequestTestCase HEAD_REQUEST = { @@ -797,16 +1257,18 @@ KJ_TEST("HttpServer responses") { HttpMethod::HEAD, "/", {}, - nullptr, {}, + uint64_t(0), {}, }; - auto io = kj::setupAsyncIo(); + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); for (auto& testCase: responseTestCases()) { if (testCase.side == CLIENT_ONLY) continue; KJ_CONTEXT(testCase.raw); - testHttpServerRequest(io, - testCase.method == HttpMethod::HEAD ? HEAD_REQUEST : REQUEST, testCase); + testHttpServerRequest(waitScope, timer, + testCase.method == HttpMethod::HEAD ? HEAD_REQUEST : REQUEST, testCase, + KJ_HTTP_TEST_CREATE_2PIPE); } } @@ -819,7 +1281,7 @@ kj::ArrayPtr pipelineTestCases() { "GET / HTTP/1.1\r\n" "\r\n", - HttpMethod::GET, "/", {}, nullptr, {}, + HttpMethod::GET, "/", {}, uint64_t(0), {}, }, { "HTTP/1.1 200 OK\r\n" @@ -850,50 +1312,107 @@ kj::ArrayPtr pipelineTestCases() { }, }, + // Throw a zero-size request/response into the pipeline to check for a bug that existed with + // them previously. { { - "POST /bar HTTP/1.1\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "6\r\n" - "garply\r\n" - "5\r\n" - "waldo\r\n" - "0\r\n" + "POST /foo HTTP/1.1\r\n" + "Content-Length: 0\r\n" "\r\n", - HttpMethod::POST, "/bar", {}, nullptr, { "garply", "waldo" }, + HttpMethod::POST, "/foo", {}, uint64_t(0), {}, }, { "HTTP/1.1 200 OK\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "4\r\n" - "fred\r\n" - "5\r\n" - "plugh\r\n" - "0\r\n" + "Content-Length: 0\r\n" "\r\n", - 200, "OK", {}, nullptr, { "fred", "plugh" } + 200, "OK", {}, uint64_t(0), {} }, }, + // Also a zero-size chunked request/response. { { - "HEAD / HTTP/1.1\r\n" + "POST /foo HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" "\r\n", - HttpMethod::HEAD, "/", {}, nullptr, {}, + HttpMethod::POST, "/foo", {}, nullptr, {}, }, { "HTTP/1.1 200 OK\r\n" - "Content-Length: 7\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n", + + 200, "OK", {}, nullptr, {} + }, + }, + + { + { + "POST /bar HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "garply\r\n" + "5\r\n" + "waldo\r\n" + "0\r\n" + "\r\n", + + HttpMethod::POST, "/bar", {}, nullptr, { "garply", "waldo" }, + }, + { + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "4\r\n" + "fred\r\n" + "5\r\n" + "plugh\r\n" + "0\r\n" + "\r\n", + + 200, "OK", {}, nullptr, { "fred", "plugh" } + }, + }, + + { + { + "HEAD / HTTP/1.1\r\n" + "\r\n", + + HttpMethod::HEAD, "/", {}, uint64_t(0), {}, + }, + { + "HTTP/1.1 200 OK\r\n" + "Content-Length: 7\r\n" "\r\n", 200, "OK", {}, 7, { "foo bar" } }, }, + + // Zero-length expected size response to HEAD request has no Content-Length header. + { + { + "HEAD / HTTP/1.1\r\n" + "\r\n", + + HttpMethod::HEAD, "/", {}, uint64_t(0), {}, + }, + { + "HTTP/1.1 200 OK\r\n" + "\r\n", + + 200, "OK", {}, uint64_t(0), {}, HttpMethod::HEAD, + }, + }, }; // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents RESPONSE_TEST_CASES from implicitly @@ -904,8 +1423,8 @@ kj::ArrayPtr pipelineTestCases() { KJ_TEST("HttpClient pipeline") { auto PIPELINE_TESTS = pipelineTestCases(); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; kj::Promise writeResponsesPromise = kj::READY_NOW; for (auto& testCase: PIPELINE_TESTS) { @@ -921,49 +1440,35 @@ KJ_TEST("HttpClient pipeline") { auto client = newHttpClient(table, *pipe.ends[0]); for (auto& testCase: PIPELINE_TESTS) { - KJ_CONTEXT(testCase.request.raw, testCase.response.raw); - - HttpHeaders headers(table); - for (auto& header: testCase.request.requestHeaders) { - headers.set(header.id, header.value); - } - - auto request = client->request( - testCase.request.method, testCase.request.path, headers, testCase.request.requestBodySize); - for (auto& part: testCase.request.requestBodyParts) { - request.body->write(part.begin(), part.size()).wait(io.waitScope); - } - request.body = nullptr; - - auto response = request.response.wait(io.waitScope); - - KJ_EXPECT(response.statusCode == testCase.response.statusCode); - auto body = response.body->readAllText().wait(io.waitScope); - if (testCase.request.method == HttpMethod::HEAD) { - KJ_EXPECT(body == ""); - } else { - KJ_EXPECT(body == kj::strArray(testCase.response.responseBodyParts, ""), body); - } + testHttpClient(waitScope, table, *client, testCase); } client = nullptr; pipe.ends[0]->shutdownWrite(); - writeResponsesPromise.wait(io.waitScope); + writeResponsesPromise.wait(waitScope); } KJ_TEST("HttpClient parallel pipeline") { auto PIPELINE_TESTS = pipelineTestCases(); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + kj::Promise readRequestsPromise = kj::READY_NOW; kj::Promise writeResponsesPromise = kj::READY_NOW; for (auto& testCase: PIPELINE_TESTS) { - writeResponsesPromise = writeResponsesPromise + auto forked = readRequestsPromise .then([&]() { return expectRead(*pipe.ends[1], testCase.request.raw); - }).then([&]() { + }).fork(); + readRequestsPromise = forked.addBranch(); + + // Don't write each response until the corresponding request is received. + auto promises = kj::heapArrayBuilder>(2); + promises.add(forked.addBranch()); + promises.add(kj::mv(writeResponsesPromise)); + writeResponsesPromise = kj::joinPromises(promises.finish()).then([&]() { return pipe.ends[1]->write(testCase.response.raw.begin(), testCase.response.raw.size()); }); } @@ -982,7 +1487,7 @@ KJ_TEST("HttpClient parallel pipeline") { auto request = client->request( testCase.request.method, testCase.request.path, headers, testCase.request.requestBodySize); for (auto& part: testCase.request.requestBodyParts) { - request.body->write(part.begin(), part.size()).wait(io.waitScope); + request.body->write(part.begin(), part.size()).wait(waitScope); } return kj::mv(request.response); @@ -990,10 +1495,10 @@ KJ_TEST("HttpClient parallel pipeline") { for (auto i: kj::indices(PIPELINE_TESTS)) { auto& testCase = PIPELINE_TESTS[i]; - auto response = responsePromises[i].wait(io.waitScope); + auto response = responsePromises[i].wait(waitScope); KJ_EXPECT(response.statusCode == testCase.response.statusCode); - auto body = response.body->readAllText().wait(io.waitScope); + auto body = response.body->readAllText().wait(waitScope); if (testCase.request.method == HttpMethod::HEAD) { KJ_EXPECT(body == ""); } else { @@ -1004,18 +1509,19 @@ KJ_TEST("HttpClient parallel pipeline") { client = nullptr; pipe.ends[0]->shutdownWrite(); - writeResponsesPromise.wait(io.waitScope); + writeResponsesPromise.wait(waitScope); } KJ_TEST("HttpServer pipeline") { auto PIPELINE_TESTS = pipelineTestCases(); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; HttpHeaderTable table; TestHttpService service(PIPELINE_TESTS, table); - HttpServer server(io.provider->getTimer(), table, service); + HttpServer server(timer, table, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); @@ -1023,13 +1529,13 @@ KJ_TEST("HttpServer pipeline") { KJ_CONTEXT(testCase.request.raw, testCase.response.raw); pipe.ends[1]->write(testCase.request.raw.begin(), testCase.request.raw.size()) - .wait(io.waitScope); + .wait(waitScope); - expectRead(*pipe.ends[1], testCase.response.raw).wait(io.waitScope); + expectRead(*pipe.ends[1], testCase.response.raw).wait(waitScope); } pipe.ends[1]->shutdownWrite(); - listenTask.wait(io.waitScope); + listenTask.wait(waitScope); KJ_EXPECT(service.getRequestCount() == kj::size(PIPELINE_TESTS)); } @@ -1037,8 +1543,9 @@ KJ_TEST("HttpServer pipeline") { KJ_TEST("HttpServer parallel pipeline") { auto PIPELINE_TESTS = pipelineTestCases(); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; auto allRequestText = kj::strArray(KJ_MAP(testCase, PIPELINE_TESTS) { return testCase.request.raw; }, ""); @@ -1047,17 +1554,17 @@ KJ_TEST("HttpServer parallel pipeline") { HttpHeaderTable table; TestHttpService service(PIPELINE_TESTS, table); - HttpServer server(io.provider->getTimer(), table, service); + HttpServer server(timer, table, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - pipe.ends[1]->write(allRequestText.begin(), allRequestText.size()).wait(io.waitScope); + pipe.ends[1]->write(allRequestText.begin(), allRequestText.size()).wait(waitScope); pipe.ends[1]->shutdownWrite(); - auto rawResponse = pipe.ends[1]->readAllText().wait(io.waitScope); + auto rawResponse = pipe.ends[1]->readAllText().wait(waitScope); KJ_EXPECT(rawResponse == allResponseText, rawResponse); - listenTask.wait(io.waitScope); + listenTask.wait(waitScope); KJ_EXPECT(service.getRequestCount() == kj::size(PIPELINE_TESTS)); } @@ -1065,305 +1572,5607 @@ KJ_TEST("HttpServer parallel pipeline") { KJ_TEST("HttpClient <-> HttpServer") { auto PIPELINE_TESTS = pipelineTestCases(); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; HttpHeaderTable table; TestHttpService service(PIPELINE_TESTS, table); - HttpServer server(io.provider->getTimer(), table, service); + HttpServer server(timer, table, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[1])); auto client = newHttpClient(table, *pipe.ends[0]); for (auto& testCase: PIPELINE_TESTS) { - KJ_CONTEXT(testCase.request.raw, testCase.response.raw); - - HttpHeaders headers(table); - for (auto& header: testCase.request.requestHeaders) { - headers.set(header.id, header.value); - } - - auto request = client->request( - testCase.request.method, testCase.request.path, headers, testCase.request.requestBodySize); - for (auto& part: testCase.request.requestBodyParts) { - request.body->write(part.begin(), part.size()).wait(io.waitScope); - } - request.body = nullptr; - - auto response = request.response.wait(io.waitScope); - - KJ_EXPECT(response.statusCode == testCase.response.statusCode); - auto body = response.body->readAllText().wait(io.waitScope); - if (testCase.request.method == HttpMethod::HEAD) { - KJ_EXPECT(body == ""); - } else { - KJ_EXPECT(body == kj::strArray(testCase.response.responseBodyParts, ""), body); - } + testHttpClient(waitScope, table, *client, testCase); } client = nullptr; pipe.ends[0]->shutdownWrite(); - listenTask.wait(io.waitScope); + listenTask.wait(waitScope); KJ_EXPECT(service.getRequestCount() == kj::size(PIPELINE_TESTS)); } // ----------------------------------------------------------------------------- -KJ_TEST("HttpServer request timeout") { - auto PIPELINE_TESTS = pipelineTestCases(); +KJ_TEST("HttpInputStream requests") { + KJ_HTTP_TEST_SETUP_IO; - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + kj::HttpHeaderTable table; - HttpHeaderTable table; - TestHttpService service(PIPELINE_TESTS, table); - HttpServerSettings settings; - settings.headerTimeout = 1 * kj::MILLISECONDS; - HttpServer server(io.provider->getTimer(), table, service, settings); + auto pipe = kj::newOneWayPipe(); + auto input = newHttpInputStream(*pipe.in, table); - // Shouldn't hang! Should time out. - server.listenHttp(kj::mv(pipe.ends[0])).wait(io.waitScope); + kj::Promise writeQueue = kj::READY_NOW; - // Sends back 408 Request Timeout. - KJ_EXPECT(pipe.ends[1]->readAllText().wait(io.waitScope) - .startsWith("HTTP/1.1 408 Request Timeout")); -} + for (auto& testCase: requestTestCases()) { + writeQueue = writeQueue.then([&]() { + return pipe.out->write(testCase.raw.begin(), testCase.raw.size()); + }); + } + writeQueue = writeQueue.then([&]() { + pipe.out = nullptr; + }); -KJ_TEST("HttpServer pipeline timeout") { - auto PIPELINE_TESTS = pipelineTestCases(); + for (auto& testCase: requestTestCases()) { + KJ_CONTEXT(testCase.raw); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + KJ_ASSERT(input->awaitNextMessage().wait(waitScope)); - HttpHeaderTable table; - TestHttpService service(PIPELINE_TESTS, table); - HttpServerSettings settings; - settings.pipelineTimeout = 1 * kj::MILLISECONDS; - HttpServer server(io.provider->getTimer(), table, service, settings); + auto req = input->readRequest().wait(waitScope); + KJ_EXPECT(req.method == testCase.method); + KJ_EXPECT(req.url == testCase.path); + for (auto& header: testCase.requestHeaders) { + KJ_EXPECT(KJ_ASSERT_NONNULL(req.headers.get(header.id)) == header.value); + } + auto body = req.body->readAllText().wait(waitScope); + KJ_EXPECT(body == kj::strArray(testCase.requestBodyParts, "")); + } - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + writeQueue.wait(waitScope); + KJ_EXPECT(!input->awaitNextMessage().wait(waitScope)); +} - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(io.waitScope); - expectRead(*pipe.ends[1], PIPELINE_TESTS[0].response.raw).wait(io.waitScope); +KJ_TEST("HttpInputStream responses") { + KJ_HTTP_TEST_SETUP_IO; - // Listen task should time out even though we didn't shutdown the socket. - listenTask.wait(io.waitScope); + kj::HttpHeaderTable table; - // In this case, no data is sent back. - KJ_EXPECT(pipe.ends[1]->readAllText().wait(io.waitScope) == ""); -} + auto pipe = kj::newOneWayPipe(); + auto input = newHttpInputStream(*pipe.in, table); -class BrokenHttpService final: public HttpService { - // HttpService that doesn't send a response. -public: - BrokenHttpService() = default; - explicit BrokenHttpService(kj::Exception&& exception): exception(kj::mv(exception)) {} + kj::Promise writeQueue = kj::READY_NOW; - kj::Promise request( - HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::AsyncInputStream& requestBody, Response& responseSender) override { - return requestBody.readAllBytes().then([this](kj::Array&&) -> kj::Promise { - KJ_IF_MAYBE(e, exception) { - return kj::cp(*e); - } else { - return kj::READY_NOW; - } + for (auto& testCase: responseTestCases()) { + if (testCase.side == CLIENT_ONLY) continue; // skip Connection: close case. + writeQueue = writeQueue.then([&]() { + return pipe.out->write(testCase.raw.begin(), testCase.raw.size()); }); } + writeQueue = writeQueue.then([&]() { + pipe.out = nullptr; + }); -private: - kj::Maybe exception; -}; + for (auto& testCase: responseTestCases()) { + if (testCase.side == CLIENT_ONLY) continue; // skip Connection: close case. + KJ_CONTEXT(testCase.raw); -KJ_TEST("HttpServer no response") { - auto PIPELINE_TESTS = pipelineTestCases(); + KJ_ASSERT(input->awaitNextMessage().wait(waitScope)); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + auto resp = input->readResponse(testCase.method).wait(waitScope); + KJ_EXPECT(resp.statusCode == testCase.statusCode); + KJ_EXPECT(resp.statusText == testCase.statusText); + for (auto& header: testCase.responseHeaders) { + KJ_EXPECT(KJ_ASSERT_NONNULL(resp.headers.get(header.id)) == header.value); + } + auto body = resp.body->readAllText().wait(waitScope); + KJ_EXPECT(body == kj::strArray(testCase.responseBodyParts, "")); + } - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(io.provider->getTimer(), table, service); + writeQueue.wait(waitScope); + KJ_EXPECT(!input->awaitNextMessage().wait(waitScope)); +} - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); +KJ_TEST("HttpInputStream bare messages") { + KJ_HTTP_TEST_SETUP_IO; - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(io.waitScope); - auto text = pipe.ends[1]->readAllText().wait(io.waitScope); + kj::HttpHeaderTable table; - KJ_EXPECT(text == - "HTTP/1.1 500 Internal Server Error\r\n" - "Connection: close\r\n" - "Content-Length: 51\r\n" - "Content-Type: text/plain\r\n" + auto pipe = kj::newOneWayPipe(); + auto input = newHttpInputStream(*pipe.in, table); + + kj::StringPtr messages = + "Content-Length: 6\r\n" "\r\n" - "ERROR: The HttpService did not generate a response.", text); + "foobar" + "Content-Length: 11\r\n" + "Content-Type: some/type\r\n" + "\r\n" + "bazquxcorge" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "grault\r\n" + "b\r\n" + "garplywaldo\r\n" + "0\r\n" + "\r\n"_kj; + + kj::Promise writeTask = pipe.out->write(messages.begin(), messages.size()) + .then([&]() { pipe.out = nullptr; }); + + { + KJ_ASSERT(input->awaitNextMessage().wait(waitScope)); + auto message = input->readMessage().wait(waitScope); + KJ_EXPECT(KJ_ASSERT_NONNULL(message.headers.get(HttpHeaderId::CONTENT_LENGTH)) == "6"); + KJ_EXPECT(message.body->readAllText().wait(waitScope) == "foobar"); + } + { + KJ_ASSERT(input->awaitNextMessage().wait(waitScope)); + auto message = input->readMessage().wait(waitScope); + KJ_EXPECT(KJ_ASSERT_NONNULL(message.headers.get(HttpHeaderId::CONTENT_LENGTH)) == "11"); + KJ_EXPECT(KJ_ASSERT_NONNULL(message.headers.get(HttpHeaderId::CONTENT_TYPE)) == "some/type"); + KJ_EXPECT(message.body->readAllText().wait(waitScope) == "bazquxcorge"); + } + { + KJ_ASSERT(input->awaitNextMessage().wait(waitScope)); + auto message = input->readMessage().wait(waitScope); + KJ_EXPECT(KJ_ASSERT_NONNULL(message.headers.get(HttpHeaderId::TRANSFER_ENCODING)) == "chunked"); + KJ_EXPECT(message.body->readAllText().wait(waitScope) == "graultgarplywaldo"); + } + + writeTask.wait(waitScope); + KJ_EXPECT(!input->awaitNextMessage().wait(waitScope)); } -KJ_TEST("HttpServer disconnected") { - auto PIPELINE_TESTS = pipelineTestCases(); +// ----------------------------------------------------------------------------- - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); +KJ_TEST("WebSocket core protocol") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(DISCONNECTED, "disconnected")); - HttpServer server(io.provider->getTimer(), table, service); + auto client = newWebSocket(kj::mv(pipe.ends[0]), nullptr); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto mediumString = kj::strArray(kj::repeat(kj::StringPtr("123456789"), 30), ""); + auto bigString = kj::strArray(kj::repeat(kj::StringPtr("123456789"), 10000), ""); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(io.waitScope); - auto text = pipe.ends[1]->readAllText().wait(io.waitScope); + auto clientTask = client->send(kj::StringPtr("hello")) + .then([&]() { return client->send(mediumString); }) + .then([&]() { return client->send(bigString); }) + .then([&]() { return client->send(kj::StringPtr("world").asBytes()); }) + .then([&]() { return client->close(1234, "bored"); }) + .then([&]() { KJ_EXPECT(client->sentByteCount() == 90307)}); - KJ_EXPECT(text == "", text); -} + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "hello"); + } -KJ_TEST("HttpServer overloaded") { - auto PIPELINE_TESTS = pipelineTestCases(); + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == mediumString); + } - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == bigString); + } - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(OVERLOADED, "overloaded")); - HttpServer server(io.provider->getTimer(), table, service); + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is>()); + KJ_EXPECT(kj::str(message.get>().asChars()) == "world"); + } - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 1234); + KJ_EXPECT(message.get().reason == "bored"); + KJ_EXPECT(server->receivedByteCount() == 90307); + } - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(io.waitScope); - auto text = pipe.ends[1]->readAllText().wait(io.waitScope); + auto serverTask = server->close(4321, "whatever"); - KJ_EXPECT(text.startsWith("HTTP/1.1 503 Service Unavailable"), text); + { + auto message = client->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 4321); + KJ_EXPECT(message.get().reason == "whatever"); + KJ_EXPECT(client->receivedByteCount() == 12); + } + + clientTask.wait(waitScope); + serverTask.wait(waitScope); } -KJ_TEST("HttpServer unimplemented") { - auto PIPELINE_TESTS = pipelineTestCases(); +KJ_TEST("WebSocket fragmented") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(UNIMPLEMENTED, "unimplemented")); - HttpServer server(io.provider->getTimer(), table, service); + byte DATA[] = { + 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + 0x00, 0x03, 'w', 'o', 'r', - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(io.waitScope); - auto text = pipe.ends[1]->readAllText().wait(io.waitScope); + 0x80, 0x02, 'l', 'd', + }; - KJ_EXPECT(text.startsWith("HTTP/1.1 501 Not Implemented"), text); + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "hello world"); + } + + clientTask.wait(waitScope); } -KJ_TEST("HttpServer threw exception") { - auto PIPELINE_TESTS = pipelineTestCases(); +#if KJ_HAS_ZLIB +KJ_TEST("WebSocket compressed fragment") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, CompressionParameters{ + .outboundNoContextTakeover = false, + .inboundNoContextTakeover = false, + .outboundMaxWindowBits=15, + .inboundMaxWindowBits=15, + }); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + // The message is "Hello", sent in two fragments, see the fragmented example at the bottom of: + // https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 + byte COMPRESSED_DATA[] = { + 0x41, 0x03, 0xf2, 0x48, 0xcd, - HttpHeaderTable table; - BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed")); - HttpServer server(io.provider->getTimer(), table, service); + 0x80, 0x04, 0xc9, 0xc9, 0x07, 0x00 + }; - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + auto clientTask = client->write(COMPRESSED_DATA, sizeof(COMPRESSED_DATA)); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(io.waitScope); - auto text = pipe.ends[1]->readAllText().wait(io.waitScope); + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } - KJ_EXPECT(text.startsWith("HTTP/1.1 500 Internal Server Error"), text); + clientTask.wait(waitScope); } +#endif // KJ_HAS_ZLIB -class PartialResponseService final: public HttpService { - // HttpService that sends a partial response then throws. +class FakeEntropySource final: public EntropySource { public: - kj::Promise request( - HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::AsyncInputStream& requestBody, Response& response) override { - return requestBody.readAllBytes() - .then([this,&response](kj::Array&&) -> kj::Promise { - HttpHeaders headers(table); - auto body = response.send(200, "OK", headers, 32); - auto promise = body->write("foo", 3); - return promise.attach(kj::mv(body)).then([]() -> kj::Promise { - return KJ_EXCEPTION(FAILED, "failed"); - }); - }); - } + void generate(kj::ArrayPtr buffer) override { + static constexpr byte DUMMY[4] = { 12, 34, 56, 78 }; + + for (auto i: kj::indices(buffer)) { + buffer[i] = DUMMY[i % sizeof(DUMMY)]; + } + } +}; + +KJ_TEST("WebSocket masked") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + FakeEntropySource maskGenerator; + + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), maskGenerator); + + byte DATA[] = { + 0x81, 0x86, 12, 34, 56, 78, 'h' ^ 12, 'e' ^ 34, 'l' ^ 56, 'l' ^ 78, 'o' ^ 12, ' ' ^ 34, + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + auto serverTask = server->send(kj::StringPtr("hello ")); + + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "hello "); + } + + expectRead(*client, DATA).wait(waitScope); + + clientTask.wait(waitScope); + serverTask.wait(waitScope); +} + +class WebSocketErrorCatcher : public WebSocketErrorHandler { +public: + kj::Vector errors; + + kj::Exception handleWebSocketProtocolError(kj::WebSocket::ProtocolError protocolError) { + errors.add(kj::mv(protocolError)); + return KJ_EXCEPTION(FAILED, protocolError.description); + } +}; + +KJ_TEST("WebSocket unexpected RSV bits") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', + + 0xF0, 0x05, 'w', 'o', 'r', 'l', 'd' // all RSV bits set, plus FIN + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket unexpected continuation frame") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x80, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', // Continuation frame with no start frame, plus FIN + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket missing continuation frame") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', // Start frame + 0x01, 0x06, 'w', 'o', 'r', 'l', 'd', '!', // Another start frame + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket fragmented control frame") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x09, 0x04, 'd', 'a', 't', 'a' // Fragmented ping frame + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket unknown opcode") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + byte DATA[] = { + 0x85, 0x04, 'd', 'a', 't', 'a' // 5 is a reserved opcode + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket unsolicited pong") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); + + byte DATA[] = { + 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', + + 0x8A, 0x03, 'f', 'o', 'o', + + 0x80, 0x05, 'w', 'o', 'r', 'l', 'd', + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "hello world"); + } + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket ping") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); + + // Be extra-annoying by having the ping arrive between fragments. + byte DATA[] = { + 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', + + 0x89, 0x03, 'f', 'o', 'o', + + 0x80, 0x05, 'w', 'o', 'r', 'l', 'd', + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "hello world"); + } + + auto serverTask = server->send(kj::StringPtr("bar")); + + byte EXPECTED[] = { + 0x8A, 0x03, 'f', 'o', 'o', // pong + 0x81, 0x03, 'b', 'a', 'r', // message + }; + + expectRead(*client, EXPECTED).wait(waitScope); + + clientTask.wait(waitScope); + serverTask.wait(waitScope); +} + +KJ_TEST("WebSocket ping mid-send") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); + + auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), ""); + auto serverTask = server->send(bigString).eagerlyEvaluate(nullptr); + + byte DATA[] = { + 0x89, 0x03, 'f', 'o', 'o', // ping + 0x81, 0x03, 'b', 'a', 'r', // some other message + }; + + auto clientTask = client->write(DATA, sizeof(DATA)); + + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "bar"); + } + + byte EXPECTED1[] = { 0x81, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 }; + expectRead(*client, EXPECTED1).wait(waitScope); + expectRead(*client, bigString).wait(waitScope); + + byte EXPECTED2[] = { 0x8A, 0x03, 'f', 'o', 'o' }; + expectRead(*client, EXPECTED2).wait(waitScope); + + clientTask.wait(waitScope); + serverTask.wait(waitScope); +} + +class InputOutputPair final: public kj::AsyncIoStream { + // Creates an AsyncIoStream out of an AsyncInputStream and an AsyncOutputStream. + +public: + InputOutputPair(kj::Own in, kj::Own out) + : in(kj::mv(in)), out(kj::mv(out)) {} + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return in->read(buffer, minBytes, maxBytes); + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return in->tryRead(buffer, minBytes, maxBytes); + } + + Maybe tryGetLength() override { + return in->tryGetLength(); + } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount = kj::maxValue) override { + return in->pumpTo(output, amount); + } + + kj::Promise write(const void* buffer, size_t size) override { + return out->write(buffer, size); + } + + kj::Promise write(kj::ArrayPtr> pieces) override { + return out->write(pieces); + } + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + return out->tryPumpFrom(input, amount); + } + + Promise whenWriteDisconnected() override { + return out->whenWriteDisconnected(); + } + + void shutdownWrite() override { + out = nullptr; + } + +private: + kj::Own in; + kj::Own out; +}; + +KJ_TEST("WebSocket double-ping mid-send") { + KJ_HTTP_TEST_SETUP_IO; + + auto upPipe = newOneWayPipe(); + auto downPipe = newOneWayPipe(); + InputOutputPair client(kj::mv(downPipe.in), kj::mv(upPipe.out)); + auto server = newWebSocket(kj::heap(kj::mv(upPipe.in), kj::mv(downPipe.out)), + nullptr); + + auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), ""); + auto serverTask = server->send(bigString).eagerlyEvaluate(nullptr); + + byte DATA[] = { + 0x89, 0x03, 'f', 'o', 'o', // ping + 0x89, 0x03, 'q', 'u', 'x', // ping2 + 0x81, 0x03, 'b', 'a', 'r', // some other message + }; + + auto clientTask = client.write(DATA, sizeof(DATA)); + + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "bar"); + } + + byte EXPECTED1[] = { 0x81, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 }; + expectRead(client, EXPECTED1).wait(waitScope); + expectRead(client, bigString).wait(waitScope); + + byte EXPECTED2[] = { 0x8A, 0x03, 'q', 'u', 'x' }; + expectRead(client, EXPECTED2).wait(waitScope); + + clientTask.wait(waitScope); + serverTask.wait(waitScope); +} + +KJ_TEST("WebSocket ping received during pong send") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); + + // Send a very large ping so that sending the pong takes a while. Then send a second ping + // immediately after. + byte PREFIX[] = { 0x89, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 }; + auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), ""); + byte POSTFIX[] = { + 0x89, 0x03, 'f', 'o', 'o', + 0x81, 0x03, 'b', 'a', 'r', + }; + + kj::ArrayPtr parts[] = {PREFIX, bigString.asBytes(), POSTFIX}; + auto clientTask = client->write(parts); + + { + auto message = server->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "bar"); + } + + byte EXPECTED1[] = { 0x8A, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 }; + expectRead(*client, EXPECTED1).wait(waitScope); + expectRead(*client, bigString).wait(waitScope); + + byte EXPECTED2[] = { 0x8A, 0x03, 'f', 'o', 'o' }; + expectRead(*client, EXPECTED2).wait(waitScope); + + clientTask.wait(waitScope); +} + +KJ_TEST("WebSocket pump byte counting") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE; + auto pipe2 = KJ_HTTP_TEST_CREATE_2PIPE; + + FakeEntropySource maskGenerator; + auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr); + auto client2 = newWebSocket(kj::mv(pipe2.ends[0]), maskGenerator); + auto server2 = newWebSocket(kj::mv(pipe2.ends[1]), nullptr); + + auto pumpTask = server1->pumpTo(*client2); + auto receiveTask = server2->receive(); + + // Client sends three bytes of a valid message then disconnects. + const char DATA[] = {0x01, 0x06, 'h'}; + pipe1.ends[0]->write(DATA, 3).wait(waitScope); + pipe1.ends[0] = nullptr; + + // The pump completes successfully, forwarding the disconnect. + pumpTask.wait(waitScope); + + // The eventual receiver gets a disconnect exception. + // (Note: We don't use KJ_EXPECT_THROW here because under -fno-exceptions it forks and we lose + // state.) + receiveTask.then([](auto) { + KJ_FAIL_EXPECT("expected exception"); + }, [](kj::Exception&& e) { + KJ_EXPECT(e.getType() == kj::Exception::Type::DISCONNECTED); + }).wait(waitScope); + + KJ_EXPECT(server1->receivedByteCount() == 3); +#if KJ_NO_RTTI + // Optimized socket pump will be disabled, so only whole messages are counted by client2/server2. + KJ_EXPECT(client2->sentByteCount() == 0); + KJ_EXPECT(server2->receivedByteCount() == 0); +#else + KJ_EXPECT(client2->sentByteCount() == 3); + KJ_EXPECT(server2->receivedByteCount() == 3); +#endif +} + +KJ_TEST("WebSocket pump disconnect on send") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE; + auto pipe2 = KJ_HTTP_TEST_CREATE_2PIPE; + + FakeEntropySource maskGenerator; + auto client1 = newWebSocket(kj::mv(pipe1.ends[0]), maskGenerator); + auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr); + auto client2 = newWebSocket(kj::mv(pipe2.ends[0]), maskGenerator); + + auto pumpTask = server1->pumpTo(*client2); + auto sendTask = client1->send("hello"_kj); + + // Endpoint reads three bytes and then disconnects. + char buffer[3]; + pipe2.ends[1]->read(buffer, 3).wait(waitScope); + pipe2.ends[1] = nullptr; + + // Pump throws disconnected. + KJ_EXPECT_THROW_RECOVERABLE(DISCONNECTED, pumpTask.wait(waitScope)); + + // client1 may or may not have been able to send its whole message depending on buffering. + sendTask.then([]() {}, [](kj::Exception&& e) { + KJ_EXPECT(e.getType() == kj::Exception::Type::DISCONNECTED); + }).wait(waitScope); +} + +KJ_TEST("WebSocket pump disconnect on receive") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE; + auto pipe2 = KJ_HTTP_TEST_CREATE_2PIPE; + + FakeEntropySource maskGenerator; + auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr); + auto client2 = newWebSocket(kj::mv(pipe2.ends[0]), maskGenerator); + auto server2 = newWebSocket(kj::mv(pipe2.ends[1]), nullptr); + + auto pumpTask = server1->pumpTo(*client2); + auto receiveTask = server2->receive(); + + // Client sends three bytes of a valid message then disconnects. + const char DATA[] = {0x01, 0x06, 'h'}; + pipe1.ends[0]->write(DATA, 3).wait(waitScope); + pipe1.ends[0] = nullptr; + + // The pump completes successfully, forwarding the disconnect. + pumpTask.wait(waitScope); + + // The eventual receiver gets a disconnect exception. + KJ_EXPECT_THROW(DISCONNECTED, receiveTask.wait(waitScope)); +} + +KJ_TEST("WebSocket abort propagates through pipe") { + // Pumping one end of a WebSocket pipe into another WebSocket which later becomes aborted will + // cancel the pump promise with a DISCONNECTED exception. + + KJ_HTTP_TEST_SETUP_IO; + auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE; + + auto server = newWebSocket(kj::mv(pipe1.ends[1]), nullptr); + auto client = newWebSocket(kj::mv(pipe1.ends[0]), nullptr); + + auto wsPipe = newWebSocketPipe(); + + auto downstreamPump = wsPipe.ends[0]->pumpTo(*server); + KJ_EXPECT(!downstreamPump.poll(waitScope)); + + client->abort(); + + KJ_EXPECT(downstreamPump.poll(waitScope)); + KJ_EXPECT_THROW_RECOVERABLE(DISCONNECTED, downstreamPump.wait(waitScope)); +} + +KJ_TEST("WebSocket maximum message size") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe =KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + FakeEntropySource maskGenerator; + auto client = newWebSocket(kj::mv(pipe.ends[0]), maskGenerator); + auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr, nullptr, errorCatcher); + + size_t maxSize = 100; + auto biggestAllowedString = kj::strArray(kj::repeat(kj::StringPtr("A"), maxSize), ""); + auto tooBigString = kj::strArray(kj::repeat(kj::StringPtr("B"), maxSize + 1), ""); + + auto clientTask = client->send(biggestAllowedString) + .then([&]() { return client->send(tooBigString); }) + .then([&]() { return client->close(1234, "done"); }); + + { + auto message = server->receive(maxSize).wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().size() == maxSize); + } + + { + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("too large", + server->receive(maxSize).ignoreResult().wait(waitScope)); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1009); + } +} + +class TestWebSocketService final: public HttpService, private kj::TaskSet::ErrorHandler { +public: + TestWebSocketService(HttpHeaderTable& headerTable, HttpHeaderId hMyHeader) + : headerTable(headerTable), hMyHeader(hMyHeader), tasks(*this) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_ASSERT(headers.isWebSocket()); + + HttpHeaders responseHeaders(headerTable); + KJ_IF_MAYBE(h, headers.get(hMyHeader)) { + responseHeaders.set(hMyHeader, kj::str("respond-", *h)); + } + + if (url == "/return-error") { + response.send(404, "Not Found", responseHeaders, uint64_t(0)); + return kj::READY_NOW; + } else if (url == "/websocket") { + auto ws = response.acceptWebSocket(responseHeaders); + return doWebSocket(*ws, "start-inline").attach(kj::mv(ws)); + } else { + KJ_FAIL_ASSERT("unexpected path", url); + } + } + +private: + HttpHeaderTable& headerTable; + HttpHeaderId hMyHeader; + kj::TaskSet tasks; + + void taskFailed(kj::Exception&& exception) override { + KJ_LOG(ERROR, exception); + } + + static kj::Promise doWebSocket(WebSocket& ws, kj::StringPtr message) { + auto copy = kj::str(message); + return ws.send(copy).attach(kj::mv(copy)) + .then([&ws]() { + return ws.receive(); + }).then([&ws](WebSocket::Message&& message) { + KJ_SWITCH_ONEOF(message) { + KJ_CASE_ONEOF(str, kj::String) { + return doWebSocket(ws, kj::str("reply:", str)); + } + KJ_CASE_ONEOF(data, kj::Array) { + return doWebSocket(ws, kj::str("reply:", data)); + } + KJ_CASE_ONEOF(close, WebSocket::Close) { + auto reason = kj::str("close-reply:", close.reason); + return ws.close(close.code + 1, reason).attach(kj::mv(reason)); + } + } + KJ_UNREACHABLE; + }); + } +}; + +const char WEBSOCKET_REQUEST_HANDSHAKE[] = + " HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "My-Header: foo\r\n" + "\r\n"; +const char WEBSOCKET_RESPONSE_HANDSHAKE[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n" + "My-Header: respond-foo\r\n" + "\r\n"; +#if KJ_HAS_ZLIB +const char WEBSOCKET_COMPRESSION_HANDSHAKE[] = + " HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n" + "\r\n"; +const char WEBSOCKET_COMPRESSION_RESPONSE_HANDSHAKE[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n" + "\r\n"; +const char WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_HANDSHAKE[] = + " HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; " + "server_no_context_takeover\r\n" + "\r\n"; +const char WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_RESPONSE_HANDSHAKE[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; " + "server_no_context_takeover\r\n" + "\r\n"; +#endif // KJ_HAS_ZLIB +const char WEBSOCKET_RESPONSE_HANDSHAKE_ERROR[] = + "HTTP/1.1 404 Not Found\r\n" + "Content-Length: 0\r\n" + "My-Header: respond-foo\r\n" + "\r\n"; +const byte WEBSOCKET_FIRST_MESSAGE_INLINE[] = + { 0x81, 0x0c, 's','t','a','r','t','-','i','n','l','i','n','e' }; +const byte WEBSOCKET_SEND_MESSAGE[] = + { 0x81, 0x83, 12, 34, 56, 78, 'b'^12, 'a'^34, 'r'^56 }; +const byte WEBSOCKET_REPLY_MESSAGE[] = + { 0x81, 0x09, 'r','e','p','l','y',':','b','a','r' }; +const byte WEBSOCKET_SEND_CLOSE[] = + { 0x88, 0x85, 12, 34, 56, 78, 0x12^12, 0x34^34, 'q'^56, 'u'^78, 'x'^12 }; +const byte WEBSOCKET_REPLY_CLOSE[] = + { 0x88, 0x11, 0x12, 0x35, 'c','l','o','s','e','-','r','e','p','l','y',':','q','u','x' }; + +#if KJ_HAS_ZLIB +const byte WEBSOCKET_FIRST_COMPRESSED_MESSAGE[] = + { 0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00 }; +// See this example: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.2 +const byte WEBSOCKET_SEND_COMPRESSED_MESSAGE[] = + { 0xc1, 0x87, 12, 34, 56, 78, 0xf2^12, 0x48^34, 0xcd^56, 0xc9^78, 0xc9^12, 0x07^34, 0x00^56 }; +const byte WEBSOCKET_SEND_COMPRESSED_MESSAGE_REUSE_CTX[] = + { 0xc1, 0x85, 12, 34, 56, 78, 0xf2^12, 0x00^34, 0x11^56, 0x00^78, 0x00^12}; +// See same compression example, but where `client_no_context_takeover` is used (saves 2 bytes). +const byte WEBSOCKET_DEFLATE_NO_COMPRESSION_MESSAGE[] = + { 0xc1, 0x0b, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x00 }; +// See this example: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.3 +// This uses a DEFLATE block with no compression. +const byte WEBSOCKET_BFINAL_SET_MESSAGE[] = + { 0xc1, 0x08, 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00, 0x00 }; +// See this example: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.4 +// This uses a DEFLATE block with BFINAL set to 1. +const byte WEBSOCKET_TWO_DEFLATE_BLOCKS_MESSAGE[] = + { 0xc1, 0x0d, 0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff, 0xca, 0xc9, 0xc9, 0x07, 0x00 }; +// See this example: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.5 +// This uses two DEFLATE blocks in a single message. +const byte WEBSOCKET_EMPTY_COMPRESSED_MESSAGE[] = + { 0xc1, 0x01, 0x00 }; +const byte WEBSOCKET_EMPTY_SEND_COMPRESSED_MESSAGE[] = + { 0xc1, 0x81, 12, 34, 56, 78, 0x00^12 }; +#endif // KJ_HAS_ZLIB + +template +kj::ArrayPtr asBytes(const char (&chars)[s]) { + return kj::ArrayPtr(chars, s - 1).asBytes(); +} + +void testWebSocketClient(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId hMyHeader, HttpClient& client) { + kj::HttpHeaders headers(headerTable); + headers.set(hMyHeader, "foo"); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); + + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo"); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); + + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "start-inline"); + } + + ws->send(kj::StringPtr("bar")).wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "reply:bar"); + } + + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } +} + +#if KJ_HAS_ZLIB +void testWebSocketTwoMessageCompression(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId extHeader, kj::StringPtr extensions, + HttpClient& client) { + // In this test, the server will always use `server_no_context_takeover` (since we can just reuse + // the message). However, we will modify the client's compressor in different ways to see how the + // compressed message changes. + + kj::HttpHeaders headers(headerTable); + headers.set(extHeader, extensions); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); + + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(extHeader)).startsWith("permessage-deflate")); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); + + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + ws->send(kj::StringPtr("Hello")).wait(waitScope); + + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + ws->send(kj::StringPtr("Hello")).wait(waitScope); + + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +void testWebSocketEmptyMessageCompression(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId extHeader, kj::StringPtr extensions, + HttpClient& client) { + // Confirm that we can send empty messages when compression is enabled. + + kj::HttpHeaders headers(headerTable); + headers.set(extHeader, extensions); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); + + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(extHeader)).startsWith("permessage-deflate")); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); + + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + ws->send(kj::StringPtr("Hello")).wait(waitScope); + + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == ""); + } + ws->send(kj::StringPtr("")).wait(waitScope); + + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + ws->send(kj::StringPtr("Hello")).wait(waitScope); + + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +void testWebSocketOptimizePumpProxy(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId extHeader, kj::StringPtr extensions, + HttpClient& client) { + // Suppose we are proxying a websocket conversation between a client and a server. + // This looks something like: CLIENT <--> (proxyServer <==PUMP==> proxyClient) <--> SERVER + // + // We want to enable optimizedPumping from the proxy's server (which communicates with the client), + // to the proxy's client (which communicates with the origin server). + // + // For this to work, proxyServer's inbound settings must map to proxyClient's outbound settings + // (and vice versa). In this case, `ws` is `proxyClient`, so we want to take `ws`'s compression + // configuration and pass it to `proxyServer` in a way that would allow for optimizedPumping. + + kj::HttpHeaders headers(headerTable); + headers.set(extHeader, extensions); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); + + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(extHeader)).startsWith("permessage-deflate")); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); + + auto maybeExt = ws->getPreferredExtensions(kj::WebSocket::ExtensionsContext::REQUEST); + // Should be nullptr since we are asking `ws` (a client) to give us extensions that we can give to + // another client. Since clients cannot `optimizedPumpTo` each other, we must get null. + KJ_ASSERT(maybeExt == nullptr); + + maybeExt = ws->getPreferredExtensions(kj::WebSocket::ExtensionsContext::RESPONSE); + kj::StringPtr extStr = KJ_ASSERT_NONNULL(maybeExt); + KJ_ASSERT(extStr == "permessage-deflate; server_no_context_takeover"); + // We got back the string the client sent! + // We could then pass this string as a header to `acceptWebSocket` and ensure the `proxyServer`s + // inbound settings match the `proxyClient`s outbound settings. + + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } +} +#endif // KJ_HAS_ZLIB +#if KJ_HAS_ZLIB +void testWebSocketFourMessageCompression(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, + kj::HttpHeaderId extHeader, kj::StringPtr extensions, + HttpClient& client) { + // In this test, the server will always use `server_no_context_takeover` (since we can just reuse + // the message). We will receive three messages. + + kj::HttpHeaders headers(headerTable); + headers.set(extHeader, extensions); + auto response = client.openWebSocket("/websocket", headers).wait(waitScope); + + KJ_EXPECT(response.statusCode == 101); + KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(extHeader)).startsWith("permessage-deflate")); + KJ_ASSERT(response.webSocketOrBody.is>()); + auto ws = kj::mv(response.webSocketOrBody.get>()); + + for (size_t i = 0; i < 4; i++) { + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get() == "Hello"); + } + } + + ws->close(0x1234, "qux").wait(waitScope); + { + auto message = ws->receive().wait(waitScope); + KJ_ASSERT(message.is()); + KJ_EXPECT(message.get().code == 0x1235); + KJ_EXPECT(message.get().reason == "close-reply:qux"); + } +} +#endif // KJ_HAS_ZLIB + +inline kj::Promise writeA(kj::AsyncOutputStream& out, kj::ArrayPtr data) { + return out.write(data.begin(), data.size()); +} + +KJ_TEST("HttpClient WebSocket handshake") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + testWebSocketClient(waitScope, *headerTable, hMyHeader, *client); + + serverTask.wait(waitScope); +} + +KJ_TEST("WebSocket Compression String Parsing (splitNext)") { + // Test `splitNext()`. + // We want to assert that: + // If a delimiter is found: + // - `input` is updated to point to the rest of the string after the delimiter. + // - The text before the delimiter is returned. + // If no delimiter is found: + // - `input` is updated to an empty string. + // - The text that had been in `input` is returned. + + const auto s = "permessage-deflate; client_max_window_bits=10;server_no_context_takeover"_kj; + + const auto expectedPartOne = "permessage-deflate"_kj; + const auto expectedRemainingOne = "client_max_window_bits=10;server_no_context_takeover"_kj; + + auto cursor = s.asArray(); + auto actual = _::splitNext(cursor, ';'); + KJ_ASSERT(actual == expectedPartOne); + + _::stripLeadingAndTrailingSpace(cursor); + KJ_ASSERT(cursor == expectedRemainingOne.asArray()); + + const auto expectedPartTwo = "client_max_window_bits=10"_kj; + const auto expectedRemainingTwo = "server_no_context_takeover"_kj; + + actual = _::splitNext(cursor, ';'); + KJ_ASSERT(actual == expectedPartTwo); + KJ_ASSERT(cursor == expectedRemainingTwo); + + const auto expectedPartThree = "server_no_context_takeover"_kj; + const auto expectedRemainingThree = ""_kj; + actual = _::splitNext(cursor, ';'); + KJ_ASSERT(actual == expectedPartThree); + KJ_ASSERT(cursor == expectedRemainingThree); +} + +KJ_TEST("WebSocket Compression String Parsing (splitParts)") { + // Test `splitParts()`. + // We want to assert that we: + // 1. Correctly split by the delimiter. + // 2. Strip whitespace before/after the extracted part. + const auto permitted = "permessage-deflate"_kj; + + const auto s = "permessage-deflate; client_max_window_bits=10;server_no_context_takeover, " + " permessage-deflate; ; ," // strips leading whitespace + "permessage-deflate"_kj; + + // These are the expected values. + const auto extOne = "permessage-deflate; client_max_window_bits=10;server_no_context_takeover"_kj; + const auto extTwo = "permessage-deflate; ;"_kj; + const auto extThree = "permessage-deflate"_kj; + + auto actualExtensions = kj::_::splitParts(s, ','); + KJ_ASSERT(actualExtensions.size() == 3); + KJ_ASSERT(actualExtensions[0] == extOne); + KJ_ASSERT(actualExtensions[1] == extTwo); + KJ_ASSERT(actualExtensions[2] == extThree); + // Splitting by ',' was fine, now let's try splitting the parameters (split by ';'). + + const auto paramOne = "client_max_window_bits=10"_kj; + const auto paramTwo = "server_no_context_takeover"_kj; + + auto actualParamsFirstExt = kj::_::splitParts(actualExtensions[0], ';'); + KJ_ASSERT(actualParamsFirstExt.size() == 3); + KJ_ASSERT(actualParamsFirstExt[0] == permitted); + KJ_ASSERT(actualParamsFirstExt[1] == paramOne); + KJ_ASSERT(actualParamsFirstExt[2] == paramTwo); + + auto actualParamsSecondExt = kj::_::splitParts(actualExtensions[1], ';'); + KJ_ASSERT(actualParamsSecondExt.size() == 2); + KJ_ASSERT(actualParamsSecondExt[0] == permitted); + KJ_ASSERT(actualParamsSecondExt[1] == ""_kj); // Note that the whitespace was stripped. + + auto actualParamsThirdExt = kj::_::splitParts(actualExtensions[2], ';'); + // No parameters supplied in the third offer. We expect to only see the extension name. + KJ_ASSERT(actualParamsThirdExt.size() == 1); + KJ_ASSERT(actualParamsThirdExt[0] == permitted); +} + +KJ_TEST("WebSocket Compression String Parsing (toKeysAndVals)") { + // If an "=" is found, everything before the "=" goes into the `Key` and everything after goes + // into the `Value`. Otherwise, everything goes into the `Key` and the `Value` remains null. + const auto cleanParameters = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=10"_kj; + auto parts = _::splitParts(cleanParameters, ';'); + auto keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(keysMaybeValues.size() == 3); + + auto firstKey = "client_no_context_takeover"_kj; + KJ_ASSERT(keysMaybeValues[0].key == firstKey.asArray()); + KJ_ASSERT(keysMaybeValues[0].val == nullptr); + + auto secondKey = "client_max_window_bits"_kj; + KJ_ASSERT(keysMaybeValues[1].key == secondKey.asArray()); + KJ_ASSERT(keysMaybeValues[1].val == nullptr); + + auto thirdKey = "server_max_window_bits"_kj; + auto thirdVal = "10"_kj; + KJ_ASSERT(keysMaybeValues[2].key == thirdKey.asArray()); + KJ_ASSERT(keysMaybeValues[2].val == thirdVal.asArray()); + + const auto weirdParameters = "= 14 ; client_max_window_bits= ; server_max_window_bits =hello"_kj; + // This is weird because: + // 1. Parameter 1 has no key. + // 2. Parameter 2 has an "=" but no subsequent value. + // 3. Parameter 3 has an "=" with an invalid value. + // That said, we don't mind if the parameters are weird when calling this function. The point + // is to create KeyMaybeVal pairs and process them later. + + parts = _::splitParts(weirdParameters, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(keysMaybeValues.size() == 3); + + firstKey = ""_kj; + auto firstVal = "14"_kj; + KJ_ASSERT(keysMaybeValues[0].key == firstKey.asArray()); + KJ_ASSERT(keysMaybeValues[0].val == firstVal.asArray()); + + secondKey = "client_max_window_bits"_kj; + auto secondVal = ""_kj; + KJ_ASSERT(keysMaybeValues[1].key == secondKey.asArray()); + KJ_ASSERT(keysMaybeValues[1].val == secondVal.asArray()); + + thirdKey = "server_max_window_bits"_kj; + thirdVal = "hello"_kj; + KJ_ASSERT(keysMaybeValues[2].key == thirdKey.asArray()); + KJ_ASSERT(keysMaybeValues[2].val == thirdVal.asArray()); +} + +KJ_TEST("WebSocket Compression String Parsing (populateUnverifiedConfig)") { + // First we'll cover cases where the `UnverifiedConfig` is successfully constructed, + // which indicates the offer was structured in a parseable way. Next, we'll cover cases where the + // offer is structured incorrectly. + const auto cleanParameters = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=10"_kj; + auto parts = _::splitParts(cleanParameters, ';'); + auto keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + + auto unverified = _::populateUnverifiedConfig(keysMaybeValues); + auto config = KJ_ASSERT_NONNULL(unverified); + KJ_ASSERT(config.clientNoContextTakeover == true); + KJ_ASSERT(config.serverNoContextTakeover == false); + + auto clientBits = KJ_ASSERT_NONNULL(config.clientMaxWindowBits); + KJ_ASSERT(clientBits == ""_kj); + auto serverBits = KJ_ASSERT_NONNULL(config.serverMaxWindowBits); + KJ_ASSERT(serverBits == "10"_kj); + // Valid config can be populated succesfully. + + const auto weirdButValidParameters = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=this_should_be_a_number"_kj; + parts = _::splitParts(weirdButValidParameters, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + + unverified = _::populateUnverifiedConfig(keysMaybeValues); + config = KJ_ASSERT_NONNULL(unverified); + KJ_ASSERT(config.clientNoContextTakeover == true); + KJ_ASSERT(config.serverNoContextTakeover == false); + + clientBits = KJ_ASSERT_NONNULL(config.clientMaxWindowBits); + KJ_ASSERT(clientBits == ""_kj); + serverBits = KJ_ASSERT_NONNULL(config.serverMaxWindowBits); + KJ_ASSERT(serverBits == "this_should_be_a_number"_kj); + // Note that while the value associated with `server_max_window_bits` is not a number, + // `populateUnverifiedConfig` succeeds because the parameter[=value] is generally structured + // correctly. + + // --- HANDLE INCORRECTLY STRUCTURED OFFERS --- + auto invalidKey = "somethingKey; client_max_window_bits;"_kj; + parts = _::splitParts(invalidKey, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to invalid key name + + auto invalidKeyTwo = "client_max_window_bitsJUNK; server_no_context_takeover"_kj; + parts = _::splitParts(invalidKeyTwo, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to invalid key name (invalid characters after valid parameter name). + + auto repeatedKey = "client_no_context_takeover; client_no_context_takeover"_kj; + parts = _::splitParts(repeatedKey, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to repeated key name. + + auto unexpectedValue = "client_no_context_takeover="_kj; + parts = _::splitParts(unexpectedValue, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to value in `x_no_context_takeover` parameter (unexpected value). + + auto unexpectedValueTwo = "client_no_context_takeover= "_kj; + parts = _::splitParts(unexpectedValueTwo, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to value in `x_no_context_takeover` parameter. + + auto emptyValue = "client_max_window_bits="_kj; + parts = _::splitParts(emptyValue, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to empty value in `x_max_window_bits` parameter. + // "Empty" in this case means an "=" was provided, but no subsequent value was provided. + + auto emptyValueTwo = "client_max_window_bits= "_kj; + parts = _::splitParts(emptyValueTwo, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + KJ_ASSERT(_::populateUnverifiedConfig(keysMaybeValues) == nullptr); + // Fail to populate due to empty value in `x_max_window_bits` parameter. + // "Empty" in this case means an "=" was provided, but no subsequent value was provided. +} + +KJ_TEST("WebSocket Compression String Parsing (validateCompressionConfig)") { + // We've tested `toKeysAndVals()` and `populateUnverifiedConfig()`, so we only need to test + // correctly structured offers/agreements here. + const auto cleanParameters = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=10"_kj; + auto parts = _::splitParts(cleanParameters, ';'); + auto keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + auto maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + auto unverified = KJ_ASSERT_NONNULL(maybeUnverified); + auto maybeValid = _::validateCompressionConfig(kj::mv(unverified), false); // Validate as Server. + auto valid = KJ_ASSERT_NONNULL(maybeValid); + KJ_ASSERT(valid.inboundNoContextTakeover == true); + KJ_ASSERT(valid.outboundNoContextTakeover == false); + auto inboundBits = KJ_ASSERT_NONNULL(valid.inboundMaxWindowBits); + KJ_ASSERT(inboundBits == 15); // `client_max_window_bits` can be empty in an offer. + auto outboundBits = KJ_ASSERT_NONNULL(valid.outboundMaxWindowBits); + KJ_ASSERT(outboundBits == 10); + // Valid config successfully constructed. + + const auto correctStructureButInvalid = "client_no_context_takeover; client_max_window_bits; " + "server_max_window_bits=this_should_be_a_number"_kj; + parts = _::splitParts(correctStructureButInvalid, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + unverified = KJ_ASSERT_NONNULL(maybeUnverified); + maybeValid = _::validateCompressionConfig(kj::mv(unverified), false); // Validate as Server. + KJ_ASSERT(maybeValid == nullptr); + // The config "looks" correct, but the `server_max_window_bits` parameter has an invalid value. + + const auto invalidRange = "client_max_window_bits; server_max_window_bits=18;"_kj; + // `server_max_window_bits` is out of range, decline. + parts = _::splitParts(invalidRange, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + maybeValid = _::validateCompressionConfig(kj::mv(KJ_REQUIRE_NONNULL(maybeUnverified)), false); + KJ_ASSERT(maybeValid == nullptr); + + const auto invalidRangeTwo = "client_max_window_bits=4"_kj; + // `server_max_window_bits` is out of range, decline. + parts = _::splitParts(invalidRangeTwo, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + maybeValid = _::validateCompressionConfig(kj::mv(KJ_REQUIRE_NONNULL(maybeUnverified)), false); + KJ_ASSERT(maybeValid == nullptr); + + const auto invalidRequest = "server_max_window_bits"_kj; + // `sever_max_window_bits` must have a value in a request AND a response. + parts = _::splitParts(invalidRequest, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + maybeValid = _::validateCompressionConfig(kj::mv(KJ_REQUIRE_NONNULL(maybeUnverified)), false); + KJ_ASSERT(maybeValid == nullptr); + + const auto invalidResponse = "client_max_window_bits"_kj; + // `client_max_window_bits` must have a value in a response. + parts = _::splitParts(invalidResponse, ';'); + keysMaybeValues = _::toKeysAndVals(parts.asPtr()); + maybeUnverified = _::populateUnverifiedConfig(keysMaybeValues); + maybeValid = _::validateCompressionConfig(kj::mv(KJ_REQUIRE_NONNULL(maybeUnverified)), true); + KJ_ASSERT(maybeValid == nullptr); +} + +KJ_TEST("WebSocket Compression String Parsing (findValidExtensionOffers)") { + // Test that we can extract only the valid extensions from a string of offers. + constexpr auto extensions = "permessage-deflate; " // Valid offer. + "client_no_context_takeover; " + "client_max_window_bits; " + "server_max_window_bits=10, " + "permessage-deflate; " // Another valid offer. + "client_no_context_takeover; " + "client_max_window_bits, " + "permessage-invalid; " // Invalid ext name. + "client_no_context_takeover, " + "permessage-deflate; " // Invalid parmeter. + "invalid_parameter; " + "client_max_window_bits; " + "server_max_window_bits=10, " + "permessage-deflate; " // Invalid parmeter value. + "server_max_window_bits=should_be_a_number, " + "permessage-deflate; " // Unexpected parmeter value. + "client_max_window_bits=true, " + "permessage-deflate; " // Missing expected parmeter value. + "server_max_window_bits, " + "permessage-deflate; " // Invalid parameter value (too high). + "client_max_window_bits=99, " + "permessage-deflate; " // Invalid parameter value (too low). + "client_max_window_bits=4, " + "permessage-deflate; " // Invalid parameter (repeated). + "client_max_window_bits; " + "client_max_window_bits, " + "permessage-deflate"_kj; // Valid offer (no parameters). + + auto validOffers = _::findValidExtensionOffers(extensions); + KJ_ASSERT(validOffers.size() == 3); + KJ_ASSERT(validOffers[0].outboundNoContextTakeover == true); + KJ_ASSERT(validOffers[0].inboundNoContextTakeover == false); + KJ_ASSERT(validOffers[0].outboundMaxWindowBits == 15); + KJ_ASSERT(validOffers[0].inboundMaxWindowBits == 10); + + KJ_ASSERT(validOffers[1].outboundNoContextTakeover == true); + KJ_ASSERT(validOffers[1].inboundNoContextTakeover == false); + KJ_ASSERT(validOffers[1].outboundMaxWindowBits == 15); + KJ_ASSERT(validOffers[1].inboundMaxWindowBits == nullptr); + + KJ_ASSERT(validOffers[2].outboundNoContextTakeover == false); + KJ_ASSERT(validOffers[2].inboundNoContextTakeover == false); + KJ_ASSERT(validOffers[2].outboundMaxWindowBits == nullptr); + KJ_ASSERT(validOffers[2].inboundMaxWindowBits == nullptr); +} + +KJ_TEST("WebSocket Compression String Parsing (generateExtensionRequest)") { + // Test that we can extract only the valid extensions from a string of offers. + constexpr auto extensions = "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits=10; " + "client_max_window_bits, " + "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits, " + "permessage-deflate"_kj; + constexpr auto EXPECTED = "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits=15; " + "server_max_window_bits=10, " + "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits=15, " + "permessage-deflate"_kj; + auto validOffers = _::findValidExtensionOffers(extensions); + auto extensionRequest = _::generateExtensionRequest(validOffers); + KJ_ASSERT(extensionRequest == EXPECTED); +} + +KJ_TEST("WebSocket Compression String Parsing (tryParseExtensionOffers)") { + // Test that we can accept a valid offer from string of offers. + constexpr auto extensions = "permessage-invalid; " // Invalid ext name. + "client_no_context_takeover, " + "permessage-deflate; " // Invalid parmeter. + "invalid_parameter; " + "client_max_window_bits; " + "server_max_window_bits=10, " + "permessage-deflate; " // Invalid parmeter value. + "server_max_window_bits=should_be_a_number, " + "permessage-deflate; " // Unexpected parmeter value. + "client_max_window_bits=true, " + "permessage-deflate; " // Missing expected parmeter value. + "server_max_window_bits, " + "permessage-deflate; " // Invalid parameter value (too high). + "client_max_window_bits=99, " + "permessage-deflate; " // Invalid parameter value (too low). + "client_max_window_bits=4, " + "permessage-deflate; " // Invalid parameter (repeated). + "client_max_window_bits; " + "client_max_window_bits, " + "permessage-deflate; " // Valid offer. + "client_no_context_takeover; " + "client_max_window_bits; " + "server_max_window_bits=10, " + "permessage-deflate; " // Another valid offer. + "client_no_context_takeover; " + "client_max_window_bits, " + "permessage-deflate"_kj; // Valid offer (no parameters). + + auto maybeAccepted = _::tryParseExtensionOffers(extensions); + auto accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == true); + KJ_ASSERT(accepted.outboundMaxWindowBits == 10); + KJ_ASSERT(accepted.inboundMaxWindowBits == 15); + + // Try the second valid offer from the big list above. + auto offerTwo = "permessage-deflate; client_no_context_takeover; client_max_window_bits"_kj; + maybeAccepted = _::tryParseExtensionOffers(offerTwo); + accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == true); + KJ_ASSERT(accepted.outboundMaxWindowBits == nullptr); + KJ_ASSERT(accepted.inboundMaxWindowBits == 15); + + auto offerThree = "permessage-deflate"_kj; // The third valid offer. + maybeAccepted = _::tryParseExtensionOffers(offerThree); + accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == false); + KJ_ASSERT(accepted.outboundMaxWindowBits == nullptr); + KJ_ASSERT(accepted.inboundMaxWindowBits == nullptr); + + auto invalid = "invalid"_kj; // Any of the invalid offers we saw above would return NULL. + maybeAccepted = _::tryParseExtensionOffers(invalid); + KJ_ASSERT(maybeAccepted == nullptr); +} + +KJ_TEST("WebSocket Compression String Parsing (tryParseAllExtensionOffers)") { + // We want to test the following: + // 1. We reject all if we don't find an offer we can accept. + // 2. We accept one after iterating over offers that we have to reject. + // 3. We accept an offer with a `server_max_window_bits` parameter if the manual config allows + // it, and choose the smaller "number of bits" (from clients request). + // 4. We accept an offer with a `server_no_context_takeover` parameter if the manual config + // allows it, and choose the smaller "number of bits" (from manual config) from + // `server_max_window_bits`. + constexpr auto serverOnly = "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits = 14; " + "server_no_context_takeover, " + "permessage-deflate; " + "client_no_context_takeover; " + "server_no_context_takeover, " + "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits = 14"_kj; + + constexpr auto acceptLast = "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits = 14; " + "server_no_context_takeover, " + "permessage-deflate; " + "client_no_context_takeover; " + "server_no_context_takeover, " + "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits = 14, " + "permessage-deflate; " // accept this + "client_no_context_takeover"_kj; + + const auto defaultConfig = CompressionParameters(); + // Our default config is equivalent to `permessage-deflate` with no parameters. + + auto maybeAccepted = _::tryParseAllExtensionOffers(serverOnly, defaultConfig); + KJ_ASSERT(maybeAccepted == nullptr); + // Asserts that we rejected all the offers with `server_x` parameters. + + maybeAccepted = _::tryParseAllExtensionOffers(acceptLast, defaultConfig); + auto accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == false); + KJ_ASSERT(accepted.outboundMaxWindowBits == nullptr); + KJ_ASSERT(accepted.inboundMaxWindowBits == nullptr); + // Asserts that we accepted the only offer that did not have a `server_x` parameter. + + const auto allowServerBits = CompressionParameters { + false, + false, + 15, // server_max_window_bits = 15 + nullptr + }; + maybeAccepted = _::tryParseAllExtensionOffers(serverOnly, allowServerBits); + accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == false); + KJ_ASSERT(accepted.inboundNoContextTakeover == false); + KJ_ASSERT(accepted.outboundMaxWindowBits == 14); // Note that we chose the lower of (14, 15). + KJ_ASSERT(accepted.inboundMaxWindowBits == nullptr); + // Asserts that we accepted an offer that allowed for `server_max_window_bits` AND we chose the + // lower number of bits (in this case, the clients offer of 14). + + const auto allowServerTakeoverAndBits = CompressionParameters { + true, // server_no_context_takeover = true + false, + 13, // server_max_window_bits = 13 + nullptr + }; + + maybeAccepted = _::tryParseAllExtensionOffers(serverOnly, allowServerTakeoverAndBits); + accepted = KJ_ASSERT_NONNULL(maybeAccepted); + KJ_ASSERT(accepted.outboundNoContextTakeover == true); + KJ_ASSERT(accepted.inboundNoContextTakeover == false); + KJ_ASSERT(accepted.outboundMaxWindowBits == 13); // Note that we chose the lower of (14, 15). + KJ_ASSERT(accepted.inboundMaxWindowBits == nullptr); + // Asserts that we accepted an offer that allowed for `server_no_context_takeover` AND we chose + // the lower number of bits (in this case, the manual config's choice of 13). +} + +KJ_TEST("WebSocket Compression String Parsing (generateExtensionResponse)") { + // Test that we can extract only the valid extensions from a string of offers. + constexpr auto extensions = "permessage-deflate; " + "client_no_context_takeover; " + "server_max_window_bits=10; " + "client_max_window_bits, " + "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits, " + "permessage-deflate"_kj; + constexpr auto EXPECTED = "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits=15; " + "server_max_window_bits=10"_kj; + auto accepted = _::tryParseExtensionOffers(extensions); + auto extensionResponse = _::generateExtensionResponse(KJ_ASSERT_NONNULL(accepted)); + KJ_ASSERT(extensionResponse == EXPECTED); +} + +KJ_TEST("WebSocket Compression String Parsing (tryParseExtensionAgreement)") { + constexpr auto didNotOffer = "Server failed WebSocket handshake: " + "added Sec-WebSocket-Extensions when client did not offer any."_kj; + constexpr auto tooMany = "Server failed WebSocket handshake: " + "expected exactly one extension (permessage-deflate) but received more than one."_kj; + constexpr auto badExt = "Server failed WebSocket handshake: " + "response included a Sec-WebSocket-Extensions value that was not permessage-deflate."_kj; + constexpr auto badVal = "Server failed WebSocket handshake: " + "the Sec-WebSocket-Extensions header in the Response included an invalid value."_kj; + + constexpr auto tooManyExtensions = "permessage-deflate; client_no_context_takeover; " + "client_max_window_bits; server_max_window_bits=10, " + "permessage-deflate; client_no_context_takeover; client_max_window_bits;"_kj; + + auto maybeAccepted = _::tryParseExtensionAgreement(nullptr, tooManyExtensions); + KJ_ASSERT( + KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == didNotOffer); + + Maybe defaultConfig = CompressionParameters{}; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, tooManyExtensions); + KJ_ASSERT(KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == tooMany); + + constexpr auto invalidExt = "permessage-invalid; " + "client_no_context_takeover; " + "client_max_window_bits; " + "server_max_window_bits=10;"; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, invalidExt); + KJ_ASSERT(KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == badExt); + + constexpr auto invalidVal = "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits; " + "server_max_window_bits=100;"; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, invalidVal); + KJ_ASSERT(KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == badVal); + + constexpr auto missingVal = "permessage-deflate; " + "client_no_context_takeover; " + "client_max_window_bits; " // This must have a value in a Response! + "server_max_window_bits=10;"; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, missingVal); + KJ_ASSERT(KJ_ASSERT_NONNULL(maybeAccepted.tryGet()).getDescription() == badVal); + + constexpr auto valid = "permessage-deflate; client_no_context_takeover; " + "client_max_window_bits=15; server_max_window_bits=10"_kj; + maybeAccepted = _::tryParseExtensionAgreement(defaultConfig, valid); + auto config = KJ_ASSERT_NONNULL(maybeAccepted.tryGet()); + KJ_ASSERT(config.outboundNoContextTakeover == true); + KJ_ASSERT(config.inboundNoContextTakeover == false); + KJ_ASSERT(config.outboundMaxWindowBits == 15); + KJ_ASSERT(config.inboundMaxWindowBits == 10); + + auto client = CompressionParameters{ true, false, 15, 10 }; + // If the server ignores our `client_no_context_takeover` parameter, we (the client) still use it. + constexpr auto serverIgnores = "permessage-deflate; client_max_window_bits=15; " + "server_max_window_bits=10"_kj; + maybeAccepted = _::tryParseExtensionAgreement(client, serverIgnores); + config = KJ_ASSERT_NONNULL(maybeAccepted.tryGet()); + KJ_ASSERT(config.outboundNoContextTakeover == true); // Note that this is missing in the response. + KJ_ASSERT(config.inboundNoContextTakeover == false); + KJ_ASSERT(config.outboundMaxWindowBits == 15); + KJ_ASSERT(config.inboundMaxWindowBits == 10); +} + +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Empty Message Compression") { + // We'll try to send and receive "Hello", then "", followed by "Hello" again. + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_HANDSHAKE); + + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_COMPRESSION_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_EMPTY_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_EMPTY_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE_REUSE_CTX); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; + + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + constexpr auto extensions = "permessage-deflate; server_no_context_takeover"_kj; + testWebSocketEmptyMessageCompression(waitScope, *headerTable, extHeader, extensions, *client); + + serverTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Default Compression") { + // We'll try to send and receive "Hello" twice. The second time we receive "Hello", the compressed + // message will be smaller as a result of the client reusing the lookback window. + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_HANDSHAKE); + + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_COMPRESSION_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE_REUSE_CTX); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; + + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + constexpr auto extensions = "permessage-deflate; server_no_context_takeover"_kj; + testWebSocketTwoMessageCompression(waitScope, *headerTable, extHeader, extensions, *client); + + serverTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Extract Extensions") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_HANDSHAKE); + + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_COMPRESSION_RESPONSE_HANDSHAKE)); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; + + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + constexpr auto extensions = "permessage-deflate; server_no_context_takeover"_kj; + testWebSocketOptimizePumpProxy(waitScope, *headerTable, extHeader, extensions, *client); + + serverTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Compression (Client Discards Compression Context)") { + // We'll try to send and receive "Hello" twice. The second time we receive "Hello", the compressed + // message will be the same size as the first time, since the client discards the lookback window. + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_HANDSHAKE); + + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], + asBytes(WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_FIRST_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; + + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + constexpr auto extensions = + "permessage-deflate; client_no_context_takeover; server_no_context_takeover"_kj; + testWebSocketTwoMessageCompression(waitScope, *headerTable, extHeader, extensions, *client); + + serverTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + +#if KJ_HAS_ZLIB +KJ_TEST("HttpClient WebSocket Compression (Different DEFLATE blocks)") { + // In this test, we'll try to use the following DEFLATE blocks: + // - Two DEFLATE blocks in 1 message. + // - A block with no compression. + // - A block with BFINAL set to 1. + // Then, we'll try to send a normal compressed message following the BFINAL message to ensure we + // can still process messages after receiving BFINAL. + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_HANDSHAKE); + + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], + asBytes(WEBSOCKET_COMPRESSION_CLIENT_DISCARDS_CTX_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_TWO_DEFLATE_BLOCKS_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_DEFLATE_NO_COMPRESSION_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_BFINAL_SET_MESSAGE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_SEND_COMPRESSED_MESSAGE); }) + .then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId extHeader = tableBuilder.add("Sec-WebSocket-Extensions"); + auto headerTable = tableBuilder.build(); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.webSocketCompressionMode = HttpClientSettings::MANUAL_COMPRESSION; + + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + constexpr auto extensions = + "permessage-deflate; client_no_context_takeover; server_no_context_takeover"_kj; + testWebSocketFourMessageCompression(waitScope, *headerTable, extHeader, extensions, *client); + + serverTask.wait(waitScope); +} +#endif // KJ_HAS_ZLIB + +KJ_TEST("HttpClient WebSocket error") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + + auto serverTask = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)); }) + .then([&]() { return expectRead(*pipe.ends[1], request); }) + .then([&]() { return writeA(*pipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + + auto client = newHttpClient(*headerTable, *pipe.ends[0], clientSettings); + + kj::HttpHeaders headers(*headerTable); + headers.set(hMyHeader, "foo"); + + { + auto response = client->openWebSocket("/websocket", headers).wait(waitScope); + + KJ_EXPECT(response.statusCode == 404); + KJ_EXPECT(response.statusText == "Not Found", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo"); + KJ_ASSERT(response.webSocketOrBody.is>()); + } + + { + auto response = client->openWebSocket("/websocket", headers).wait(waitScope); + + KJ_EXPECT(response.statusCode == 404); + KJ_EXPECT(response.statusText == "Not Found", response.statusText); + KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo"); + KJ_ASSERT(response.webSocketOrBody.is>()); + } + + serverTask.wait(waitScope); +} + +KJ_TEST("HttpServer WebSocket handshake") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + TestWebSocketService service(*headerTable, hMyHeader); + HttpServer server(timer, *headerTable, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); + + expectRead(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); + writeA(*pipe.ends[1], WEBSOCKET_SEND_MESSAGE).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE).wait(waitScope); + writeA(*pipe.ends[1], WEBSOCKET_SEND_CLOSE).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE).wait(waitScope); + + listenTask.wait(waitScope); +} + +KJ_TEST("HttpServer WebSocket handshake error") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + TestWebSocketService service(*headerTable, hMyHeader); + HttpServer server(timer, *headerTable, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto request = kj::str("GET /return-error", WEBSOCKET_REQUEST_HANDSHAKE); + writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(waitScope); + + // Can send more requests! + writeA(*pipe.ends[1], request.asBytes()).wait(waitScope); + expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(waitScope); + + pipe.ends[1]->shutdownWrite(); + + listenTask.wait(waitScope); +} + +void testBadWebSocketHandshake( + WaitScope& waitScope, Timer& timer, StringPtr request, StringPtr response, TwoWayPipe pipe) { + // Write an invalid WebSocket GET request, and expect a particular error response. + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + TestWebSocketService service(*headerTable, hMyHeader); + + class ErrorHandler final: public HttpServerErrorHandler { + Promise handleApplicationError( + Exception exception, Maybe response) override { + // When I first wrote this, I expected this function to be called, because + // `TestWebSocketService::request()` definitely throws. However, the exception it throws comes + // from `HttpService::Response::acceptWebSocket()`, which stores the fact which it threw a + // WebSocket error. This prevents the HttpServer's listen loop from propagating the exception + // to our HttpServerErrorHandler (i.e., this function), because it assumes the exception is + // related to the WebSocket error response. See `HttpServer::Connection::startLoop()` for + // details. + bool responseWasSent = response == nullptr; + KJ_FAIL_EXPECT("Unexpected application error", responseWasSent, exception); + return READY_NOW; + } + }; + + ErrorHandler errorHandler; + + HttpServerSettings serverSettings; + serverSettings.errorHandler = errorHandler; + + HttpServer server(timer, *headerTable, service, serverSettings); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], response).wait(waitScope); + + listenTask.wait(waitScope); +} + +KJ_TEST("HttpServer WebSocket handshake with unsupported Sec-WebSocket-Version") { + static constexpr auto REQUEST = + "GET /websocket HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 1\r\n" + "My-Header: foo\r\n" + "\r\n"_kj; + + static constexpr auto RESPONSE = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 56\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The requested WebSocket version is not supported."_kj; + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + testBadWebSocketHandshake(waitScope, timer, REQUEST, RESPONSE, KJ_HTTP_TEST_CREATE_2PIPE); +} + +KJ_TEST("HttpServer WebSocket handshake with missing Sec-WebSocket-Key") { + static constexpr auto REQUEST = + "GET /websocket HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Version: 13\r\n" + "My-Header: foo\r\n" + "\r\n"_kj; + + static constexpr auto RESPONSE = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 32\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Missing Sec-WebSocket-Key"_kj; + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + testBadWebSocketHandshake(waitScope, timer, REQUEST, RESPONSE, KJ_HTTP_TEST_CREATE_2PIPE); +} + +KJ_TEST("HttpServer WebSocket with application error after accept") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + class WebSocketApplicationErrorService: public HttpService, public HttpServerErrorHandler { + // Accepts a WebSocket, receives a message, and throws an exception (application error). + + public: + Promise request( + HttpMethod method, kj::StringPtr, const HttpHeaders&, + AsyncInputStream&, Response& response) override { + KJ_ASSERT(method == HttpMethod::GET); + HttpHeaderTable headerTable; + HttpHeaders responseHeaders(headerTable); + auto webSocket = response.acceptWebSocket(responseHeaders); + return webSocket->receive().then([](WebSocket::Message) { + throwRecoverableException(KJ_EXCEPTION(FAILED, "test exception")); + }).attach(kj::mv(webSocket)); + } + + Promise handleApplicationError(Exception exception, Maybe response) override { + // We accepted the WebSocket, so the response was already sent. At one time, we _did_ expose a + // useless Response reference here, so this is a regression test. + bool responseWasSent = response == nullptr; + KJ_EXPECT(responseWasSent); + KJ_EXPECT(exception.getDescription() == "test exception"_kj); + return READY_NOW; + } + }; + + // Set up the HTTP service. + + WebSocketApplicationErrorService service; + + HttpServerSettings serverSettings; + serverSettings.errorHandler = service; + + HttpHeaderTable headerTable; + HttpServer server(timer, headerTable, service, serverSettings); + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Make a client and open a WebSocket to the service. + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + auto client = newHttpClient( + headerTable, *pipe.ends[1], clientSettings); + + HttpHeaders headers(headerTable); + auto webSocketResponse = client->openWebSocket("/websocket"_kj, headers) + .wait(waitScope); + + KJ_ASSERT(webSocketResponse.statusCode == 101); + auto webSocket = kj::mv(KJ_ASSERT_NONNULL(webSocketResponse.webSocketOrBody.tryGet>())); + + webSocket->send("ignored"_kj).wait(waitScope); + + listenTask.wait(waitScope); +} + +// ----------------------------------------------------------------------------- + +KJ_TEST("HttpServer request timeout") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + TestHttpService service(PIPELINE_TESTS, table); + HttpServerSettings settings; + settings.headerTimeout = 1 * kj::MILLISECONDS; + HttpServer server(timer, table, service, settings); + + // Shouldn't hang! Should time out. + auto promise = server.listenHttp(kj::mv(pipe.ends[0])); + KJ_EXPECT(!promise.poll(waitScope)); + timer.advanceTo(timer.now() + settings.headerTimeout / 2); + KJ_EXPECT(!promise.poll(waitScope)); + timer.advanceTo(timer.now() + settings.headerTimeout); + promise.wait(waitScope); + + // Closes the connection without sending anything. + KJ_EXPECT(pipe.ends[1]->readAllText().wait(waitScope) == ""); +} + +KJ_TEST("HttpServer pipeline timeout") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + TestHttpService service(PIPELINE_TESTS, table); + HttpServerSettings settings; + settings.pipelineTimeout = 1 * kj::MILLISECONDS; + HttpServer server(timer, table, service, settings); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + expectRead(*pipe.ends[1], PIPELINE_TESTS[0].response.raw).wait(waitScope); + + // Listen task should time out even though we didn't shutdown the socket. + KJ_EXPECT(!listenTask.poll(waitScope)); + timer.advanceTo(timer.now() + settings.pipelineTimeout / 2); + KJ_EXPECT(!listenTask.poll(waitScope)); + timer.advanceTo(timer.now() + settings.pipelineTimeout); + listenTask.wait(waitScope); + + // In this case, no data is sent back. + KJ_EXPECT(pipe.ends[1]->readAllText().wait(waitScope) == ""); +} + +class BrokenHttpService final: public HttpService { + // HttpService that doesn't send a response. +public: + BrokenHttpService() = default; + explicit BrokenHttpService(kj::Exception&& exception): exception(kj::mv(exception)) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& responseSender) override { + return requestBody.readAllBytes().then([this](kj::Array&&) -> kj::Promise { + KJ_IF_MAYBE(e, exception) { + return kj::cp(*e); + } else { + return kj::READY_NOW; + } + }); + } + +private: + kj::Maybe exception; +}; + +KJ_TEST("HttpServer no response") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text == + "HTTP/1.1 500 Internal Server Error\r\n" + "Connection: close\r\n" + "Content-Length: 51\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The HttpService did not generate a response.", text); +} + +KJ_TEST("HttpServer disconnected") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + BrokenHttpService service(KJ_EXCEPTION(DISCONNECTED, "disconnected")); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text == "", text); +} + +KJ_TEST("HttpServer overloaded") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + BrokenHttpService service(KJ_EXCEPTION(OVERLOADED, "overloaded")); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text.startsWith("HTTP/1.1 503 Service Unavailable"), text); +} + +KJ_TEST("HttpServer unimplemented") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + BrokenHttpService service(KJ_EXCEPTION(UNIMPLEMENTED, "unimplemented")); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text.startsWith("HTTP/1.1 501 Not Implemented"), text); +} + +KJ_TEST("HttpServer threw exception") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed")); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text.startsWith("HTTP/1.1 500 Internal Server Error"), text); +} + +KJ_TEST("HttpServer bad request") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + static constexpr auto request = "GET / HTTP/1.1\r\nbad request\r\n\r\n"_kj; + auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); + auto response = pipe.ends[1]->readAllText().wait(waitScope); + KJ_EXPECT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); + + static constexpr auto expectedResponse = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 53\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The headers sent by your client are not valid."_kj; + + KJ_EXPECT(expectedResponse == response, expectedResponse, response); +} + +KJ_TEST("HttpServer invalid method") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + static constexpr auto request = "bad request\r\n\r\n"_kj; + auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); + auto response = pipe.ends[1]->readAllText().wait(waitScope); + KJ_EXPECT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); + + static constexpr auto expectedResponse = + "HTTP/1.1 501 Not Implemented\r\n" + "Connection: close\r\n" + "Content-Length: 35\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Unrecognized request method."_kj; + + KJ_EXPECT(expectedResponse == response, expectedResponse, response); +} + +// Ensure that HttpServerSettings can continue to be constexpr. +KJ_UNUSED static constexpr HttpServerSettings STATIC_CONSTEXPR_SETTINGS {}; + +class TestErrorHandler: public HttpServerErrorHandler { +public: + kj::Promise handleClientProtocolError( + HttpHeaders::ProtocolError protocolError, kj::HttpService::Response& response) override { + // In a real error handler, you should redact `protocolError.rawContent`. + auto message = kj::str("Saw protocol error: ", protocolError.description, "; rawContent = ", + encodeCEscape(protocolError.rawContent)); + return sendError(400, "Bad Request", kj::mv(message), response); + } + + kj::Promise handleApplicationError( + kj::Exception exception, kj::Maybe response) override { + return sendError(500, "Internal Server Error", + kj::str("Saw application error: ", exception.getDescription()), response); + } + + kj::Promise handleNoResponse(kj::HttpService::Response& response) override { + return sendError(500, "Internal Server Error", kj::str("Saw no response."), response); + } + + static TestErrorHandler instance; + +private: + kj::Promise sendError(uint statusCode, kj::StringPtr statusText, String message, + Maybe response) { + KJ_IF_MAYBE(r, response) { + HttpHeaderTable headerTable; + HttpHeaders headers(headerTable); + auto body = r->send(statusCode, statusText, headers, message.size()); + return body->write(message.begin(), message.size()).attach(kj::mv(body), kj::mv(message)); + } else { + KJ_LOG(ERROR, "Saw an error but too late to report to client."); + return kj::READY_NOW; + } + } +}; + +TestErrorHandler TestErrorHandler::instance {}; + +KJ_TEST("HttpServer no response, custom error handler") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpServerSettings settings {}; + settings.errorHandler = TestErrorHandler::instance; + + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service, settings); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text == + "HTTP/1.1 500 Internal Server Error\r\n" + "Connection: close\r\n" + "Content-Length: 16\r\n" + "\r\n" + "Saw no response.", text); +} + +KJ_TEST("HttpServer threw exception, custom error handler") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpServerSettings settings {}; + settings.errorHandler = TestErrorHandler::instance; + + HttpHeaderTable table; + BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed")); + HttpServer server(timer, table, service, settings); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text == + "HTTP/1.1 500 Internal Server Error\r\n" + "Connection: close\r\n" + "Content-Length: 29\r\n" + "\r\n" + "Saw application error: failed", text); +} + +KJ_TEST("HttpServer bad request, custom error handler") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpServerSettings settings {}; + settings.errorHandler = TestErrorHandler::instance; + + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service, settings); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + static constexpr auto request = "bad request\r\n\r\n"_kj; + auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); + auto response = pipe.ends[1]->readAllText().wait(waitScope); + KJ_EXPECT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); + + static constexpr auto expectedResponse = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 80\r\n" + "\r\n" + "Saw protocol error: Unrecognized request method.; " + "rawContent = bad request\\000\\n"_kj; + + KJ_EXPECT(expectedResponse == response, expectedResponse, response); +} + +class PartialResponseService final: public HttpService { + // HttpService that sends a partial response then throws. +public: + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + return requestBody.readAllBytes() + .then([this,&response](kj::Array&&) -> kj::Promise { + HttpHeaders headers(table); + auto body = response.send(200, "OK", headers, 32); + auto promise = body->write("foo", 3); + return promise.attach(kj::mv(body)).then([]() -> kj::Promise { + return KJ_EXCEPTION(FAILED, "failed"); + }); + }); + } + +private: + kj::Maybe exception; + HttpHeaderTable table; +}; + +KJ_TEST("HttpServer threw exception after starting response") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + PartialResponseService service; + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + KJ_EXPECT_LOG(ERROR, "HttpService threw exception after generating a partial response"); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text == + "HTTP/1.1 200 OK\r\n" + "Content-Length: 32\r\n" + "\r\n" + "foo", text); +} + +class PartialResponseNoThrowService final: public HttpService { + // HttpService that sends a partial response then returns without throwing. +public: + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + return requestBody.readAllBytes() + .then([this,&response](kj::Array&&) -> kj::Promise { + HttpHeaders headers(table); + auto body = response.send(200, "OK", headers, 32); + auto promise = body->write("foo", 3); + return promise.attach(kj::mv(body)); + }); + } + +private: + kj::Maybe exception; + HttpHeaderTable table; +}; + +KJ_TEST("HttpServer failed to write complete response but didn't throw") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + PartialResponseNoThrowService service; + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text == + "HTTP/1.1 200 OK\r\n" + "Content-Length: 32\r\n" + "\r\n" + "foo", text); +} + +class SimpleInputStream final: public kj::AsyncInputStream { + // An InputStream that returns bytes out of a static string. + +public: + SimpleInputStream(kj::StringPtr text) + : unread(text.asBytes()) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + size_t amount = kj::min(maxBytes, unread.size()); + memcpy(buffer, unread.begin(), amount); + unread = unread.slice(amount, unread.size()); + return amount; + } + +private: + kj::ArrayPtr unread; +}; + +class PumpResponseService final: public HttpService { + // HttpService that uses pumpTo() to write a response, without carefully specifying how much to + // pump, but the stream happens to be the right size. +public: + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + return requestBody.readAllBytes() + .then([this,&response](kj::Array&&) -> kj::Promise { + HttpHeaders headers(table); + kj::StringPtr text = "Hello, World!"; + auto body = response.send(200, "OK", headers, text.size()); + + auto stream = kj::heap(text); + auto promise = stream->pumpTo(*body); + return promise.attach(kj::mv(body), kj::mv(stream)) + .then([text](uint64_t amount) { + KJ_EXPECT(amount == text.size()); + }); + }); + } + +private: + kj::Maybe exception; + HttpHeaderTable table; +}; + +KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + PumpResponseService service; + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Do one request. + pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) + .wait(waitScope); + pipe.ends[1]->shutdownWrite(); + auto text = pipe.ends[1]->readAllText().wait(waitScope); + + KJ_EXPECT(text == + "HTTP/1.1 200 OK\r\n" + "Content-Length: 13\r\n" + "\r\n" + "Hello, World!", text); +} + +class HangingHttpService final: public HttpService { + // HttpService that hangs forever. +public: + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& responseSender) override { + kj::Promise result = kj::NEVER_DONE; + ++inFlight; + return result.attach(kj::defer([this]() { + if (--inFlight == 0) { + KJ_IF_MAYBE(f, onCancelFulfiller) { + f->get()->fulfill(); + } + } + })); + } + + kj::Promise onCancel() { + auto paf = kj::newPromiseAndFulfiller(); + onCancelFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + + uint inFlight = 0; + +private: + kj::Maybe exception; + kj::Maybe>> onCancelFulfiller; +}; + +KJ_TEST("HttpServer cancels request when client disconnects") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + HangingHttpService service; + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + KJ_EXPECT(service.inFlight == 0); + + static constexpr auto request = "GET / HTTP/1.1\r\n\r\n"_kj; + pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope); + + auto cancelPromise = service.onCancel(); + KJ_EXPECT(!cancelPromise.poll(waitScope)); + KJ_EXPECT(service.inFlight == 1); + + // Disconnect client and verify server cancels. + pipe.ends[1] = nullptr; + KJ_ASSERT(cancelPromise.poll(waitScope)); + KJ_EXPECT(service.inFlight == 0); + cancelPromise.wait(waitScope); +} + +class SuspendAfter: private HttpService { + // A SuspendableHttpServiceFactory which responds to the first `n` requests with 200 OK, then + // suspends all subsequent requests until its counter is reset. + +public: + void suspendAfter(uint countdownParam) { countdown = countdownParam; } + + kj::Maybe> operator()(HttpServer::SuspendableRequest& sr) { + if (countdown == 0) { + suspendedRequest = sr.suspend(); + return nullptr; + } + --countdown; + return kj::Own(static_cast(this), kj::NullDisposer::instance); + } + + kj::Maybe getSuspended() { + KJ_DEFER(suspendedRequest = nullptr); + return kj::mv(suspendedRequest); + } + +private: + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + HttpHeaders responseHeaders(table); + response.send(200, "OK", responseHeaders); + return requestBody.readAllBytes().ignoreResult(); + } + + HttpHeaderTable table; + + uint countdown = kj::maxValue; + kj::Maybe suspendedRequest; +}; + +KJ_TEST("HttpServer can suspend a request") { + // This test sends a single request to an HttpServer three times. First it writes the request to + // its pipe and arranges for the HttpServer to suspend the request. Then it resumes the suspended + // request and arranges for this resumption to be suspended as well. Then it resumes once more and + // arranges for the request to be completed. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + // This HttpService will not actually be used, because we're passing a factory in to + // listenHttpCleanDrain(). + HangingHttpService service; + HttpServer server(timer, table, service); + + kj::Maybe suspendedRequest; + + SuspendAfter factory; + + { + // Observe the HttpServer suspend. + + factory.suspendAfter(0); + auto listenPromise = server.listenHttpCleanDrain(*pipe.ends[0], factory); + + static constexpr kj::StringPtr REQUEST = + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "foobar\r\n" + "0\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + // The listen promise is fulfilled with false. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + + // And we have a SuspendedRequest. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } + + { + // Observe the HttpServer suspend again without reading from the connection. + + factory.suspendAfter(0); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); + + // The listen promise is again fulfilled with false. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + + // We again have a suspendedRequest. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } + + { + // The SuspendedRequest is completed. + + factory.suspendAfter(1); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); + + auto drainPromise = kj::evalLast([&]() { + return server.drain(); + }); + + // We need to read the response for the HttpServer to drain. + auto readPromise = pipe.ends[1]->readAllText(); + + // This time, the server drained cleanly. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(listenPromise.wait(waitScope)); + + drainPromise.wait(waitScope); + + // Close the server side of the pipe so our read promise completes. + pipe.ends[0] = nullptr; + + auto response = readPromise.wait(waitScope); + static constexpr kj::StringPtr RESPONSE = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n"_kj; + KJ_EXPECT(RESPONSE == response); + } +} + +KJ_TEST("HttpServer can suspend and resume pipelined requests") { + // This test sends multiple requests with both Content-Length and Transfer-Encoding: chunked + // bodies, and verifies that suspending both kinds does not corrupt the stream. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + // This HttpService will not actually be used, because we're passing a factory in to + // listenHttpCleanDrain(). + HangingHttpService service; + HttpServer server(timer, table, service); + + // We'll suspend the second request. + kj::Maybe suspendedRequest; + SuspendAfter factory; + + static constexpr kj::StringPtr LENGTHFUL_REQUEST = + "POST / HTTP/1.1\r\n" + "Content-Length: 6\r\n" + "\r\n" + "foobar"_kj; + static constexpr kj::StringPtr CHUNKED_REQUEST = + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "foobar\r\n" + "0\r\n" + "\r\n"_kj; + + // Set up several requests; we'll suspend and transfer the second and third one. + auto writePromise = pipe.ends[1]->write(LENGTHFUL_REQUEST.begin(), LENGTHFUL_REQUEST.size()) + .then([&]() { + return pipe.ends[1]->write(CHUNKED_REQUEST.begin(), CHUNKED_REQUEST.size()); + }).then([&]() { + return pipe.ends[1]->write(LENGTHFUL_REQUEST.begin(), LENGTHFUL_REQUEST.size()); + }).then([&]() { + return pipe.ends[1]->write(CHUNKED_REQUEST.begin(), CHUNKED_REQUEST.size()); + }); + + auto readPromise = pipe.ends[1]->readAllText(); + + { + // Observe the HttpServer suspend the second request. + + factory.suspendAfter(1); + auto listenPromise = server.listenHttpCleanDrain(*pipe.ends[0], factory); + + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } + + { + // Let's resume one request and suspend the next pipelined request. + + factory.suspendAfter(1); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); + + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } + + { + // Resume again and run to completion. + + factory.suspendAfter(kj::maxValue); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); + + auto drainPromise = kj::evalLast([&]() { + return server.drain(); + }); + + // This time, the server drained cleanly. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(listenPromise.wait(waitScope)); + // No suspended request this time. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest == nullptr); + + drainPromise.wait(waitScope); + } + + writePromise.wait(waitScope); + + // Close the server side of the pipe so our read promise completes. + pipe.ends[0] = nullptr; + + auto responses = readPromise.wait(waitScope); + static constexpr kj::StringPtr RESPONSE = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n"_kj; + KJ_EXPECT(kj::str(kj::delimited(kj::repeat(RESPONSE, 4), "")) == responses); +} + +KJ_TEST("HttpServer can suspend a request with no leftover") { + // This test verifies that if the request loop's read perfectly ends at the end of message + // headers, leaving no leftover section, we can still successfully suspend and resume. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + // This HttpService will not actually be used, because we're passing a factory in to + // listenHttpCleanDrain(). + HangingHttpService service; + HttpServer server(timer, table, service); + + kj::Maybe suspendedRequest; + + SuspendAfter factory; + + { + factory.suspendAfter(0); + auto listenPromise = server.listenHttpCleanDrain(*pipe.ends[0], factory); + + static constexpr kj::StringPtr REQUEST_HEADERS = + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST_HEADERS.begin(), REQUEST_HEADERS.size()).wait(waitScope); + + // The listen promise is fulfilled with false. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(!listenPromise.wait(waitScope)); + + // And we have a SuspendedRequest. We know that it has no leftover, because we only wrote + // headers, no body yet. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest != nullptr); + } + + { + factory.suspendAfter(1); + auto listenPromise = server.listenHttpCleanDrain( + *pipe.ends[0], factory, kj::mv(suspendedRequest)); + + auto drainPromise = kj::evalLast([&]() { + return server.drain(); + }); + + // We need to read the response for the HttpServer to drain. + auto readPromise = pipe.ends[1]->readAllText(); + + static constexpr kj::StringPtr REQUEST_BODY = + "6\r\n" + "foobar\r\n" + "0\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST_BODY.begin(), REQUEST_BODY.size()).wait(waitScope); + + // Clean drain. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(listenPromise.wait(waitScope)); + + drainPromise.wait(waitScope); + + // No SuspendedRequest. + suspendedRequest = factory.getSuspended(); + KJ_EXPECT(suspendedRequest == nullptr); + + // Close the server side of the pipe so our read promise completes. + pipe.ends[0] = nullptr; + + auto response = readPromise.wait(waitScope); + static constexpr kj::StringPtr RESPONSE = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n"_kj; + KJ_EXPECT(RESPONSE == response); + } +} + +KJ_TEST("HttpServer::listenHttpCleanDrain() factory-created services outlive requests") { + // Test that the lifetimes of factory-created Own objects are handled correctly. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + // This HttpService will not actually be used, because we're passing a factory in to + // listenHttpCleanDrain(). + HangingHttpService service; + HttpServer server(timer, table, service); + + uint serviceCount = 0; + + // A factory which returns a service whose request() function responds asynchronously. + auto factory = [&](HttpServer::SuspendableRequest&) -> kj::Own { + class ServiceImpl final: public HttpService { + public: + explicit ServiceImpl(uint& serviceCount): serviceCount(++serviceCount) {} + ~ServiceImpl() noexcept(false) { --serviceCount; } + KJ_DISALLOW_COPY_AND_MOVE(ServiceImpl); + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + return evalLater([&serviceCount = serviceCount, &table = table, &requestBody, &response]() { + // This KJ_EXPECT here is the entire point of this test. + KJ_EXPECT(serviceCount == 1) + HttpHeaders responseHeaders(table); + response.send(200, "OK", responseHeaders); + return requestBody.readAllBytes().ignoreResult(); + }); + } + + private: + HttpHeaderTable table; + + uint& serviceCount; + }; + + return kj::heap(serviceCount); + }; + + auto listenPromise = server.listenHttpCleanDrain(*pipe.ends[0], factory); + + static constexpr kj::StringPtr REQUEST = + "POST / HTTP/1.1\r\n" + "Content-Length: 6\r\n" + "\r\n" + "foobar"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + // We need to read the response for the HttpServer to drain. + auto readPromise = pipe.ends[1]->readAllText(); + + // http-socketpair-test quirk: we must drive the request loop past the point of receiving request + // headers so that our call to server.drain() doesn't prematurely cancel the request. + KJ_EXPECT(!listenPromise.poll(waitScope)); + + auto drainPromise = kj::evalLast([&]() { + return server.drain(); + }); + + // Clean drain. + KJ_EXPECT(listenPromise.poll(waitScope)); + KJ_EXPECT(listenPromise.wait(waitScope)); + + drainPromise.wait(waitScope); + + // Close the server side of the pipe so our read promise completes. + pipe.ends[0] = nullptr; + auto response = readPromise.wait(waitScope); + + static constexpr kj::StringPtr RESPONSE = + "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "0\r\n" + "\r\n"_kj; + KJ_EXPECT(RESPONSE == response); +} + +// ----------------------------------------------------------------------------- + +KJ_TEST("newHttpService from HttpClient") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; + auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::Promise writeResponsesPromise = kj::READY_NOW; + for (auto& testCase: PIPELINE_TESTS) { + writeResponsesPromise = writeResponsesPromise + .then([&]() { + return expectRead(*backPipe.ends[1], testCase.request.raw); + }).then([&]() { + return backPipe.ends[1]->write(testCase.response.raw.begin(), testCase.response.raw.size()); + }); + } + + { + HttpHeaderTable table; + auto backClient = newHttpClient(table, *backPipe.ends[0]); + auto frontService = newHttpService(*backClient); + HttpServer frontServer(timer, table, *frontService); + auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); + + for (auto& testCase: PIPELINE_TESTS) { + KJ_CONTEXT(testCase.request.raw, testCase.response.raw); + + frontPipe.ends[0]->write(testCase.request.raw.begin(), testCase.request.raw.size()) + .wait(waitScope); + + expectRead(*frontPipe.ends[0], testCase.response.raw).wait(waitScope); + } + + frontPipe.ends[0]->shutdownWrite(); + listenTask.wait(waitScope); + } + + backPipe.ends[0]->shutdownWrite(); + writeResponsesPromise.wait(waitScope); +} + +KJ_TEST("newHttpService from HttpClient WebSockets") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; + auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + auto writeResponsesPromise = expectRead(*backPipe.ends[1], request) + .then([&]() { return writeA(*backPipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) + .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) + .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_REPLY_MESSAGE); }) + .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_CLOSE); }) + .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_REPLY_CLOSE); }) + .then([&]() { return expectEnd(*backPipe.ends[1]); }) + .then([&]() { backPipe.ends[1]->shutdownWrite(); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + { + HttpHeaderTable table; + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + auto backClientStream = kj::mv(backPipe.ends[0]); + auto backClient = newHttpClient(table, *backClientStream, clientSettings); + auto frontService = newHttpService(*backClient); + HttpServer frontServer(timer, table, *frontService); + auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); + + writeA(*frontPipe.ends[0], request.asBytes()).wait(waitScope); + expectRead(*frontPipe.ends[0], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); + + expectRead(*frontPipe.ends[0], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); + writeA(*frontPipe.ends[0], WEBSOCKET_SEND_MESSAGE).wait(waitScope); + expectRead(*frontPipe.ends[0], WEBSOCKET_REPLY_MESSAGE).wait(waitScope); + writeA(*frontPipe.ends[0], WEBSOCKET_SEND_CLOSE).wait(waitScope); + expectRead(*frontPipe.ends[0], WEBSOCKET_REPLY_CLOSE).wait(waitScope); + + frontPipe.ends[0]->shutdownWrite(); + listenTask.wait(waitScope); + } + + writeResponsesPromise.wait(waitScope); +} + +KJ_TEST("newHttpService from HttpClient WebSockets disconnect") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE; + auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE; + + auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); + auto writeResponsesPromise = expectRead(*backPipe.ends[1], request) + .then([&]() { return writeA(*backPipe.ends[1], asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)); }) + .then([&]() { return writeA(*backPipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE); }) + .then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_MESSAGE); }) + .then([&]() { backPipe.ends[1]->shutdownWrite(); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + { + HttpHeaderTable table; + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + auto backClient = newHttpClient(table, *backPipe.ends[0], clientSettings); + auto frontService = newHttpService(*backClient); + HttpServer frontServer(timer, table, *frontService); + auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1])); + + writeA(*frontPipe.ends[0], request.asBytes()).wait(waitScope); + expectRead(*frontPipe.ends[0], WEBSOCKET_RESPONSE_HANDSHAKE).wait(waitScope); + + expectRead(*frontPipe.ends[0], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(waitScope); + writeA(*frontPipe.ends[0], WEBSOCKET_SEND_MESSAGE).wait(waitScope); + + KJ_EXPECT(frontPipe.ends[0]->readAllText().wait(waitScope) == ""); + + frontPipe.ends[0]->shutdownWrite(); + listenTask.wait(waitScope); + } + + writeResponsesPromise.wait(waitScope); +} + +// ----------------------------------------------------------------------------- + +KJ_TEST("newHttpClient from HttpService") { + auto PIPELINE_TESTS = pipelineTestCases(); + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + TestHttpService service(PIPELINE_TESTS, table); + auto client = newHttpClient(service); + + for (auto& testCase: PIPELINE_TESTS) { + testHttpClient(waitScope, table, *client, testCase); + } +} + +KJ_TEST("newHttpClient from HttpService WebSockets") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable::Builder tableBuilder; + HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); + auto headerTable = tableBuilder.build(); + TestWebSocketService service(*headerTable, hMyHeader); + auto client = newHttpClient(service); + + testWebSocketClient(waitScope, *headerTable, hMyHeader, *client); +} + +KJ_TEST("adapted client/server propagates request exceptions like non-adapted client") { + KJ_HTTP_TEST_SETUP_IO; + + HttpHeaderTable table; + HttpHeaders headers(table); + + class FailingHttpClient final: public HttpClient { + public: + Request request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_FAIL_ASSERT("request_fail"); + } + + kj::Promise openWebSocket( + kj::StringPtr url, const HttpHeaders& headers) override { + KJ_FAIL_ASSERT("websocket_fail"); + } + }; + + auto rawClient = kj::heap(); + + auto innerClient = kj::heap(); + auto adaptedService = kj::newHttpService(*innerClient).attach(kj::mv(innerClient)); + auto adaptedClient = kj::newHttpClient(*adaptedService).attach(kj::mv(adaptedService)); + + KJ_EXPECT_THROW_MESSAGE("request_fail", rawClient->request(HttpMethod::POST, "/"_kj, headers)); + KJ_EXPECT_THROW_MESSAGE("request_fail", adaptedClient->request(HttpMethod::POST, "/"_kj, headers)); + + KJ_EXPECT_THROW_MESSAGE("websocket_fail", rawClient->openWebSocket("/"_kj, headers)); + KJ_EXPECT_THROW_MESSAGE("websocket_fail", adaptedClient->openWebSocket("/"_kj, headers)); +} + +class DelayedCompletionHttpService final: public HttpService { +public: + DelayedCompletionHttpService(HttpHeaderTable& table, kj::Maybe expectedLength) + : table(table), expectedLength(expectedLength) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + auto stream = response.send(200, "OK", HttpHeaders(table), expectedLength); + auto promise = stream->write("foo", 3); + return promise.attach(kj::mv(stream)).then([this]() { + return kj::mv(paf.promise); + }); + } + + kj::PromiseFulfiller& getFulfiller() { return *paf.fulfiller; } + +private: + HttpHeaderTable& table; + kj::Maybe expectedLength; + kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller(); +}; + +void doDelayedCompletionTest(bool exception, kj::Maybe expectedLength) noexcept { + KJ_HTTP_TEST_SETUP_IO; + + HttpHeaderTable table; + + DelayedCompletionHttpService service(table, expectedLength); + auto client = newHttpClient(service); + + auto resp = client->request(HttpMethod::GET, "/", HttpHeaders(table), uint64_t(0)) + .response.wait(waitScope); + KJ_EXPECT(resp.statusCode == 200); + + // Read "foo" from the response body: works + char buffer[16]; + KJ_ASSERT(resp.body->tryRead(buffer, 1, sizeof(buffer)).wait(waitScope) == 3); + buffer[3] = '\0'; + KJ_EXPECT(buffer == "foo"_kj); + + // But reading any more hangs. + auto promise = resp.body->tryRead(buffer, 1, sizeof(buffer)); + + KJ_EXPECT(!promise.poll(waitScope)); + + // Until we cause the service to return. + if (exception) { + service.getFulfiller().reject(KJ_EXCEPTION(FAILED, "service-side failure")); + } else { + service.getFulfiller().fulfill(); + } + + KJ_ASSERT(promise.poll(waitScope)); + + if (exception) { + KJ_EXPECT_THROW_MESSAGE("service-side failure", promise.wait(waitScope)); + } else { + promise.wait(waitScope); + } +}; + +KJ_TEST("adapted client waits for service to complete before returning EOF on response stream") { + doDelayedCompletionTest(false, uint64_t(3)); +} + +KJ_TEST("adapted client waits for service to complete before returning EOF on chunked response") { + doDelayedCompletionTest(false, nullptr); +} + +KJ_TEST("adapted client propagates throw from service after complete response body sent") { + doDelayedCompletionTest(true, uint64_t(3)); +} + +KJ_TEST("adapted client propagates throw from service after incomplete response body sent") { + doDelayedCompletionTest(true, uint64_t(6)); +} + +KJ_TEST("adapted client propagates throw from service after chunked response body sent") { + doDelayedCompletionTest(true, nullptr); +} + +class DelayedCompletionWebSocketHttpService final: public HttpService { +public: + DelayedCompletionWebSocketHttpService(HttpHeaderTable& table, bool closeUpstreamFirst) + : table(table), closeUpstreamFirst(closeUpstreamFirst) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_ASSERT(headers.isWebSocket()); + + auto ws = response.acceptWebSocket(HttpHeaders(table)); + kj::Promise promise = kj::READY_NOW; + if (closeUpstreamFirst) { + // Wait for a close message from the client before starting. + promise = promise.then([&ws = *ws]() { return ws.receive(); }).ignoreResult(); + } + promise = promise + .then([&ws = *ws]() { return ws.send("foo"_kj); }) + .then([&ws = *ws]() { return ws.close(1234, "closed"_kj); }); + if (!closeUpstreamFirst) { + // Wait for a close message from the client at the end. + promise = promise.then([&ws = *ws]() { return ws.receive(); }).ignoreResult(); + } + return promise.attach(kj::mv(ws)).then([this]() { + return kj::mv(paf.promise); + }); + } + + kj::PromiseFulfiller& getFulfiller() { return *paf.fulfiller; } + +private: + HttpHeaderTable& table; + bool closeUpstreamFirst; + kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller(); +}; + +void doDelayedCompletionWebSocketTest(bool exception, bool closeUpstreamFirst) noexcept { + KJ_HTTP_TEST_SETUP_IO; + + HttpHeaderTable table; + + DelayedCompletionWebSocketHttpService service(table, closeUpstreamFirst); + auto client = newHttpClient(service); + + auto resp = client->openWebSocket("/", HttpHeaders(table)).wait(waitScope); + auto ws = kj::mv(KJ_ASSERT_NONNULL(resp.webSocketOrBody.tryGet>())); + + if (closeUpstreamFirst) { + // Send "close" immediately. + ws->close(1234, "whatever"_kj).wait(waitScope); + } + + // Read "foo" from the WebSocket: works + { + auto msg = ws->receive().wait(waitScope); + KJ_ASSERT(msg.is()); + KJ_ASSERT(msg.get() == "foo"); + } + + kj::Promise promise = nullptr; + if (closeUpstreamFirst) { + // Receiving the close hangs. + promise = ws->receive() + .then([](WebSocket::Message&& msg) { KJ_EXPECT(msg.is()); }); + } else { + auto msg = ws->receive().wait(waitScope); + KJ_ASSERT(msg.is()); + + // Sending a close hangs. + promise = ws->close(1234, "whatever"_kj); + } + KJ_EXPECT(!promise.poll(waitScope)); + + // Until we cause the service to return. + if (exception) { + service.getFulfiller().reject(KJ_EXCEPTION(FAILED, "service-side failure")); + } else { + service.getFulfiller().fulfill(); + } + + KJ_ASSERT(promise.poll(waitScope)); + + if (exception) { + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("service-side failure", promise.wait(waitScope)); + } else { + promise.wait(waitScope); + } +}; + +KJ_TEST("adapted client waits for service to complete before completing upstream close on WebSocket") { + doDelayedCompletionWebSocketTest(false, false); +} + +KJ_TEST("adapted client waits for service to complete before returning downstream close on WebSocket") { + doDelayedCompletionWebSocketTest(false, true); +} + +KJ_TEST("adapted client propagates throw from service after WebSocket upstream close sent") { + doDelayedCompletionWebSocketTest(true, false); +} + +KJ_TEST("adapted client propagates throw from service after WebSocket downstream close sent") { + doDelayedCompletionWebSocketTest(true, true); +} + +// ----------------------------------------------------------------------------- + +class CountingIoStream final: public kj::AsyncIoStream { + // An AsyncIoStream wrapper which decrements a counter when destroyed (allowing us to count how + // many connections are open). + +public: + CountingIoStream(kj::Own inner, uint& count) + : inner(kj::mv(inner)), count(count) {} + ~CountingIoStream() noexcept(false) { + --count; + } + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->read(buffer, minBytes, maxBytes); + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->tryRead(buffer, minBytes, maxBytes); + } + kj::Maybe tryGetLength() override { + return inner->tryGetLength();; + } + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return inner->pumpTo(output, amount); + } + kj::Promise write(const void* buffer, size_t size) override { + return inner->write(buffer, size); + } + kj::Promise write(kj::ArrayPtr> pieces) override { + return inner->write(pieces); + } + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + return inner->tryPumpFrom(input, amount); + } + Promise whenWriteDisconnected() override { + return inner->whenWriteDisconnected(); + } + void shutdownWrite() override { + return inner->shutdownWrite(); + } + void abortRead() override { + return inner->abortRead(); + } + +public: + kj::Own inner; + uint& count; +}; + +class CountingNetworkAddress final: public kj::NetworkAddress { +public: + CountingNetworkAddress(kj::NetworkAddress& inner, uint& count, uint& cumulative) + : inner(inner), count(count), addrCount(ownAddrCount), cumulative(cumulative) {} + CountingNetworkAddress(kj::Own inner, uint& count, uint& addrCount) + : inner(*inner), ownInner(kj::mv(inner)), count(count), addrCount(addrCount), + cumulative(ownCumulative) {} + ~CountingNetworkAddress() noexcept(false) { + --addrCount; + } + + kj::Promise> connect() override { + ++count; + ++cumulative; + return inner.connect() + .then([this](kj::Own stream) -> kj::Own { + return kj::heap(kj::mv(stream), count); + }); + } + + kj::Own listen() override { KJ_UNIMPLEMENTED("test"); } + kj::Own clone() override { KJ_UNIMPLEMENTED("test"); } + kj::String toString() override { KJ_UNIMPLEMENTED("test"); } + +private: + kj::NetworkAddress& inner; + kj::Own ownInner; + uint& count; + uint ownAddrCount = 1; + uint& addrCount; + uint ownCumulative = 0; + uint& cumulative; +}; + +class ConnectionCountingNetwork final: public kj::Network { +public: + ConnectionCountingNetwork(kj::Network& inner, uint& count, uint& addrCount) + : inner(inner), count(count), addrCount(addrCount) {} + + Promise> parseAddress(StringPtr addr, uint portHint = 0) override { + ++addrCount; + return inner.parseAddress(addr, portHint) + .then([this](Own&& addr) -> Own { + return kj::heap(kj::mv(addr), count, addrCount); + }); + } + Own getSockaddr(const void* sockaddr, uint len) override { + KJ_UNIMPLEMENTED("test"); + } + Own restrictPeers( + kj::ArrayPtr allow, + kj::ArrayPtr deny = nullptr) override { + KJ_UNIMPLEMENTED("test"); + } + +private: + kj::Network& inner; + uint& count; + uint& addrCount; +}; + +class DummyService final: public HttpService { +public: + DummyService(HttpHeaderTable& headerTable): headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + if (!headers.isWebSocket()) { + if (url == "/throw") { + return KJ_EXCEPTION(FAILED, "client requested failure"); + } + + auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); + auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size()); + auto promises = kj::heapArrayBuilder>(2); + promises.add(stream->write(body.begin(), body.size())); + promises.add(requestBody.readAllBytes().ignoreResult()); + return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body)); + } else { + auto ws = response.acceptWebSocket(HttpHeaders(headerTable)); + auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); + auto sendPromise = ws->send(body); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(sendPromise.attach(kj::mv(body))); + promises.add(ws->receive().ignoreResult()); + return kj::joinPromises(promises.finish()).attach(kj::mv(ws)); + } + } + +private: + HttpHeaderTable& headerTable; +}; + +KJ_TEST("HttpClient connection management") { + KJ_HTTP_TEST_SETUP_IO; + KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; + + kj::TimerImpl serverTimer(kj::origin()); + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + + DummyService service(headerTable); + HttpServerSettings serverSettings; + HttpServer server(serverTimer, headerTable, service, serverSettings); + auto listenTask = server.listenHttp(*listener); + + uint count = 0; + uint cumulative = 0; + CountingNetworkAddress countingAddr(*addr, count, cumulative); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + auto client = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 0); + + uint i = 0; + auto doRequest = [&]() { + uint n = i++; + return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response + .then([](HttpClient::Response&& response) { + auto promise = response.body->readAllText(); + return promise.attach(kj::mv(response.body)); + }).then([n](kj::String body) { + KJ_EXPECT(body == kj::str("null:/", n)); + }); + }; + + // We can do several requests in a row and only have one connection. + doRequest().wait(waitScope); + doRequest().wait(waitScope); + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 1); + + // But if we do two in parallel, we'll end up with two connections. + auto req1 = doRequest(); + auto req2 = doRequest(); + req1.wait(waitScope); + req2.wait(waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(cumulative == 2); + + // We can reuse after a POST, provided we write the whole POST body properly. + { + auto req = client->request( + HttpMethod::POST, kj::str("/foo"), HttpHeaders(headerTable), size_t(6)); + req.body->write("foobar", 6).wait(waitScope); + req.response.wait(waitScope).body->readAllBytes().wait(waitScope); + } + KJ_EXPECT(count == 2); + KJ_EXPECT(cumulative == 2); + doRequest().wait(waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(cumulative == 2); + + // Advance time for half the timeout, then exercise one of the connections. + clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout / 2); + doRequest().wait(waitScope); + doRequest().wait(waitScope); + waitScope.poll(); + KJ_EXPECT(count == 2); + KJ_EXPECT(cumulative == 2); + + // Advance time past when the other connection should time out. It should be dropped. + clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout * 3 / 4); + waitScope.poll(); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 2); + + // Wait for the other to drop. + clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout / 2); + waitScope.poll(); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 2); + + // New request creates a new connection again. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 3); + + // WebSocket connections are not reused. + client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable)) + .wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 3); + + // Errored connections are not reused. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 4); + client->request(HttpMethod::GET, kj::str("/throw"), HttpHeaders(headerTable)).response + .wait(waitScope).body->readAllBytes().wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 4); + + // Connections where we failed to read the full response body are not reused. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 5); + client->request(HttpMethod::GET, kj::str("/foo"), HttpHeaders(headerTable)).response + .wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 5); + + // Connections where we didn't even wait for the response headers are not reused. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 6); + client->request(HttpMethod::GET, kj::str("/foo"), HttpHeaders(headerTable)); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 6); + + // Connections where we failed to write the full request body are not reused. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 7); + client->request(HttpMethod::POST, kj::str("/foo"), HttpHeaders(headerTable), size_t(6)).response + .wait(waitScope).body->readAllBytes().wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 7); + + // If the server times out the connection, we figure it out on the client. + doRequest().wait(waitScope); + + // TODO(someday): Figure out why the following poll is necessary for the test to pass on Windows + // and Mac. Without it, it seems that the request's connection never starts, so the + // subsequent advanceTo() does not actually time out the connection. + waitScope.poll(); + + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 8); + serverTimer.advanceTo(serverTimer.now() + serverSettings.pipelineTimeout * 2); + waitScope.poll(); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 8); + + // Can still make requests. + doRequest().wait(waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 9); +} + +KJ_TEST("HttpClient disable connection reuse") { + KJ_HTTP_TEST_SETUP_IO; + KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; + + kj::TimerImpl serverTimer(kj::origin()); + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + + DummyService service(headerTable); + HttpServerSettings serverSettings; + HttpServer server(serverTimer, headerTable, service, serverSettings); + auto listenTask = server.listenHttp(*listener); + + uint count = 0; + uint cumulative = 0; + CountingNetworkAddress countingAddr(*addr, count, cumulative); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.idleTimeout = 0 * kj::SECONDS; + auto client = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 0); + + uint i = 0; + auto doRequest = [&]() { + uint n = i++; + return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response + .then([](HttpClient::Response&& response) { + auto promise = response.body->readAllText(); + return promise.attach(kj::mv(response.body)); + }).then([n](kj::String body) { + KJ_EXPECT(body == kj::str("null:/", n)); + }); + }; + + // Each serial request gets its own connection. + doRequest().wait(waitScope); + doRequest().wait(waitScope); + doRequest().wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 3); + + // Each parallel request gets its own connection. + auto req1 = doRequest(); + auto req2 = doRequest(); + req1.wait(waitScope); + req2.wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 5); +} + +KJ_TEST("HttpClient concurrency limiting") { +#if KJ_HTTP_TEST_USE_OS_PIPE && !__linux__ + // On Windows and Mac, OS event delivery is not always immediate, and that seems to make this + // test flakey. On Linux, events are always immediately delivered. For now, we compile the test + // but we don't run it outside of Linux. We do run the in-memory-pipes version on all OSs since + // that mode shouldn't depend on kernel behavior at all. + return; +#endif + + KJ_HTTP_TEST_SETUP_IO; + KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR; + + kj::TimerImpl serverTimer(kj::origin()); + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + + DummyService service(headerTable); + HttpServerSettings serverSettings; + HttpServer server(serverTimer, headerTable, service, serverSettings); + auto listenTask = server.listenHttp(*listener); + + uint count = 0; + uint cumulative = 0; + CountingNetworkAddress countingAddr(*addr, count, cumulative); + + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + clientSettings.entropySource = entropySource; + clientSettings.idleTimeout = 0 * kj::SECONDS; + auto innerClient = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings); + + struct CallbackEvent { + uint runningCount; + uint pendingCount; + + bool operator==(const CallbackEvent& other) const { + return runningCount == other.runningCount && pendingCount == other.pendingCount; + } + bool operator!=(const CallbackEvent& other) const { return !(*this == other); } + // TODO(someday): Can use default spaceship operator in C++20: + //auto operator<=>(const CallbackEvent&) const = default; + }; + + kj::Vector callbackEvents; + auto callback = [&](uint runningCount, uint pendingCount) { + callbackEvents.add(CallbackEvent{runningCount, pendingCount}); + }; + auto client = newConcurrencyLimitingHttpClient(*innerClient, 1, kj::mv(callback)); + + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 0); + + uint i = 0; + auto doRequest = [&]() { + uint n = i++; + return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response + .then([](HttpClient::Response&& response) { + auto promise = response.body->readAllText(); + return promise.attach(kj::mv(response.body)); + }).then([n](kj::String body) { + KJ_EXPECT(body == kj::str("null:/", n)); + }); + }; + + // Second connection blocked by first. + auto req1 = doRequest(); + + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); + callbackEvents.clear(); + + auto req2 = doRequest(); + + // TODO(someday): Figure out why this poll() is necessary on Windows and macOS. + waitScope.poll(); + + KJ_EXPECT(req1.poll(waitScope)); + KJ_EXPECT(!req2.poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 1); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 1} })); + callbackEvents.clear(); + + // Releasing first connection allows second to start. + req1.wait(waitScope); + KJ_EXPECT(req2.poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 2); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); + callbackEvents.clear(); + + req2.wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 2); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {0, 0} })); + callbackEvents.clear(); + + // Using body stream after releasing blocked response promise throws no exception + auto req3 = doRequest(); + { + kj::Own req4Body; + { + auto req4 = client->request(HttpMethod::GET, kj::str("/", ++i), HttpHeaders(headerTable)); + waitScope.poll(); + req4Body = kj::mv(req4.body); + } + auto writePromise = req4Body->write("a", 1); + KJ_EXPECT(!writePromise.poll(waitScope)); + } + req3.wait(waitScope); + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 3); + + // Similar connection limiting for web sockets + // TODO(someday): Figure out why the sequencing of websockets events does + // not work correctly on Windows (and maybe macOS?). The solution is not as + // simple as inserting poll()s as above, since doing so puts the websocket in + // a state that trips a "previous HTTP message body incomplete" assertion, + // while trying to write 500 network response. + callbackEvents.clear(); + auto ws1 = kj::heap(client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable))); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); + callbackEvents.clear(); + auto ws2 = kj::heap(client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable))); + KJ_EXPECT(ws1->poll(waitScope)); + KJ_EXPECT(!ws2->poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 4); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 1} })); + callbackEvents.clear(); + + { + auto response1 = ws1->wait(waitScope); + KJ_EXPECT(!ws2->poll(waitScope)); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({})); + } + KJ_EXPECT(ws2->poll(waitScope)); + KJ_EXPECT(count == 1); + KJ_EXPECT(cumulative == 5); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {1, 0} })); + callbackEvents.clear(); + { + auto response2 = ws2->wait(waitScope); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({})); + } + KJ_EXPECT(count == 0); + KJ_EXPECT(cumulative == 5); + KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {0, 0} })); +} + +#if KJ_HTTP_TEST_USE_OS_PIPE +// This test relies on access to the network. +KJ_TEST("NetworkHttpClient connect impl") { + KJ_HTTP_TEST_SETUP_IO; + auto listener1 = io.provider->getNetwork().parseAddress("localhost", 0) + .wait(io.waitScope)->listen(); + + auto ignored KJ_UNUSED = listener1->accept().then([](Own stream) { + auto buffer = kj::str("test"); + return stream->write(buffer.cStr(), buffer.size()).attach(kj::mv(stream), kj::mv(buffer)); + }).eagerlyEvaluate(nullptr); + + HttpClientSettings clientSettings; + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + auto client = newHttpClient(clientTimer, headerTable, + io.provider->getNetwork(), nullptr, clientSettings); + auto request = client->connect( + kj::str("localhost:", listener1->getPort()), HttpHeaders(headerTable), {}); + + auto buf = kj::heapArray(4); + return request.connection->tryRead(buf.begin(), 1, buf.size()) + .then([buf = kj::mv(buf)](size_t count) { + KJ_ASSERT(count == 4); + KJ_ASSERT(kj::str(buf.asChars()) == "test"); + }).attach(kj::mv(request.connection)).wait(io.waitScope); +} +#endif + +#if KJ_HTTP_TEST_USE_OS_PIPE +// TODO(someday): Implement mock kj::Network for userspace version of this test? +KJ_TEST("HttpClient multi host") { + auto io = kj::setupAsyncIo(); + + kj::TimerImpl serverTimer(kj::origin()); + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + + auto listener1 = io.provider->getNetwork().parseAddress("localhost", 0) + .wait(io.waitScope)->listen(); + auto listener2 = io.provider->getNetwork().parseAddress("localhost", 0) + .wait(io.waitScope)->listen(); + DummyService service(headerTable); + HttpServer server(serverTimer, headerTable, service); + auto listenTask1 = server.listenHttp(*listener1); + auto listenTask2 = server.listenHttp(*listener2); + + uint count = 0, addrCount = 0; + uint tlsCount = 0, tlsAddrCount = 0; + ConnectionCountingNetwork countingNetwork(io.provider->getNetwork(), count, addrCount); + ConnectionCountingNetwork countingTlsNetwork(io.provider->getNetwork(), tlsCount, tlsAddrCount); + + HttpClientSettings clientSettings; + auto client = newHttpClient(clientTimer, headerTable, + countingNetwork, countingTlsNetwork, clientSettings); + + KJ_EXPECT(count == 0); + + uint i = 0; + auto doRequest = [&](bool tls, uint port) { + uint n = i++; + // We stick a double-slash in the URL to test that it doesn't get coalesced into one slash, + // which was a bug in the past. + return client->request(HttpMethod::GET, + kj::str((tls ? "https://localhost:" : "http://localhost:"), port, "//", n), + HttpHeaders(headerTable)).response + .then([](HttpClient::Response&& response) { + auto promise = response.body->readAllText(); + return promise.attach(kj::mv(response.body)); + }).then([n, port](kj::String body) { + KJ_EXPECT(body == kj::str("localhost:", port, "://", n), body, port, n); + }); + }; + + uint port1 = listener1->getPort(); + uint port2 = listener2->getPort(); + + // We can do several requests in a row to the same host and only have one connection. + doRequest(false, port1).wait(io.waitScope); + doRequest(false, port1).wait(io.waitScope); + doRequest(false, port1).wait(io.waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(tlsCount == 0); + KJ_EXPECT(addrCount == 1); + KJ_EXPECT(tlsAddrCount == 0); + + // Request a different host, and now we have two connections. + doRequest(false, port2).wait(io.waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(tlsCount == 0); + KJ_EXPECT(addrCount == 2); + KJ_EXPECT(tlsAddrCount == 0); + + // Try TLS. + doRequest(true, port1).wait(io.waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(tlsCount == 1); + KJ_EXPECT(addrCount == 2); + KJ_EXPECT(tlsAddrCount == 1); + + // Try first host again, no change in connection count. + doRequest(false, port1).wait(io.waitScope); + KJ_EXPECT(count == 2); + KJ_EXPECT(tlsCount == 1); + KJ_EXPECT(addrCount == 2); + KJ_EXPECT(tlsAddrCount == 1); + + // Multiple requests in parallel forces more connections to that host. + auto promise1 = doRequest(false, port1); + auto promise2 = doRequest(false, port1); + promise1.wait(io.waitScope); + promise2.wait(io.waitScope); + KJ_EXPECT(count == 3); + KJ_EXPECT(tlsCount == 1); + KJ_EXPECT(addrCount == 2); + KJ_EXPECT(tlsAddrCount == 1); + + // Let everything expire. + clientTimer.advanceTo(clientTimer.now() + clientSettings.idleTimeout * 2); + io.waitScope.poll(); + KJ_EXPECT(count == 0); + KJ_EXPECT(tlsCount == 0); + KJ_EXPECT(addrCount == 0); + KJ_EXPECT(tlsAddrCount == 0); + + // We can still request those hosts again. + doRequest(false, port1).wait(io.waitScope); + KJ_EXPECT(count == 1); + KJ_EXPECT(tlsCount == 0); + KJ_EXPECT(addrCount == 1); + KJ_EXPECT(tlsAddrCount == 0); +} +#endif + +// ----------------------------------------------------------------------------- + +#if KJ_HTTP_TEST_USE_OS_PIPE +// This test only makes sense using the real network. +KJ_TEST("HttpClient to capnproto.org") { + auto io = kj::setupAsyncIo(); + + auto maybeConn = io.provider->getNetwork().parseAddress("capnproto.org", 80) + .then([](kj::Own addr) { + auto promise = addr->connect(); + return promise.attach(kj::mv(addr)); + }).then([](kj::Own&& connection) -> kj::Maybe> { + return kj::mv(connection); + }, [](kj::Exception&& e) -> kj::Maybe> { + KJ_LOG(WARNING, "skipping test because couldn't connect to capnproto.org"); + return nullptr; + }).wait(io.waitScope); + + KJ_IF_MAYBE(conn, maybeConn) { + // Successfully connected to capnproto.org. Try doing GET /. We expect to get a redirect to + // HTTPS, because what kind of horrible web site would serve in plaintext, really? + + HttpHeaderTable table; + auto client = newHttpClient(table, **conn); + + HttpHeaders headers(table); + headers.set(HttpHeaderId::HOST, "capnproto.org"); + + auto response = client->request(HttpMethod::GET, "/", headers).response.wait(io.waitScope); + KJ_EXPECT(response.statusCode / 100 == 3); + auto location = KJ_ASSERT_NONNULL(response.headers->get(HttpHeaderId::LOCATION)); + KJ_EXPECT(location == "https://capnproto.org/"); + + auto body = response.body->readAllText().wait(io.waitScope); + } +} +#endif + +// ======================================================================================= +// Misc bugfix tests + +class ReadCancelHttpService final: public HttpService { + // HttpService that tries to read all request data but cancels after 1ms and sends a response. +public: + ReadCancelHttpService(kj::Timer& timer, HttpHeaderTable& headerTable) + : timer(timer), headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& responseSender) override { + if (method == HttpMethod::POST) { + // Try to read all content, but cancel after 1ms. + + // Actually, we can't literally cancel mid-read, because this leaves the stream in an + // unknown state which requires closing the connection. Instead, we know that the sender + // will send 5 bytes, so we read that, then pause. + static char junk[5]; + return requestBody.read(junk, 5) + .then([]() -> kj::Promise { return kj::NEVER_DONE; }) + .exclusiveJoin(timer.afterDelay(1 * kj::MILLISECONDS)) + .then([this, &responseSender]() { + responseSender.send(408, "Request Timeout", kj::HttpHeaders(headerTable), uint64_t(0)); + }); + } else { + responseSender.send(200, "OK", kj::HttpHeaders(headerTable), uint64_t(0)); + return kj::READY_NOW; + } + } + +private: + kj::Timer& timer; + HttpHeaderTable& headerTable; +}; + +KJ_TEST("canceling a length stream mid-read correctly discards rest of request") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + ReadCancelHttpService service(timer, table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + { + static constexpr kj::StringPtr REQUEST = + "POST / HTTP/1.1\r\n" + "Content-Length: 6\r\n" + "\r\n" + "fooba"_kj; // incomplete + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + auto promise = expectRead(*pipe.ends[1], + "HTTP/1.1 408 Request Timeout\r\n" + "Content-Length: 0\r\n" + "\r\n"_kj); + + KJ_EXPECT(!promise.poll(waitScope)); + + // Trigger timeout, then response should be sent. + timer.advanceTo(timer.now() + 1 * kj::MILLISECONDS); + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } + + // We left our request stream hanging. The server will try to read and discard the request body. + // Let's give it the rest of the data, followed by a second request. + { + static constexpr kj::StringPtr REQUEST = + "r" + "GET / HTTP/1.1\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + auto promise = expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "\r\n"_kj); + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } +} + +KJ_TEST("canceling a chunked stream mid-read correctly discards rest of request") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + ReadCancelHttpService service(timer, table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + { + static constexpr kj::StringPtr REQUEST = + "POST / HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n" + "fooba"_kj; // incomplete chunk + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + auto promise = expectRead(*pipe.ends[1], + "HTTP/1.1 408 Request Timeout\r\n" + "Content-Length: 0\r\n" + "\r\n"_kj); + + KJ_EXPECT(!promise.poll(waitScope)); + + // Trigger timeout, then response should be sent. + timer.advanceTo(timer.now() + 1 * kj::MILLISECONDS); + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } + + // We left our request stream hanging. The server will try to read and discard the request body. + // Let's give it the rest of the data, followed by a second request. + { + static constexpr kj::StringPtr REQUEST = + "r\r\n" + "4a\r\n" + "this is some text that is the body of a chunk and not a valid chunk header\r\n" + "0\r\n" + "\r\n" + "GET / HTTP/1.1\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + + auto promise = expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "\r\n"_kj); + KJ_ASSERT(promise.poll(waitScope)); + promise.wait(waitScope); + } +} + +KJ_TEST("drain() doesn't lose bytes when called at the wrong moment") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + DummyService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttpCleanDrain(*pipe.ends[0]); + + // Do a regular request. + static constexpr kj::StringPtr REQUEST = + "GET / HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST.begin(), REQUEST.size()).wait(waitScope); + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 13\r\n" + "\r\n" + "example.com:/"_kj).wait(waitScope); + + // Make sure the server is blocked on the next read from the socket. + kj::Promise(kj::NEVER_DONE).poll(waitScope); + + // Now simultaneously deliver a new request AND drain the socket. + auto drainPromise = server.drain(); + static constexpr kj::StringPtr REQUEST2 = + "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST2.begin(), REQUEST2.size()).wait(waitScope); + +#if KJ_HTTP_TEST_USE_OS_PIPE + // In the case of an OS pipe, the drain will complete before any data is read from the socket. + drainPromise.wait(waitScope); + + // The HTTP server should indicate the connection was released but still valid. + KJ_ASSERT(listenTask.wait(waitScope)); + + // The request will not have been read off the socket. We can read it now. + pipe.ends[1]->shutdownWrite(); + KJ_EXPECT(pipe.ends[0]->readAllText().wait(waitScope) == REQUEST2); + +#else + // In the case of an in-memory pipe, the write() will have delivered bytes directly to the + // destination buffer synchronously, which means that the server must handle the request + // before draining. + KJ_EXPECT(!drainPromise.poll(waitScope)); + + // The HTTP request should get a response. + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 16\r\n" + "\r\n" + "example.com:/foo"_kj).wait(waitScope); + + // Now the drain completes. + drainPromise.wait(waitScope); + + // The HTTP server should indicate the connection was released but still valid. + KJ_ASSERT(listenTask.wait(waitScope)); +#endif +} + +KJ_TEST("drain() does not cancel the first request on a new connection") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + DummyService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttpCleanDrain(*pipe.ends[0]); + + // Request a drain(). It won't complete, because the newly-connected socket is considered to have + // an in-flight request. + auto drainPromise = server.drain(); + KJ_EXPECT(!drainPromise.poll(waitScope)); + + // Deliver the request. + static constexpr kj::StringPtr REQUEST2 = + "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST2.begin(), REQUEST2.size()).wait(waitScope); + + // It should get a response. + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Content-Length: 16\r\n" + "\r\n" + "example.com:/foo"_kj).wait(waitScope); + + // Now the drain completes. + drainPromise.wait(waitScope); + + // The HTTP server should indicate the connection was released but still valid. + KJ_ASSERT(listenTask.wait(waitScope)); +} + +KJ_TEST("drain() when NOT using listenHttpCleanDrain() sends Connection: close header") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable table; + DummyService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + // Request a drain(). It won't complete, because the newly-connected socket is considered to have + // an in-flight request. + auto drainPromise = server.drain(); + KJ_EXPECT(!drainPromise.poll(waitScope)); + + // Deliver the request. + static constexpr kj::StringPtr REQUEST2 = + "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj; + pipe.ends[1]->write(REQUEST2.begin(), REQUEST2.size()).wait(waitScope); + + // It should get a response. + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "Connection: close\r\n" + "Content-Length: 16\r\n" + "\r\n" + "example.com:/foo"_kj).wait(waitScope); + + // And then EOF. + auto rest = pipe.ends[1]->readAllText(); + KJ_ASSERT(rest.poll(waitScope)); + KJ_EXPECT(rest.wait(waitScope) == nullptr); + + // The drain task and listen task are done. + drainPromise.wait(waitScope); + listenTask.wait(waitScope); +} + +class BrokenConnectionListener final: public kj::ConnectionReceiver { +public: + void fulfillOne(kj::Own stream) { + fulfiller->fulfill(kj::mv(stream)); + } + + kj::Promise> accept() override { + auto paf = kj::newPromiseAndFulfiller>(); + fulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + + uint getPort() override { + KJ_UNIMPLEMENTED("not used"); + } private: - kj::Maybe exception; + kj::Own>> fulfiller; +}; + +class BrokenConnection final: public kj::AsyncIoStream { +public: + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return KJ_EXCEPTION(FAILED, "broken"); + } + Promise write(const void* buffer, size_t size) override { + return KJ_EXCEPTION(FAILED, "broken"); + } + Promise write(ArrayPtr> pieces) override { + return KJ_EXCEPTION(FAILED, "broken"); + } + Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + void shutdownWrite() override {} +}; + +KJ_TEST("HttpServer.listenHttp() doesn't prematurely terminate if an accepted connection is broken") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + DummyService service(table); + HttpServer server(timer, table, service); + + BrokenConnectionListener listener; + auto promise = server.listenHttp(listener).eagerlyEvaluate(nullptr); + + // Loop is waiting for a connection. + KJ_ASSERT(!promise.poll(waitScope)); + + KJ_EXPECT_LOG(ERROR, "failed: broken"); + listener.fulfillOne(kj::heap()); + + // The loop should not have stopped, even though the connection was broken. + KJ_ASSERT(!promise.poll(waitScope)); +} + +KJ_TEST("HttpServer handles disconnected exception for clients disconnecting after headers") { + // This test case reproduces a race condition where a client could disconnect after the server + // sent response headers but before it sent the response body, resulting in a broken pipe + // "disconnected" exception when writing the body. The default handler for application errors + // tells the server to ignore "disconnected" exceptions and close the connection, but code + // after the handler exercised the broken connection, causing the server loop to instead fail + // with a "failed" exception. + + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + class SendErrorHttpService final: public HttpService { + // HttpService that serves an error page via sendError(). + public: + SendErrorHttpService(HttpHeaderTable& headerTable): headerTable(headerTable) {} + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& responseSender) override { + return responseSender.sendError(404, "Not Found", headerTable); + } + + private: + HttpHeaderTable& headerTable; + }; + + class DisconnectingAsyncIoStream final: public kj::AsyncIoStream { + public: + DisconnectingAsyncIoStream(AsyncIoStream& inner): inner(inner) {} + + Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner.read(buffer, minBytes, maxBytes); + } + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner.tryRead(buffer, minBytes, maxBytes); + } + + Maybe tryGetLength() override { return inner.tryGetLength(); } + + Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return inner.pumpTo(output, amount); + } + + Promise write(const void* buffer, size_t size) override { + int writeId = writeCount++; + if (writeId == 0) { + // Allow first write (headers) to succeed. + auto promise = inner.write(buffer, size); + inner.shutdownWrite(); + return promise; + } else if (writeId == 1) { + // Fail subsequent write (body) with a disconnected exception. + return KJ_EXCEPTION(DISCONNECTED, "a_disconnected_exception"); + } else { + KJ_FAIL_ASSERT("Unexpected write"); + } + } + Promise write(ArrayPtr> pieces) override { + return inner.write(pieces); + } + + Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { + return inner.tryPumpFrom(input, amount); + } + + Promise whenWriteDisconnected() override { + return inner.whenWriteDisconnected(); + } + + void shutdownWrite() override { + return inner.shutdownWrite(); + } + + void abortRead() override { + return inner.abortRead(); + } + + void getsockopt(int level, int option, void* value, uint* length) override { + return inner.getsockopt(level, option, value, length); + } + void setsockopt(int level, int option, const void* value, uint length) override { + return inner.setsockopt(level, option, value, length); + } + + void getsockname(struct sockaddr* addr, uint* length) override { + return inner.getsockname(addr, length); + } + void getpeername(struct sockaddr* addr, uint* length) override { + return inner.getsockname(addr, length); + } + + int writeCount = 0; + + private: + kj::AsyncIoStream& inner; + }; + + class TestErrorHandler: public HttpServerErrorHandler { + public: + kj::Promise handleApplicationError( + kj::Exception exception, kj::Maybe response) override { + applicationErrorCount++; + if (exception.getType() == kj::Exception::Type::DISCONNECTED) { + // Tell HttpServer to ignore disconnected exceptions (the default behavior). + return kj::READY_NOW; + } + KJ_FAIL_ASSERT("Unexpected application error type", exception.getType()); + } + + int applicationErrorCount = 0; + }; + + TestErrorHandler testErrorHandler; + HttpServerSettings settings {}; + settings.errorHandler = testErrorHandler; + + HttpHeaderTable table; + SendErrorHttpService service(table); + HttpServer server(timer, table, service, settings); + + auto stream = kj::heap(*pipe.ends[0]); + auto listenPromise = server.listenHttpCleanDrain(*stream); + + static constexpr auto request = "GET / HTTP/1.1\r\n\r\n"_kj; + pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + // Client races to read headers but not body, then disconnects. (Note that the following code + // doesn't reliably reproduce the race condition by itself -- DisconnectingAsyncIoStream is + // needed to ensure the disconnected exception throws on the correct write promise.) + expectRead(*pipe.ends[1], + "HTTP/1.1 404 Not Found\r\n" + "Content-Length: 9\r\n" + "\r\n"_kj).wait(waitScope); + pipe.ends[1] = nullptr; + + // The race condition failure would manifest as a "previous HTTP message body incomplete" + // "FAILED" exception here: + bool canReuse = listenPromise.wait(waitScope); + + KJ_ASSERT(!canReuse); + KJ_ASSERT(stream->writeCount == 2); + KJ_ASSERT(testErrorHandler.applicationErrorCount == 1); +} + +// ======================================================================================= +// CONNECT tests + +class ConnectEchoService final: public HttpService { + // A simple CONNECT echo. It will always accept, and whatever data it + // receives will be echoed back. +public: + ConnectEchoService(HttpHeaderTable& headerTable, uint statusCodeToSend = 200) + : headerTable(headerTable), + statusCodeToSend(statusCodeToSend) { + KJ_ASSERT(statusCodeToSend >= 200 && statusCodeToSend < 300); + } + + uint connectCount = 0; + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + connectCount++; + response.accept(statusCodeToSend, "OK", HttpHeaders(headerTable)); + return connection.pumpTo(connection).ignoreResult(); + } + +private: + HttpHeaderTable& headerTable; + uint statusCodeToSend; +}; + +class ConnectRejectService final: public HttpService { + // A simple CONNECT implementation that always rejects. +public: + ConnectRejectService(HttpHeaderTable& headerTable, uint statusCodeToSend = 400) + : headerTable(headerTable), + statusCodeToSend(statusCodeToSend) { + KJ_ASSERT(statusCodeToSend >= 300); + } + + uint connectCount = 0; + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + connectCount++; + auto out = response.reject(statusCodeToSend, "Failed"_kj, HttpHeaders(headerTable), 4); + return out->write("boom", 4).attach(kj::mv(out)); + } + +private: + HttpHeaderTable& headerTable; + uint statusCodeToSend; +}; + +class ConnectCancelReadService final: public HttpService { + // A simple CONNECT server that will accept a connection then immediately + // cancel reading from it to test handling of abrupt termination. +public: + ConnectCancelReadService(HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", HttpHeaders(headerTable)); + // Return an immediately resolved promise and drop the connection + return kj::READY_NOW; + } + +private: + HttpHeaderTable& headerTable; +}; + +class ConnectCancelWriteService final: public HttpService { + // A simple CONNECT server that will accept a connection then immediately + // cancel writing to it to test handling of abrupt termination. +public: + ConnectCancelWriteService(HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", HttpHeaders(headerTable)); + + auto msg = "hello"_kj; + auto promise KJ_UNUSED = connection.write(msg.begin(), 5); + + // Return an immediately resolved promise and drop the io + return kj::READY_NOW; + } + +private: + HttpHeaderTable& headerTable; +}; + +class ConnectHttpService final: public HttpService { + // A CONNECT service that tunnels HTTP requests just to verify that, yes, the CONNECT + // impl can actually tunnel actual protocols. +public: + ConnectHttpService(HttpHeaderTable& table) + : timer(kj::origin()), + tunneledService(table), + server(timer, table, tunneledService) {} +private: + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", HttpHeaders(tunneledService.table)); + return server.listenHttp(kj::Own(&connection, kj::NullDisposer::instance)); + } + + class SimpleHttpService final: public HttpService { + public: + SimpleHttpService(HttpHeaderTable& table) : table(table) {} + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + auto out = response.send(200, "OK"_kj, HttpHeaders(table)); + auto msg = "hello there"_kj; + return out->write(msg.begin(), 11).attach(kj::mv(out)); + } + + HttpHeaderTable& table; + }; + + kj::TimerImpl timer; + SimpleHttpService tunneledService; + HttpServer server; +}; + +class ConnectCloseService final: public HttpService { + // A simple CONNECT server that will accept a connection then immediately + // shutdown the write side of the AsyncIoStream to simulate socket disconnection. +public: + ConnectCloseService(HttpHeaderTable& headerTable) + : headerTable(headerTable) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + KJ_UNIMPLEMENTED("Regular HTTP requests are not implemented here."); + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) override { + response.accept(200, "OK", HttpHeaders(headerTable)); + connection.shutdownWrite(); + return kj::READY_NOW; + } + +private: + HttpHeaderTable& headerTable; +}; + +KJ_TEST("Simple CONNECT Server works") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + HttpHeaderTable table; -}; + ConnectEchoService service(table); + HttpServer server(timer, table, service); -KJ_TEST("HttpServer threw exception after starting response") { - auto PIPELINE_TESTS = pipelineTestCases(); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - auto io = kj::setupAsyncIo(); - auto pipe = io.provider->newTwoWayPipe(); + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "\r\n" + "hello"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +KJ_TEST("Simple CONNECT Client/Server works") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); HttpHeaderTable table; - PartialResponseService service; - HttpServer server(io.provider->getTimer(), table, service); + ConnectEchoService service(table); + HttpServer server(timer, table, service); auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - KJ_EXPECT_LOG(ERROR, "HttpService threw exception after generating a partial response"); + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable clientHeaders; + // Initiates a CONNECT with the echo server. Once established, sends a bit of data + // and waits for it to be echoed back. + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); + + request.status.then([io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(io->write("hello", 5)); + promises.add(expectRead(*io, "hello"_kj)); + return kj::joinPromises(promises.finish()) + .then([io=kj::mv(io)]() mutable { + io->shutdownWrite(); + return expectEnd(*io).attach(kj::mv(io)); + }); + }).wait(waitScope); - // Do one request. - pipe.ends[1]->write(PIPELINE_TESTS[0].request.raw.begin(), PIPELINE_TESTS[0].request.raw.size()) - .wait(io.waitScope); - auto text = pipe.ends[1]->readAllText().wait(io.waitScope); + listenTask.wait(waitScope); - KJ_EXPECT(text == - "HTTP/1.1 200 OK\r\n" - "Content-Length: 32\r\n" - "\r\n" - "foo", text); + KJ_ASSERT(service.connectCount == 1); } -// ----------------------------------------------------------------------------- +KJ_TEST("CONNECT Server (201 status)") { + KJ_HTTP_TEST_SETUP_IO; -KJ_TEST("HttpClient to capnproto.org") { - auto io = kj::setupAsyncIo(); + // Test that CONNECT works with 2xx status codes that typically do + // not carry a response payload. - auto maybeConn = io.provider->getNetwork().parseAddress("capnproto.org", 80) - .then([](kj::Own addr) { - auto promise = addr->connect(); - return promise.attach(kj::mv(addr)); - }).then([](kj::Own&& connection) -> kj::Maybe> { - return kj::mv(connection); - }, [](kj::Exception&& e) -> kj::Maybe> { - KJ_LOG(WARNING, "skipping test because couldn't connect to capnproto.org"); - return nullptr; - }).wait(io.waitScope); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - KJ_IF_MAYBE(conn, maybeConn) { - // Successfully connected to capnproto.org. Try doing GET /. We expect to get a redirect to - // HTTPS, because what kind of horrible web site would serve in plaintext, really? + kj::TimerImpl timer(kj::origin()); - HttpHeaderTable table; - auto client = newHttpClient(table, **conn); + HttpHeaderTable table; + ConnectEchoService service(table, 201); + HttpServer server(timer, table, service); - HttpHeaders headers(table); - headers.set(HttpHeaderId::HOST, "capnproto.org"); + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - auto response = client->request(HttpMethod::GET, "/", headers).response.wait(io.waitScope); - KJ_EXPECT(response.statusCode / 100 == 3); - auto location = KJ_ASSERT_NONNULL(response.headers->get(HttpHeaderId::LOCATION)); - KJ_EXPECT(location == "https://capnproto.org/"); + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; - auto body = response.body->readAllText().wait(io.waitScope); - } + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 201 OK\r\n" + "\r\n" + "hello"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +KJ_TEST("CONNECT Client (204 status)") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + // Test that CONNECT works with 2xx status codes that typically do + // not carry a response payload. + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectEchoService service(table, 204); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable clientHeaders; + // Initiates a CONNECT with the echo server. Once established, sends a bit of data + // and waits for it to be echoed back. + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); + + request.status.then([io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 204); + KJ_ASSERT(status.statusText == "OK"_kj); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(io->write("hello", 5)); + promises.add(expectRead(*io, "hello"_kj)); + + return kj::joinPromises(promises.finish()) + .then([io=kj::mv(io)]() mutable { + io->shutdownWrite(); + return expectEnd(*io).attach(kj::mv(io)); + }); + }).wait(waitScope); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +KJ_TEST("CONNECT Server rejected") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectRejectService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 400 Failed\r\n" + "Connection: close\r\n" + "Content-Length: 4\r\n" + "\r\n" + "boom"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} + +#ifndef KJ_HTTP_TEST_USE_OS_PIPE +KJ_TEST("CONNECT Client rejected") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectRejectService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable clientHeaders; + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); + + request.status.then([](auto status) mutable { + KJ_ASSERT(status.statusCode == 400); + KJ_ASSERT(status.statusText == "Failed"_kj); + + auto& errorBody = KJ_ASSERT_NONNULL(status.errorBody); + + return expectRead(*errorBody, "boom"_kj).then([&errorBody=*errorBody]() { + return expectEnd(errorBody); + }).attach(kj::mv(errorBody)); + }).wait(waitScope); + + listenTask.wait(waitScope); + + KJ_ASSERT(service.connectCount == 1); +} +#endif + +KJ_TEST("CONNECT Server cancels read") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectCancelReadService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "\r\n"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); +} + +#ifndef KJ_HTTP_TEST_USE_OS_PIPE +KJ_TEST("CONNECT Server cancels read w/client") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectCancelReadService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + bool failed = false; + + HttpHeaderTable clientHeaders; + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); + + request.status.then([&failed, io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + + return io->write("hello", 5).catch_([&](kj::Exception&& ex) { + KJ_ASSERT(ex.getType() == kj::Exception::Type::DISCONNECTED); + failed = true; + }).attach(kj::mv(io)); + }).wait(waitScope); + + KJ_ASSERT(failed, "the write promise should have failed"); + + listenTask.wait(waitScope); +} +#endif + +KJ_TEST("CONNECT Server cancels write") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectCancelWriteService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "\r\n" + "hello"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 200 OK\r\n" + "\r\n"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); +} + +#ifndef KJ_HTTP_TEST_USE_OS_PIPE +KJ_TEST("CONNECT Server cancels write w/client") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectCancelWriteService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable clientHeaders; + bool failed = false; + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(clientHeaders), {}); + + request.status.then([&failed, io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + + return io->write("hello", 5).catch_([&failed](kj::Exception&& ex) mutable { + KJ_ASSERT(ex.getType() == kj::Exception::Type::DISCONNECTED); + failed = true; + }).attach(kj::mv(io)); + }).wait(waitScope); + + KJ_ASSERT(failed, "the write promise should have failed"); + + listenTask.wait(waitScope); +} +#endif + +KJ_TEST("CONNECT rejects Transfer-Encoding") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectEchoService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "5\r\n" + "hello" + "0\r\n"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 18\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Bad Request"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); +} + +KJ_TEST("CONNECT rejects Content-Length") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + + HttpHeaderTable table; + ConnectEchoService service(table); + HttpServer server(timer, table, service); + + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto msg = "CONNECT https://example.org HTTP/1.1\r\n" + "Content-Length: 5\r\n" + "\r\n" + "hello"_kj; + + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + pipe.ends[1]->shutdownWrite(); + + expectRead(*pipe.ends[1], + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 18\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Bad Request"_kj).wait(waitScope); + + expectEnd(*pipe.ends[1]); + + listenTask.wait(waitScope); +} + +KJ_TEST("CONNECT HTTP-tunneled-over-CONNECT") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + HttpHeaderTable table; + ConnectHttpService service(table); + HttpServer server(timer, table, service); + + auto listenTask KJ_UNUSED = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable connectHeaderTable; + HttpHeaderTable tunneledHeaderTable; + HttpClientSettings settings; + + auto request = client->connect( + "https://example.org"_kj, HttpHeaders(connectHeaderTable), {}); + + auto text = request.status.then([ + &tunneledHeaderTable, + &settings, + io=kj::mv(request.connection)](auto status) mutable { + KJ_ASSERT(status.statusCode == 200); + KJ_ASSERT(status.statusText == "OK"_kj); + auto client = newHttpClient(tunneledHeaderTable, *io, settings) + .attach(kj::mv(io)); + + return client->request(HttpMethod::GET, "http://example.org"_kj, + HttpHeaders(tunneledHeaderTable)) + .response.then([](HttpClient::Response&& response) { + return response.body->readAllText().attach(kj::mv(response)); + }).attach(kj::mv(client)); + }).wait(waitScope); + + KJ_ASSERT(text == "hello there"); +} + +KJ_TEST("CONNECT HTTP-tunneled-over-pipelined-CONNECT") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + HttpHeaderTable table; + ConnectHttpService service(table); + HttpServer server(timer, table, service); + + auto listenTask KJ_UNUSED = server.listenHttp(kj::mv(pipe.ends[0])); + + auto client = newHttpClient(table, *pipe.ends[1]); + + HttpHeaderTable connectHeaderTable; + HttpHeaderTable tunneledHeaderTable; + HttpClientSettings settings; + + auto request = client->connect( + "https://exmaple.org"_kj, HttpHeaders(connectHeaderTable), {}); + auto conn = kj::mv(request.connection); + auto proxyClient = newHttpClient(tunneledHeaderTable, *conn, settings).attach(kj::mv(conn)); + + auto get = proxyClient->request(HttpMethod::GET, + "http://example.org"_kj, + HttpHeaders(tunneledHeaderTable)); + auto text = get.response.then([](HttpClient::Response&& response) mutable { + return response.body->readAllText().attach(kj::mv(response)); + }).attach(kj::mv(proxyClient)).wait(waitScope); + + KJ_ASSERT(text == "hello there"); +} + +KJ_TEST("CONNECT pipelined via an adapter") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + HttpHeaderTable table; + ConnectHttpService service(table); + HttpServer server(timer, table, service); + + auto listenTask KJ_UNUSED = server.listenHttp(kj::mv(pipe.ends[0])); + + bool acceptCalled = false; + + auto client = newHttpClient(table, *pipe.ends[1]); + auto adaptedService = kj::newHttpService(*client).attach(kj::mv(client)); + + // adaptedService is an HttpService that wraps an HttpClient that sends + // a request to server. + + auto clientPipe = newTwoWayPipe(); + + struct ResponseImpl final: public HttpService::ConnectResponse { + bool& acceptCalled; + ResponseImpl(bool& acceptCalled) : acceptCalled(acceptCalled) {} + void accept(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) override { + acceptCalled = true; + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + KJ_UNREACHABLE; + } + }; + + ResponseImpl response(acceptCalled); + + HttpHeaderTable connectHeaderTable; + HttpHeaderTable tunneledHeaderTable; + HttpClientSettings settings; + + auto promise = adaptedService->connect("https://example.org"_kj, + HttpHeaders(connectHeaderTable), + *clientPipe.ends[0], + response, + {}).attach(kj::mv(clientPipe.ends[0])); + + auto proxyClient = newHttpClient(tunneledHeaderTable, *clientPipe.ends[1], settings) + .attach(kj::mv(clientPipe.ends[1])); + + auto text = proxyClient->request(HttpMethod::GET, + "http://example.org"_kj, + HttpHeaders(tunneledHeaderTable)) + .response.then([](HttpClient::Response&& response) mutable { + return response.body->readAllText().attach(kj::mv(response)); + }).wait(waitScope); + + KJ_ASSERT(acceptCalled); + KJ_ASSERT(text == "hello there"); +} + +KJ_TEST("CONNECT pipelined via an adapter (reject)") { + KJ_HTTP_TEST_SETUP_IO; + + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + kj::TimerImpl timer(kj::origin()); + HttpHeaderTable table; + ConnectRejectService service(table); + HttpServer server(timer, table, service); + + auto listenTask KJ_UNUSED = server.listenHttp(kj::mv(pipe.ends[0])); + + bool rejectCalled = false; + bool failedAsExpected = false; + + auto client = newHttpClient(table, *pipe.ends[1]); + auto adaptedService = kj::newHttpService(*client).attach(kj::mv(client)); + + // adaptedService is an HttpService that wraps an HttpClient that sends + // a request to server. + + auto clientPipe = newTwoWayPipe(); + + struct ResponseImpl final: public HttpService::ConnectResponse { + bool& rejectCalled; + kj::OneWayPipe pipe; + ResponseImpl(bool& rejectCalled) + : rejectCalled(rejectCalled), + pipe(kj::newOneWayPipe()) {} + void accept(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) override { + KJ_UNREACHABLE; + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + rejectCalled = true; + return kj::mv(pipe.out); + } + + kj::Own getRejectStream() { + return kj::mv(pipe.in); + } + }; + + ResponseImpl response(rejectCalled); + + HttpHeaderTable connectHeaderTable; + HttpHeaderTable tunneledHeaderTable; + HttpClientSettings settings; + + auto promise = adaptedService->connect("https://example.org"_kj, + HttpHeaders(connectHeaderTable), + *clientPipe.ends[0], + response, + {}).attach(kj::mv(clientPipe.ends[0])); + + auto proxyClient = newHttpClient(tunneledHeaderTable, *clientPipe.ends[1], settings) + .attach(kj::mv(clientPipe.ends[1])); + + auto text = proxyClient->request(HttpMethod::GET, + "http://example.org"_kj, + HttpHeaders(tunneledHeaderTable)) + .response.then([](HttpClient::Response&& response) mutable { + return response.body->readAllText().attach(kj::mv(response)); + }, [&](kj::Exception&& ex) -> kj::Promise { + // We fully expect the stream to fail here. + if (ex.getDescription() == "stream disconnected prematurely") { + failedAsExpected = true; + } + return kj::str("ok"); + }).wait(waitScope); + + auto rejectStream = response.getRejectStream(); + +#ifndef KJ_HTTP_TEST_USE_OS_PIPE + expectRead(*rejectStream, "boom"_kj).wait(waitScope); +#endif + + KJ_ASSERT(rejectCalled); + KJ_ASSERT(failedAsExpected); + KJ_ASSERT(text == "ok"); } } // namespace diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index af95ecf577..b1968f3576 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -20,13 +20,316 @@ // THE SOFTWARE. #include "http.h" +#include "kj/exception.h" +#include "url.h" #include #include +#include #include #include +#include +#include +#include +#include +#if KJ_HAS_ZLIB +#include +#endif // KJ_HAS_ZLIB namespace kj { +// ======================================================================================= +// SHA-1 implementation from https://github.com/clibs/sha1 +// +// The WebSocket standard depends on SHA-1. ARRRGGGHHHHH. +// +// Any old checksum would have served the purpose, or hell, even just returning the header +// verbatim. But NO, they decided to throw a whole complicated hash algorithm in there, AND +// THEY CHOSE A BROKEN ONE THAT WE OTHERWISE WOULDN'T NEED ANYMORE. +// +// TODO(cleanup): Move this to a shared hashing library. Maybe. Or maybe don't, because no one +// should be using SHA-1 anymore. +// +// THIS USAGE IS NOT SECURITY SENSITIVE. IF YOU REPORT A SECURITY ISSUE BECAUSE YOU SAW SHA1 IN THE +// SOURCE CODE I WILL MAKE FUN OF YOU. + +/* +SHA-1 in C +By Steve Reid +100% Public Domain +Test Vectors (from FIPS PUB 180-1) +"abc" + A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D +"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" + 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1 +A million repetitions of "a" + 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F +*/ + +/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */ +/* #define SHA1HANDSOFF * Copies data before messing with it. */ + +#define SHA1HANDSOFF + +typedef struct +{ + uint32_t state[5]; + uint32_t count[2]; + unsigned char buffer[64]; +} SHA1_CTX; + +#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits)))) + +/* blk0() and blk() perform the initial expand. */ +/* I got the idea of expanding during the round function from SSLeay */ +#if BYTE_ORDER == LITTLE_ENDIAN +#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \ + |(rol(block->l[i],8)&0x00FF00FF)) +#elif BYTE_ORDER == BIG_ENDIAN +#define blk0(i) block->l[i] +#else +#error "Endianness not defined!" +#endif +#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \ + ^block->l[(i+2)&15]^block->l[i&15],1)) + +/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */ +#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30); +#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30); +#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30); + + +/* Hash a single 512-bit block. This is the core of the algorithm. */ + +void SHA1Transform( + uint32_t state[5], + const unsigned char buffer[64] +) +{ + uint32_t a, b, c, d, e; + + typedef union + { + unsigned char c[64]; + uint32_t l[16]; + } CHAR64LONG16; + +#ifdef SHA1HANDSOFF + CHAR64LONG16 block[1]; /* use array to appear as a pointer */ + + memcpy(block, buffer, 64); +#else + /* The following had better never be used because it causes the + * pointer-to-const buffer to be cast into a pointer to non-const. + * And the result is written through. I threw a "const" in, hoping + * this will cause a diagnostic. + */ + CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer; +#endif + /* Copy context->state[] to working vars */ + a = state[0]; + b = state[1]; + c = state[2]; + d = state[3]; + e = state[4]; + /* 4 rounds of 20 operations each. Loop unrolled. */ + R0(a, b, c, d, e, 0); + R0(e, a, b, c, d, 1); + R0(d, e, a, b, c, 2); + R0(c, d, e, a, b, 3); + R0(b, c, d, e, a, 4); + R0(a, b, c, d, e, 5); + R0(e, a, b, c, d, 6); + R0(d, e, a, b, c, 7); + R0(c, d, e, a, b, 8); + R0(b, c, d, e, a, 9); + R0(a, b, c, d, e, 10); + R0(e, a, b, c, d, 11); + R0(d, e, a, b, c, 12); + R0(c, d, e, a, b, 13); + R0(b, c, d, e, a, 14); + R0(a, b, c, d, e, 15); + R1(e, a, b, c, d, 16); + R1(d, e, a, b, c, 17); + R1(c, d, e, a, b, 18); + R1(b, c, d, e, a, 19); + R2(a, b, c, d, e, 20); + R2(e, a, b, c, d, 21); + R2(d, e, a, b, c, 22); + R2(c, d, e, a, b, 23); + R2(b, c, d, e, a, 24); + R2(a, b, c, d, e, 25); + R2(e, a, b, c, d, 26); + R2(d, e, a, b, c, 27); + R2(c, d, e, a, b, 28); + R2(b, c, d, e, a, 29); + R2(a, b, c, d, e, 30); + R2(e, a, b, c, d, 31); + R2(d, e, a, b, c, 32); + R2(c, d, e, a, b, 33); + R2(b, c, d, e, a, 34); + R2(a, b, c, d, e, 35); + R2(e, a, b, c, d, 36); + R2(d, e, a, b, c, 37); + R2(c, d, e, a, b, 38); + R2(b, c, d, e, a, 39); + R3(a, b, c, d, e, 40); + R3(e, a, b, c, d, 41); + R3(d, e, a, b, c, 42); + R3(c, d, e, a, b, 43); + R3(b, c, d, e, a, 44); + R3(a, b, c, d, e, 45); + R3(e, a, b, c, d, 46); + R3(d, e, a, b, c, 47); + R3(c, d, e, a, b, 48); + R3(b, c, d, e, a, 49); + R3(a, b, c, d, e, 50); + R3(e, a, b, c, d, 51); + R3(d, e, a, b, c, 52); + R3(c, d, e, a, b, 53); + R3(b, c, d, e, a, 54); + R3(a, b, c, d, e, 55); + R3(e, a, b, c, d, 56); + R3(d, e, a, b, c, 57); + R3(c, d, e, a, b, 58); + R3(b, c, d, e, a, 59); + R4(a, b, c, d, e, 60); + R4(e, a, b, c, d, 61); + R4(d, e, a, b, c, 62); + R4(c, d, e, a, b, 63); + R4(b, c, d, e, a, 64); + R4(a, b, c, d, e, 65); + R4(e, a, b, c, d, 66); + R4(d, e, a, b, c, 67); + R4(c, d, e, a, b, 68); + R4(b, c, d, e, a, 69); + R4(a, b, c, d, e, 70); + R4(e, a, b, c, d, 71); + R4(d, e, a, b, c, 72); + R4(c, d, e, a, b, 73); + R4(b, c, d, e, a, 74); + R4(a, b, c, d, e, 75); + R4(e, a, b, c, d, 76); + R4(d, e, a, b, c, 77); + R4(c, d, e, a, b, 78); + R4(b, c, d, e, a, 79); + /* Add the working vars back into context.state[] */ + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + /* Wipe variables */ + a = b = c = d = e = 0; +#ifdef SHA1HANDSOFF + memset(block, '\0', sizeof(block)); +#endif +} + + +/* SHA1Init - Initialize new context */ + +void SHA1Init( + SHA1_CTX * context +) +{ + /* SHA1 initialization constants */ + context->state[0] = 0x67452301; + context->state[1] = 0xEFCDAB89; + context->state[2] = 0x98BADCFE; + context->state[3] = 0x10325476; + context->state[4] = 0xC3D2E1F0; + context->count[0] = context->count[1] = 0; +} + + +/* Run your data through this. */ + +void SHA1Update( + SHA1_CTX * context, + const unsigned char *data, + uint32_t len +) +{ + uint32_t i; + + uint32_t j; + + j = context->count[0]; + if ((context->count[0] += len << 3) < j) + context->count[1]++; + context->count[1] += (len >> 29); + j = (j >> 3) & 63; + if ((j + len) > 63) + { + memcpy(&context->buffer[j], data, (i = 64 - j)); + SHA1Transform(context->state, context->buffer); + for (; i + 63 < len; i += 64) + { + SHA1Transform(context->state, &data[i]); + } + j = 0; + } + else + i = 0; + memcpy(&context->buffer[j], &data[i], len - i); +} + + +/* Add padding and return the message digest. */ + +void SHA1Final( + unsigned char digest[20], + SHA1_CTX * context +) +{ + unsigned i; + + unsigned char finalcount[8]; + + unsigned char c; + +#if 0 /* untested "improvement" by DHR */ + /* Convert context->count to a sequence of bytes + * in finalcount. Second element first, but + * big-endian order within element. + * But we do it all backwards. + */ + unsigned char *fcp = &finalcount[8]; + for (i = 0; i < 2; i++) + { + uint32_t t = context->count[i]; + int j; + for (j = 0; j < 4; t >>= 8, j++) + *--fcp = (unsigned char) t} +#else + for (i = 0; i < 8; i++) + { + finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */ + } +#endif + c = 0200; + SHA1Update(context, &c, 1); + while ((context->count[0] & 504) != 448) + { + c = 0000; + SHA1Update(context, &c, 1); + } + SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */ + for (i = 0; i < 20; i++) + { + digest[i] = (unsigned char) + ((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255); + } + /* Wipe variables */ + memset(context, '\0', sizeof(*context)); + memset(&finalcount, '\0', sizeof(finalcount)); +} + +// End SHA-1 implementation. +// ======================================================================================= + static const char* METHOD_NAMES[] = { #define METHOD_NAME(id) #id, KJ_HTTP_FOR_EACH_METHOD(METHOD_NAME) @@ -37,22 +340,38 @@ kj::StringPtr KJ_STRINGIFY(HttpMethod method) { return METHOD_NAMES[static_cast(method)]; } -static kj::Maybe consumeHttpMethod(char*& ptr) { +kj::StringPtr KJ_STRINGIFY(HttpConnectMethod method) { + return "CONNECT"_kj; +} + +static kj::Maybe> consumeHttpMethod(char*& ptr) { char* p = ptr; #define EXPECT_REST(prefix, suffix) \ if (strncmp(p, #suffix, sizeof(#suffix)-1) == 0) { \ ptr = p + (sizeof(#suffix)-1); \ - return HttpMethod::prefix##suffix; \ + return kj::Maybe>(HttpMethod::prefix##suffix); \ } else { \ return nullptr; \ } switch (*p++) { + case 'A': EXPECT_REST(A,CL) case 'C': switch (*p++) { case 'H': EXPECT_REST(CH,ECKOUT) - case 'O': EXPECT_REST(CO,PY) + case 'O': + switch (*p++) { + case 'P': EXPECT_REST(COP,Y) + case 'N': + if (strncmp(p, "NECT", 4) == 0) { + ptr = p + 4; + return kj::Maybe>(HttpConnectMethod()); + } else { + return nullptr; + } + default: return nullptr; + } default: return nullptr; } case 'D': EXPECT_REST(D,ELETE) @@ -114,6 +433,19 @@ static kj::Maybe consumeHttpMethod(char*& ptr) { } kj::Maybe tryParseHttpMethod(kj::StringPtr name) { + KJ_IF_MAYBE(method, tryParseHttpMethodAllowingConnect(name)) { + KJ_SWITCH_ONEOF(*method) { + KJ_CASE_ONEOF(m, HttpMethod) { return m; } + KJ_CASE_ONEOF(m, HttpConnectMethod) { return nullptr; } + } + KJ_UNREACHABLE; + } else { + return nullptr; + } +} + +kj::Maybe> tryParseHttpMethodAllowingConnect( + kj::StringPtr name) { // const_cast OK because we don't actually access it. consumeHttpMethod() is also called by some // code later than explicitly needs to use a non-const pointer. char* ptr = const_cast(name.begin()); @@ -129,6 +461,20 @@ kj::Maybe tryParseHttpMethod(kj::StringPtr name) { namespace { +constexpr char WEBSOCKET_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +// From RFC6455. + +static kj::String generateWebSocketAccept(kj::StringPtr key) { + // WebSocket demands we do a SHA-1 here. ARRGHH WHY SHA-1 WHYYYYYY? + SHA1_CTX ctx; + byte digest[20]; + SHA1Init(&ctx); + SHA1Update(&ctx, key.asBytes().begin(), key.size()); + SHA1Update(&ctx, reinterpret_cast(WEBSOCKET_GUID), strlen(WEBSOCKET_GUID)); + SHA1Final(digest, &ctx); + return kj::encodeBase64(digest); +} + constexpr auto HTTP_SEPARATOR_CHARS = kj::parse::anyOfChars("()<>@,;:\\\"/[]?={} \t"); // RFC2616 section 2.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 @@ -149,9 +495,8 @@ static void requireValidHeaderName(kj::StringPtr name) { } static void requireValidHeaderValue(kj::StringPtr value) { - for (char c: value) { - KJ_REQUIRE(c >= 0x20, "invalid header value", value); - } + KJ_REQUIRE(HttpHeaders::isValidHeaderValue(value), "invalid header value", + kj::encodeCEscape(value)); } static const char* BUILTIN_HEADER_NAMES[] = { @@ -161,31 +506,14 @@ static const char* BUILTIN_HEADER_NAMES[] = { #undef HEADER_NAME }; -enum class BuiltinHeaderIndices { -#define HEADER_ID(id, name) id, - KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID) -#undef HEADER_ID -}; - -static constexpr size_t CONNECTION_HEADER_COUNT KJ_UNUSED = 0 -#define COUNT_HEADER(id, name) + 1 - KJ_HTTP_FOR_EACH_CONNECTION_HEADER(COUNT_HEADER) -#undef COUNT_HEADER - ; +} // namespace -enum class ConnectionHeaderIndices { -#define HEADER_ID(id, name) id, - KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HEADER_ID) +#define HEADER_ID(id, name) constexpr uint HttpHeaders::BuiltinIndices::id; + KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID) #undef HEADER_ID -}; - -static constexpr uint CONNECTION_HEADER_XOR = kj::maxValue; -static constexpr uint CONNECTION_HEADER_THRESHOLD = CONNECTION_HEADER_XOR >> 1; - -} // namespace #define DEFINE_HEADER(id, name) \ -const HttpHeaderId HttpHeaderId::id(nullptr, static_cast(BuiltinHeaderIndices::id)); +const HttpHeaderId HttpHeaderId::id(nullptr, HttpHeaders::BuiltinIndices::id); KJ_HTTP_FOR_EACH_BUILTIN_HEADER(DEFINE_HEADER) #undef DEFINE_HEADER @@ -233,7 +561,9 @@ struct HttpHeaderTable::IdsByNameMap { }; HttpHeaderTable::Builder::Builder() - : table(kj::heap()) {} + : table(kj::heap()) { + table->buildStatus = BuildStatus::BUILDING; +} HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) { requireValidHeaderName(name); @@ -247,21 +577,15 @@ HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) { HttpHeaderTable::HttpHeaderTable() : idsByName(kj::heap()) { -#define ADD_HEADER(id, name) \ - idsByName->map.insert(std::make_pair(name, \ - static_cast(ConnectionHeaderIndices::id) ^ CONNECTION_HEADER_XOR)); - KJ_HTTP_FOR_EACH_CONNECTION_HEADER(ADD_HEADER); -#undef ADD_HEADER - #define ADD_HEADER(id, name) \ namesById.add(name); \ - idsByName->map.insert(std::make_pair(name, static_cast(BuiltinHeaderIndices::id))); + idsByName->map.insert(std::make_pair(name, HttpHeaders::BuiltinIndices::id)); KJ_HTTP_FOR_EACH_BUILTIN_HEADER(ADD_HEADER); #undef ADD_HEADER } HttpHeaderTable::~HttpHeaderTable() noexcept(false) {} -kj::Maybe HttpHeaderTable::stringToId(kj::StringPtr name) { +kj::Maybe HttpHeaderTable::stringToId(kj::StringPtr name) const { auto iter = idsByName->map.find(name); if (iter == idsByName->map.end()) { return nullptr; @@ -272,9 +596,26 @@ kj::Maybe HttpHeaderTable::stringToId(kj::StringPtr name) { // ======================================================================================= -HttpHeaders::HttpHeaders(HttpHeaderTable& table) +bool HttpHeaders::isValidHeaderValue(kj::StringPtr value) { + for (char c: value) { + // While the HTTP spec suggests that only printable ASCII characters are allowed in header + // values, reality has a different opinion. See: https://github.com/httpwg/http11bis/issues/19 + // We follow the browsers' lead. + if (c == '\0' || c == '\r' || c == '\n') { + return false; + } + } + + return true; +} + +HttpHeaders::HttpHeaders(const HttpHeaderTable& table) : table(&table), - indexedHeaders(kj::heapArray(table.idCount())) {} + indexedHeaders(kj::heapArray(table.idCount())) { + KJ_ASSERT( + table.isReady(), "HttpHeaders object was constructed from " + "HttpHeaderTable that wasn't fully built yet at the time of construction"); +} void HttpHeaders::clear() { for (auto& header: indexedHeaders) { @@ -284,6 +625,16 @@ void HttpHeaders::clear() { unindexedHeaders.clear(); } +size_t HttpHeaders::size() const { + size_t result = unindexedHeaders.size(); + for (auto i: kj::indices(indexedHeaders)) { + if (indexedHeaders[i] != nullptr) { + ++result; + } + } + return result; +} + HttpHeaders HttpHeaders::clone() const { HttpHeaders result(*table); @@ -326,6 +677,19 @@ kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) { return result; } + +namespace { + +template +constexpr bool fastCaseCmp(const char* actual); + +} // namespace + +bool HttpHeaders::isWebSocket() const { + return fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>( + get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr()); +} + void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) { id.requireFrom(*table); requireValidHeaderValue(value); @@ -342,8 +706,7 @@ void HttpHeaders::add(kj::StringPtr name, kj::StringPtr value) { requireValidHeaderName(name); requireValidHeaderValue(value); - KJ_REQUIRE(addNoCheck(name, value) == nullptr, - "can't set connection-level headers on HttpHeaders", name, value) { break; } + addNoCheck(name, value); } void HttpHeaders::add(kj::StringPtr name, kj::String&& value) { @@ -357,25 +720,31 @@ void HttpHeaders::add(kj::String&& name, kj::String&& value) { takeOwnership(kj::mv(value)); } -kj::Maybe HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) { +void HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) { KJ_IF_MAYBE(id, table->stringToId(name)) { - if (id->id > CONNECTION_HEADER_THRESHOLD) { - return id->id ^ CONNECTION_HEADER_XOR; - } - if (indexedHeaders[id->id] == nullptr) { indexedHeaders[id->id] = value; } else { // Duplicate HTTP headers are equivalent to the values being separated by a comma. - auto concat = kj::str(indexedHeaders[id->id], ", ", value); - indexedHeaders[id->id] = concat; - ownedStrings.add(concat.releaseArray()); + +#if _MSC_VER + if (_stricmp(name.cStr(), "set-cookie") == 0) { +#else + if (strcasecmp(name.cStr(), "set-cookie") == 0) { +#endif + // Uh-oh, Set-Cookie will be corrupted if we try to concatenate it. We'll make it an + // unindexed header, which is weird, but the alternative is guaranteed corruption, so... + // TODO(cleanup): Maybe HttpHeaders should just special-case set-cookie in general? + unindexedHeaders.add(Header {name, value}); + } else { + auto concat = kj::str(indexedHeaders[id->id], ", ", value); + indexedHeaders[id->id] = concat; + ownedStrings.add(concat.releaseArray()); + } } } else { unindexedHeaders.add(Header {name, value}); } - - return nullptr; } void HttpHeaders::takeOwnership(kj::String&& string) { @@ -545,82 +914,124 @@ static char* trimHeaderEnding(kj::ArrayPtr content) { return end; } -kj::Maybe HttpHeaders::tryParseRequest(kj::ArrayPtr content) { +HttpHeaders::RequestOrProtocolError HttpHeaders::tryParseRequest(kj::ArrayPtr content) { + KJ_SWITCH_ONEOF(tryParseRequestOrConnect(content)) { + KJ_CASE_ONEOF(request, Request) { + return kj::mv(request); + } + KJ_CASE_ONEOF(error, ProtocolError) { + return kj::mv(error); + } + KJ_CASE_ONEOF(connect, ConnectRequest) { + return ProtocolError { 501, "Not Implemented", + "Unrecognized request method.", content }; + } + } + KJ_UNREACHABLE; +} + +HttpHeaders::RequestConnectOrProtocolError HttpHeaders::tryParseRequestOrConnect( + kj::ArrayPtr content) { char* end = trimHeaderEnding(content); - if (end == nullptr) return nullptr; + if (end == nullptr) { + return ProtocolError { 400, "Bad Request", + "Request headers have no terminal newline.", content }; + } char* ptr = content.begin(); - HttpHeaders::Request request; + HttpHeaders::RequestConnectOrProtocolError result; KJ_IF_MAYBE(method, consumeHttpMethod(ptr)) { - request.method = *method; if (*ptr != ' ' && *ptr != '\t') { - return nullptr; + return ProtocolError { 501, "Not Implemented", + "Unrecognized request method.", content }; } ++ptr; - } else { - return nullptr; - } - KJ_IF_MAYBE(path, consumeWord(ptr)) { - request.url = *path; + kj::Maybe path; + KJ_IF_MAYBE(p, consumeWord(ptr)) { + path = *p; + } else { + return ProtocolError { 400, "Bad Request", + "Invalid request line.", content }; + } + + KJ_SWITCH_ONEOF(*method) { + KJ_CASE_ONEOF(m, HttpMethod) { + result = HttpHeaders::Request { m, KJ_ASSERT_NONNULL(path) }; + } + KJ_CASE_ONEOF(m, HttpConnectMethod) { + result = HttpHeaders::ConnectRequest { KJ_ASSERT_NONNULL(path) }; + } + } } else { - return nullptr; + return ProtocolError { 501, "Not Implemented", + "Unrecognized request method.", content }; } // Ignore rest of line. Don't care about "HTTP/1.1" or whatever. consumeLine(ptr); - if (!parseHeaders(ptr, end, request.connectionHeaders)) return nullptr; + if (!parseHeaders(ptr, end)) { + return ProtocolError { 400, "Bad Request", + "The headers sent by your client are not valid.", content }; + } - return request; + return result; } -kj::Maybe HttpHeaders::tryParseResponse(kj::ArrayPtr content) { +HttpHeaders::ResponseOrProtocolError HttpHeaders::tryParseResponse(kj::ArrayPtr content) { char* end = trimHeaderEnding(content); - if (end == nullptr) return nullptr; + if (end == nullptr) { + return ProtocolError { 502, "Bad Gateway", + "Response headers have no terminal newline.", content }; + } char* ptr = content.begin(); HttpHeaders::Response response; KJ_IF_MAYBE(version, consumeWord(ptr)) { - if (!version->startsWith("HTTP/")) return nullptr; + if (!version->startsWith("HTTP/")) { + return ProtocolError { 502, "Bad Gateway", + "Invalid response status line (invalid protocol).", content }; + } } else { - return nullptr; + return ProtocolError { 502, "Bad Gateway", + "Invalid response status line (no spaces).", content }; } KJ_IF_MAYBE(code, consumeNumber(ptr)) { response.statusCode = *code; } else { - return nullptr; + return ProtocolError { 502, "Bad Gateway", + "Invalid response status line (invalid status code).", content }; } response.statusText = consumeLine(ptr); - if (!parseHeaders(ptr, end, response.connectionHeaders)) return nullptr; + if (!parseHeaders(ptr, end)) { + return ProtocolError { 502, "Bad Gateway", + "The headers sent by the server are not valid.", content }; + } return response; } -bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connectionHeaders) { +bool HttpHeaders::tryParse(kj::ArrayPtr content) { + char* end = trimHeaderEnding(content); + if (end == nullptr) return false; + + char* ptr = content.begin(); + return parseHeaders(ptr, end); +} + +bool HttpHeaders::parseHeaders(char* ptr, char* end) { while (*ptr != '\0') { KJ_IF_MAYBE(name, consumeHeaderName(ptr)) { kj::StringPtr line = consumeLine(ptr); - KJ_IF_MAYBE(connectionHeaderId, addNoCheck(*name, line)) { - // Parsed a connection header. - switch (*connectionHeaderId) { -#define HANDLE_HEADER(id, name) \ - case static_cast(ConnectionHeaderIndices::id): \ - connectionHeaders.id = line; \ - break; - KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER) -#undef HANDLE_HEADER - default: - KJ_UNREACHABLE; - } - } + addNoCheck(*name, line); } else { return false; } @@ -631,13 +1042,21 @@ bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connecti // ----------------------------------------------------------------------------- -kj::String HttpHeaders::serializeRequest(HttpMethod method, kj::StringPtr url, - const ConnectionHeaders& connectionHeaders) const { +kj::String HttpHeaders::serializeRequest( + HttpMethod method, kj::StringPtr url, + kj::ArrayPtr connectionHeaders) const { return serialize(kj::toCharSequence(method), url, kj::StringPtr("HTTP/1.1"), connectionHeaders); } -kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusText, - const ConnectionHeaders& connectionHeaders) const { +kj::String HttpHeaders::serializeConnectRequest( + kj::StringPtr authority, + kj::ArrayPtr connectionHeaders) const { + return serialize("CONNECT"_kj, authority, kj::StringPtr("HTTP/1.1"), connectionHeaders); +} + +kj::String HttpHeaders::serializeResponse( + uint statusCode, kj::StringPtr statusText, + kj::ArrayPtr connectionHeaders) const { auto statusCodeStr = kj::toCharSequence(statusCode); return serialize(kj::StringPtr("HTTP/1.1"), statusCodeStr, statusText, connectionHeaders); @@ -646,7 +1065,7 @@ kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusT kj::String HttpHeaders::serialize(kj::ArrayPtr word1, kj::ArrayPtr word2, kj::ArrayPtr word3, - const ConnectionHeaders& connectionHeaders) const { + kj::ArrayPtr connectionHeaders) const { const kj::StringPtr space = " "; const kj::StringPtr newline = "\r\n"; const kj::StringPtr colon = ": "; @@ -655,15 +1074,11 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr word1, if (word1 != nullptr) { size += word1.size() + word2.size() + word3.size() + 4; } -#define HANDLE_HEADER(id, name) \ - if (connectionHeaders.id != nullptr) { \ - size += connectionHeaders.id.size() + (sizeof(name) + 3); \ - } - KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER) -#undef HANDLE_HEADER + KJ_ASSERT(connectionHeaders.size() <= indexedHeaders.size()); for (auto i: kj::indices(indexedHeaders)) { - if (indexedHeaders[i] != nullptr) { - size += table->idToString(HttpHeaderId(table, i)).size() + indexedHeaders[i].size() + 4; + kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i]; + if (value != nullptr) { + size += table->idToString(HttpHeaderId(table, i)).size() + value.size() + 4; } } for (auto& header: unindexedHeaders) { @@ -676,16 +1091,10 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr word1, if (word1 != nullptr) { ptr = kj::_::fill(ptr, word1, space, word2, space, word3, newline); } -#define HANDLE_HEADER(id, name) \ - if (connectionHeaders.id != nullptr) { \ - ptr = kj::_::fill(ptr, kj::StringPtr(name), colon, connectionHeaders.id, newline); \ - } - KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER) -#undef HANDLE_HEADER for (auto i: kj::indices(indexedHeaders)) { - if (indexedHeaders[i] != nullptr) { - ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon, - indexedHeaders[i], newline); + kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i]; + if (value != nullptr) { + ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon, value, newline); } } for (auto& header: unindexedHeaders) { @@ -698,23 +1107,210 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr word1, } kj::String HttpHeaders::toString() const { - return serialize(nullptr, nullptr, nullptr, ConnectionHeaders()); + return serialize(nullptr, nullptr, nullptr, nullptr); } // ======================================================================================= namespace { +template +class WrappableStreamMixin { + // Both HttpInputStreamImpl and HttpOutputStream are commonly wrapped by a class that implements + // a particular type of body stream, such as a chunked body or a fixed-length body. That wrapper + // stream is passed back to the application to represent the specific request/response body, but + // the inner stream is associated with the connection and can be reused several times. + // + // It's easy for applications to screw up and hold on to a body stream beyond the lifetime of the + // underlying connection stream. This used to lead to UAF. This mixin class implements behavior + // that detached the wrapper if it outlives the wrapped stream, so that we log errors and + +public: + WrappableStreamMixin() = default; + WrappableStreamMixin(WrappableStreamMixin&& other) { + // This constructor is only needed by HttpServer::Connection::makeHttpInput() which constructs + // a new stream and returns it. Technically the constructor will always be elided anyway. + KJ_REQUIRE(other.currentWrapper == nullptr, "can't move a wrappable object that has wrappers!"); + } + KJ_DISALLOW_COPY(WrappableStreamMixin); + + ~WrappableStreamMixin() noexcept(false) { + KJ_IF_MAYBE(w, currentWrapper) { + KJ_LOG(ERROR, "HTTP connection destroyed while HTTP body streams still exist", + kj::getStackTrace()); + *w = nullptr; + } + } + + void setCurrentWrapper(kj::Maybe& weakRef) { + // Tracks the current `HttpEntityBodyReader` instance which is wrapping this stream. There can + // be only one wrapper at a time, and the wrapper must be destroyed before the underlying HTTP + // connection is torn down. The purpose of tracking the wrapper here is to detect when these + // rules are violated by apps, and log an error instead of going UB. + // + // `weakRef` is the wrapper's pointer to this object. If the underlying stream is destroyed + // before the wrapper, then `weakRef` will be nulled out. + + // The API should prevent an app from obtaining multiple wrappers with the same backing stream. + KJ_ASSERT(currentWrapper == nullptr, + "bug in KJ HTTP: only one HTTP stream wrapper can exist at a time"); + + currentWrapper = weakRef; + weakRef = static_cast(*this); + } + + void unsetCurrentWrapper(kj::Maybe& weakRef) { + auto& current = KJ_ASSERT_NONNULL(currentWrapper); + KJ_ASSERT(¤t == &weakRef, + "bug in KJ HTTP: unsetCurrentWrapper() passed the wrong wrapper"); + weakRef = nullptr; + currentWrapper = nullptr; + } + +private: + kj::Maybe&> currentWrapper; +}; + +// ======================================================================================= + static constexpr size_t MIN_BUFFER = 4096; -static constexpr size_t MAX_BUFFER = 65536; +static constexpr size_t MAX_BUFFER = 128 * 1024; static constexpr size_t MAX_CHUNK_HEADER_SIZE = 32; -class HttpInputStream { +class HttpInputStreamImpl final: public HttpInputStream, + public WrappableStreamMixin { +private: + static kj::OneOf getResumingRequest( + kj::OneOf method, + kj::StringPtr url) { + KJ_SWITCH_ONEOF(method) { + KJ_CASE_ONEOF(m, HttpMethod) { + return HttpHeaders::Request { m, url }; + } + KJ_CASE_ONEOF(m, HttpConnectMethod) { + return HttpHeaders::ConnectRequest { url }; + } + } + KJ_UNREACHABLE; + } public: - explicit HttpInputStream(AsyncIoStream& inner, HttpHeaderTable& table) + explicit HttpInputStreamImpl(AsyncInputStream& inner, const HttpHeaderTable& table) : inner(inner), headerBuffer(kj::heapArray(MIN_BUFFER)), headers(table) { } + explicit HttpInputStreamImpl(AsyncInputStream& inner, + kj::Array headerBufferParam, + kj::ArrayPtr leftoverParam, + kj::OneOf method, + kj::StringPtr url, + HttpHeaders headers) + : inner(inner), + headerBuffer(kj::mv(headerBufferParam)), + // Initialize `messageHeaderEnd` to a safe value, we'll adjust it below. + messageHeaderEnd(leftoverParam.begin() - headerBuffer.begin()), + leftover(leftoverParam), + headers(kj::mv(headers)), + resumingRequest(getResumingRequest(method, url)) { + // Constructor used for resuming a SuspendedRequest. + + // We expect headerBuffer to look like this: + // [CR] LF + // We initialized `messageHeaderEnd` to the beginning of `leftover`, but we want to point it at + // the CR (or LF if there's no CR). + KJ_REQUIRE(messageHeaderEnd >= 2 && leftover.end() <= headerBuffer.end(), + "invalid SuspendedRequest - leftover buffer not where it should be"); + KJ_REQUIRE(leftover.begin()[-1] == '\n', "invalid SuspendedRequest - missing LF"); + messageHeaderEnd -= 1 + (leftover.begin()[-2] == '\r'); + + // We're in the middle of a message, so set up our state as such. Note that the only way to + // resume a SuspendedRequest is via an HttpServer, but HttpServers never call + // `awaitNextMessage()` before fully reading request bodies, meaning we expect that + // `messageReadQueue` will never be used. + ++pendingMessageCount; + auto paf = kj::newPromiseAndFulfiller(); + onMessageDone = kj::mv(paf.fulfiller); + messageReadQueue = kj::mv(paf.promise); + } + + bool canReuse() { + return !broken && pendingMessageCount == 0; + } + + bool canSuspend() { + // We are at a suspendable point if we've parsed the headers, but haven't consumed anything + // beyond that. + // + // TODO(cleanup): This is a silly check; we need a more defined way to track the state of the + // stream. + bool messageHeaderEndLooksRight = + (leftover.begin() - (headerBuffer.begin() + messageHeaderEnd) == 2 && + leftover.begin()[-1] == '\n' && leftover.begin()[-2] == '\r') + || (leftover.begin() - (headerBuffer.begin() + messageHeaderEnd) == 1 && + leftover.begin()[-1] == '\n'); + + return !broken && headerBuffer.size() > 0 && messageHeaderEndLooksRight; + } + + // --------------------------------------------------------------------------- + // public interface + + kj::Promise readRequest() override { + return readRequestHeaders() + .then([this](HttpHeaders::RequestConnectOrProtocolError&& requestOrProtocolError) + -> HttpInputStream::Request { + auto request = KJ_REQUIRE_NONNULL( + requestOrProtocolError.tryGet(), "bad request"); + auto body = getEntityBody(HttpInputStreamImpl::REQUEST, request.method, 0, headers); + + return { request.method, request.url, headers, kj::mv(body) }; + }); + } + + kj::Promise> readRequestAllowingConnect() override { + return readRequestHeaders() + .then([this](HttpHeaders::RequestConnectOrProtocolError&& requestOrProtocolError) + -> kj::OneOf { + KJ_SWITCH_ONEOF(requestOrProtocolError) { + KJ_CASE_ONEOF(request, HttpHeaders::Request) { + auto body = getEntityBody(HttpInputStreamImpl::REQUEST, request.method, 0, headers); + return HttpInputStream::Request { request.method, request.url, headers, kj::mv(body) }; + } + KJ_CASE_ONEOF(request, HttpHeaders::ConnectRequest) { + auto body = getEntityBody(HttpInputStreamImpl::REQUEST, HttpConnectMethod(), 0, headers); + return HttpInputStream::Connect { request.authority, headers, kj::mv(body) }; + } + KJ_CASE_ONEOF(error, HttpHeaders::ProtocolError) { + KJ_FAIL_REQUIRE("bad request"); + } + } + KJ_UNREACHABLE; + }); + } + + kj::Promise readResponse(HttpMethod requestMethod) override { + return readResponseHeaders() + .then([this,requestMethod](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) + -> HttpInputStream::Response { + auto response = KJ_REQUIRE_NONNULL( + responseOrProtocolError.tryGet(), "bad response"); + auto body = getEntityBody(HttpInputStreamImpl::RESPONSE, requestMethod, + response.statusCode, headers); + + return { response.statusCode, response.statusText, headers, kj::mv(body) }; + }); + } + + kj::Promise readMessage() override { + return readMessageHeaders() + .then([this](kj::ArrayPtr text) -> HttpInputStream::Message { + headers.clear(); + KJ_REQUIRE(headers.tryParse(text), "bad message"); + auto body = getEntityBody(HttpInputStreamImpl::RESPONSE, HttpMethod::GET, 0, headers); + + return { headers, kj::mv(body) }; + }); + } + // --------------------------------------------------------------------------- // Stream locking: While an entity-body is being read, the body stream "locks" the underlying // HTTP stream. Once the entity-body is complete, we can read the next pipelined message. @@ -724,37 +1320,45 @@ public: KJ_REQUIRE_NONNULL(onMessageDone)->fulfill(); onMessageDone = nullptr; + --pendingMessageCount; } void abortRead() { // Called when a body input stream was destroyed without reading to the end. KJ_REQUIRE_NONNULL(onMessageDone)->reject(KJ_EXCEPTION(FAILED, - "client did not finish reading previous HTTP response body", - "can't read next pipelined response")); + "application did not finish reading previous HTTP response body", + "can't read next pipelined request/response")); onMessageDone = nullptr; + broken = true; } // --------------------------------------------------------------------------- - kj::Promise awaitNextMessage() { - // Waits until more data is available, but doesn't consume it. Only meant for server-side use, - // after a request is handled, to check for pipelined requests. Returns false on EOF. + kj::Promise awaitNextMessage() override { + // Waits until more data is available, but doesn't consume it. Returns false on EOF. + // + // Used on the server after a request is handled, to check for pipelined requests. + // + // Used on the client to detect when idle connections are closed from the server end. (In this + // case, the promise always returns false or is canceled.) - // Slightly-crappy code to snarf the expected line break. This will actually eat the leading - // regex /\r*\n?/. - while (lineBreakBeforeNextHeader && leftover.size() > 0) { - if (leftover[0] == '\r') { - leftover = leftover.slice(1, leftover.size()); - } else if (leftover[0] == '\n') { - leftover = leftover.slice(1, leftover.size()); - lineBreakBeforeNextHeader = false; - } else { - // Err, missing line break, whatever. - lineBreakBeforeNextHeader = false; - } + if (resumingRequest != nullptr) { + // We're resuming a request, so report that we have a message. + return true; + } + + if (onMessageDone != nullptr) { + // We're still working on reading the previous body. + auto fork = messageReadQueue.fork(); + messageReadQueue = fork.addBranch(); + return fork.addBranch().then([this]() { + return awaitNextMessage(); + }); } + snarfBufferedLineBreak(); + if (!lineBreakBeforeNextHeader && leftover != nullptr) { return true; } @@ -770,14 +1374,22 @@ public: }); } + bool isCleanDrain() { + // Returns whether we can cleanly drain the stream at this point. + if (onMessageDone != nullptr) return false; + snarfBufferedLineBreak(); + return !lineBreakBeforeNextHeader && leftover == nullptr; + } + kj::Promise> readMessageHeaders() { + ++pendingMessageCount; auto paf = kj::newPromiseAndFulfiller(); auto promise = messageReadQueue - .then(kj::mvCapture(paf.fulfiller, [this](kj::Own> fulfiller) { + .then([this,fulfiller=kj::mv(paf.fulfiller)]() mutable { onMessageDone = kj::mv(fulfiller); return readHeader(HeaderType::MESSAGE, 0, 0); - })); + }); messageReadQueue = kj::mv(paf.promise); @@ -801,9 +1413,8 @@ public: } else if ('A' <= c && c <= 'F') { value = value * 16 + (c - 'A' + 10); } else { - KJ_FAIL_REQUIRE("invalid HTTP chunk size", text, text.asBytes()) { - return value; - } + KJ_FAIL_REQUIRE("invalid HTTP chunk size", text, text.asBytes()) { break; } + return value; } } @@ -811,16 +1422,24 @@ public: }); } - inline kj::Promise> readRequestHeaders() { - headers.clear(); + inline kj::Promise readRequestHeaders() { + KJ_IF_MAYBE(resuming, resumingRequest) { + KJ_DEFER(resumingRequest = nullptr); + return HttpHeaders::RequestConnectOrProtocolError(*resuming); + } + return readMessageHeaders().then([this](kj::ArrayPtr text) { - return headers.tryParseRequest(text); + headers.clear(); + return headers.tryParseRequestOrConnect(text); }); } - inline kj::Promise> readResponseHeaders() { - headers.clear(); + inline kj::Promise readResponseHeaders() { + // Note: readResponseHeaders() could be called multiple times concurrently when pipelining + // requests. readMessageHeaders() will serialize these, but it's important not to mess with + // state (like calling headers.clear()) before said serialization has taken place. return readMessageHeaders().then([this](kj::ArrayPtr text) { + headers.clear(); return headers.tryParseResponse(text); }); } @@ -863,11 +1482,22 @@ public: }; kj::Own getEntityBody( - RequestOrResponse type, HttpMethod method, uint statusCode, - HttpHeaders::ConnectionHeaders& connectionHeaders); + RequestOrResponse type, + kj::OneOf method, + uint statusCode, + const kj::HttpHeaders& headers); + + struct ReleasedBuffer { + kj::Array buffer; + kj::ArrayPtr leftover; + }; + + ReleasedBuffer releaseBuffer() { + return { headerBuffer.releaseAsBytes(), leftover.asBytes() }; + } private: - AsyncIoStream& inner; + AsyncInputStream& inner; kj::Array headerBuffer; size_t messageHeaderEnd = 0; @@ -880,11 +1510,20 @@ private: HttpHeaders headers; // Parsed headers, after a call to parseAwaited*(). + kj::Maybe> resumingRequest; + // Non-null if we're resuming a SuspendedRequest. + bool lineBreakBeforeNextHeader = false; - // If true, the next await should expect to start with a spurrious '\n' or '\r\n'. This happens + // If true, the next await should expect to start with a spurious '\n' or '\r\n'. This happens // as a side-effect of HTTP chunked encoding, where such a newline is added to the end of each // chunk, for no good reason. + bool broken = false; + // Becomes true if the caller failed to read the whole entity-body before closing the stream. + + uint pendingMessageCount = 0; + // Number of reads we have queued up. + kj::Promise messageReadQueue = kj::READY_NOW; kj::Maybe>> onMessageDone; @@ -921,7 +1560,7 @@ private: readPromise = leftover.size(); leftover = nullptr; } else { - // Need to read more data from the unfderlying stream. + // Need to read more data from the underlying stream. if (bufferEnd == headerBuffer.size()) { // Out of buffer space. @@ -1034,59 +1673,104 @@ private: } }); } + + void snarfBufferedLineBreak() { + // Slightly-crappy code to snarf the expected line break. This will actually eat the leading + // regex /\r*\n?/. + while (lineBreakBeforeNextHeader && leftover.size() > 0) { + if (leftover[0] == '\r') { + leftover = leftover.slice(1, leftover.size()); + } else if (leftover[0] == '\n') { + leftover = leftover.slice(1, leftover.size()); + lineBreakBeforeNextHeader = false; + } else { + // Err, missing line break, whatever. + lineBreakBeforeNextHeader = false; + } + } + } }; // ----------------------------------------------------------------------------- class HttpEntityBodyReader: public kj::AsyncInputStream { public: - HttpEntityBodyReader(HttpInputStream& inner): inner(inner) {} + HttpEntityBodyReader(HttpInputStreamImpl& inner) { + inner.setCurrentWrapper(weakInner); + } ~HttpEntityBodyReader() noexcept(false) { if (!finished) { - inner.abortRead(); + KJ_IF_MAYBE(inner, weakInner) { + inner->unsetCurrentWrapper(weakInner); + inner->abortRead(); + } else { + // Since we're in a destructor, log an error instead of throwing. + KJ_LOG(ERROR, "HTTP body input stream outlived underlying connection", kj::getStackTrace()); + } } } protected: - HttpInputStream& inner; + HttpInputStreamImpl& getInner() { + KJ_IF_MAYBE(i, weakInner) { + return *i; + } else if (finished) { + // This is a bug in the implementations in this file, not the app. + KJ_FAIL_ASSERT("bug in KJ HTTP: tried to access inner stream after it had been released"); + } else { + KJ_FAIL_REQUIRE("HTTP body input stream outlived underlying connection"); + } + } void doneReading() { - KJ_REQUIRE(!finished); + auto& inner = getInner(); + inner.unsetCurrentWrapper(weakInner); finished = true; inner.finishRead(); } - inline bool alreadyDone() { return finished; } + inline bool alreadyDone() { return weakInner == nullptr; } private: + kj::Maybe weakInner; bool finished = false; }; class HttpNullEntityReader final: public HttpEntityBodyReader { - // Stream which reads until EOF. + // Stream for an entity-body which is not present. Always returns EOF on read, but tryGetLength() + // may indicate non-zero in the special case of a response to a HEAD request. public: - HttpNullEntityReader(HttpInputStream& inner) - : HttpEntityBodyReader(inner) { + HttpNullEntityReader(HttpInputStreamImpl& inner, kj::Maybe length) + : HttpEntityBodyReader(inner), length(length) { + // `length` is what to return from tryGetLength(). For a response to a HEAD request, this may + // be non-zero. doneReading(); } Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return size_t(0); + return constPromise(); + } + + Maybe tryGetLength() override { + return length; } + +private: + kj::Maybe length; }; class HttpConnectionCloseEntityReader final: public HttpEntityBodyReader { // Stream which reads until EOF. public: - HttpConnectionCloseEntityReader(HttpInputStream& inner) + HttpConnectionCloseEntityReader(HttpInputStreamImpl& inner) : HttpEntityBodyReader(inner) {} Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - if (alreadyDone()) return size_t(0); + if (alreadyDone()) return constPromise(); - return inner.tryRead(buffer, minBytes, maxBytes) + return getInner().tryRead(buffer, minBytes, maxBytes) .then([=](size_t amount) { if (amount < minBytes) { doneReading(); @@ -1100,7 +1784,7 @@ class HttpFixedLengthEntityReader final: public HttpEntityBodyReader { // Stream which reads only up to a fixed length from the underlying stream, then emulates EOF. public: - HttpFixedLengthEntityReader(HttpInputStream& inner, size_t length) + HttpFixedLengthEntityReader(HttpInputStreamImpl& inner, size_t length) : HttpEntityBodyReader(inner), length(length) { if (length == 0) doneReading(); } @@ -1110,46 +1794,72 @@ public: } Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - if (length == 0) return size_t(0); + KJ_REQUIRE(clean, "can't read more data after a previous read didn't complete"); + clean = false; + return tryReadInternal(buffer, minBytes, maxBytes, 0); + } - return inner.tryRead(buffer, kj::min(minBytes, length), kj::min(maxBytes, length)) - .then([=](size_t amount) { +private: + size_t length; + bool clean = true; + + Promise tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, + size_t alreadyRead) { + if (length == 0) { + clean = true; + return constPromise(); + } + + // We have to set minBytes to 1 here so that if we read any data at all, we update our + // counter immediately, so that we still know where we are in case of cancellation. + return getInner().tryRead(buffer, 1, kj::min(maxBytes, length)) + .then([=](size_t amount) -> kj::Promise { length -= amount; - if (length > 0 && amount < minBytes) { - kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, - "premature EOF in HTTP entity body; did not reach Content-Length")); + if (length > 0) { + // We haven't reached the end of the entity body yet. + if (amount == 0) { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, + "premature EOF in HTTP entity body; did not reach Content-Length")); + } else if (amount < minBytes) { + // We requested a minimum 1 byte above, but our own caller actually set a larger minimum + // which has not yet been reached. Keep trying until we reach it. + return tryReadInternal(reinterpret_cast(buffer) + amount, + minBytes - amount, maxBytes - amount, alreadyRead + amount); + } } else if (length == 0) { doneReading(); } - return amount; + clean = true; + return amount + alreadyRead; }); } - -private: - size_t length; }; class HttpChunkedEntityReader final: public HttpEntityBodyReader { // Stream which reads a Transfer-Encoding: Chunked stream. public: - HttpChunkedEntityReader(HttpInputStream& inner) + HttpChunkedEntityReader(HttpInputStreamImpl& inner) : HttpEntityBodyReader(inner) {} Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + KJ_REQUIRE(clean, "can't read more data after a previous read didn't complete"); + clean = false; return tryReadInternal(buffer, minBytes, maxBytes, 0); } private: size_t chunkSize = 0; + bool clean = true; Promise tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead) { if (alreadyDone()) { + clean = true; return alreadyRead; } else if (chunkSize == 0) { // Read next chunk header. - return inner.readChunkHeader().then([=](uint64_t nextChunkSize) { + return getInner().readChunkHeader().then([=](uint64_t nextChunkSize) { if (nextChunkSize == 0) { doneReading(); } @@ -1157,23 +1867,22 @@ private: chunkSize = nextChunkSize; return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead); }); - } else if (chunkSize < minBytes) { - // Read entire current chunk and continue to next chunk. - return inner.tryRead(buffer, chunkSize, chunkSize) + } else { + // Read current chunk. + // We have to set minBytes to 1 here so that if we read any data at all, we update our + // counter immediately, so that we still know where we are in case of cancellation. + return getInner().tryRead(buffer, 1, kj::min(maxBytes, chunkSize)) .then([=](size_t amount) -> kj::Promise { chunkSize -= amount; - if (chunkSize > 0) { - return KJ_EXCEPTION(DISCONNECTED, "premature EOF in HTTP chunk"); + if (amount == 0) { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "premature EOF in HTTP chunk")); + } else if (amount < minBytes) { + // We requested a minimum 1 byte above, but our own caller actually set a larger minimum + // which has not yet been reached. Keep trying until we reach it. + return tryReadInternal(reinterpret_cast(buffer) + amount, + minBytes - amount, maxBytes - amount, alreadyRead + amount); } - - return tryReadInternal(reinterpret_cast(buffer) + amount, - minBytes - amount, maxBytes - amount, alreadyRead + amount); - }); - } else { - // Read only part of the current chunk. - return inner.tryRead(buffer, minBytes, kj::min(maxBytes, chunkSize)) - .then([=](size_t amount) -> size_t { - chunkSize -= amount; + clean = true; return alreadyRead + amount; }); } @@ -1187,10 +1896,8 @@ template struct FastCaseCmp { static constexpr bool apply(const char* actual) { return - 'a' <= first && first <= 'z' - ? (*actual | 0x20) == first && FastCaseCmp::apply(actual + 1) - : 'A' <= first && first <= 'Z' - ? (*actual & ~0x20) == first && FastCaseCmp::apply(actual + 1) + ('a' <= first && first <= 'z') || ('A' <= first && first <= 'Z') + ? (*actual | 0x20) == (first | 0x20) && FastCaseCmp::apply(actual + 1) : *actual == first && FastCaseCmp::apply(actual + 1); } }; @@ -1214,56 +1921,144 @@ static_assert(!fastCaseCmp<'n','O','o','B','1'>("FooB1"), ""); static_assert(!fastCaseCmp<'f','O','o','B'>("FooB1"), ""); static_assert(!fastCaseCmp<'f','O','o','B','1','a'>("FooB1"), ""); -kj::Own HttpInputStream::getEntityBody( - RequestOrResponse type, HttpMethod method, uint statusCode, - HttpHeaders::ConnectionHeaders& connectionHeaders) { - if (type == RESPONSE && (method == HttpMethod::HEAD || - statusCode == 204 || statusCode == 205 || statusCode == 304)) { - // No body. - return kj::heap(*this); +kj::Own HttpInputStreamImpl::getEntityBody( + RequestOrResponse type, + kj::OneOf method, + uint statusCode, + const kj::HttpHeaders& headers) { + KJ_REQUIRE(headerBuffer.size() > 0, "Cannot get entity body after header buffer release."); + + auto isHeadRequest = method.tryGet().map([](auto& m) { + return m == HttpMethod::HEAD; + }).orDefault(false); + + auto isConnectRequest = method.is(); + + // Rules to determine how HTTP entity-body is delimited: + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // #1 + if (type == RESPONSE) { + if (isHeadRequest) { + // Body elided. + kj::Maybe length; + KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) { + length = strtoull(cl->cStr(), nullptr, 10); + } else if (headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr) { + // HACK: Neither Content-Length nor Transfer-Encoding header in response to HEAD + // request. Propagate this fact with a 0 expected body length. + length = uint64_t(0); + } + return kj::heap(*this, length); + } else if (isConnectRequest && statusCode >= 200 && statusCode < 300) { + KJ_FAIL_ASSERT("a CONNECT response with a 2xx status does not have an entity body to get"); + } else if (statusCode == 204 || statusCode == 304) { + // No body. + return kj::heap(*this, uint64_t(0)); + } } - if (connectionHeaders.transferEncoding != nullptr) { - // TODO(someday): Support plugable transfer encodings? Or at least gzip? - // TODO(0.7): Support stacked transfer encodings, e.g. "gzip, chunked". - if (fastCaseCmp<'c','h','u','n','k','e','d'>(connectionHeaders.transferEncoding.cStr())) { + // For CONNECT requests messages, we let the rest of the logic play out. + // We already check before here to ensure that Transfer-Encoding and + // Content-Length headers are not present in which case the code below + // does the right thing. + + // #3 + KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) { + // TODO(someday): Support pluggable transfer encodings? Or at least gzip? + // TODO(someday): Support stacked transfer encodings, e.g. "gzip, chunked". + + // NOTE: #3¶3 is ambiguous about what should happen if Transfer-Encoding and Content-Length are + // both present. It says that Transfer-Encoding takes precedence, but also that the request + // "ought to be handled as an error", and that proxies "MUST" drop the Content-Length before + // forwarding. We ignore the vague "ought to" part and implement the other two. (The + // dropping of Content-Length will happen naturally if/when the message is sent back out to + // the network.) + if (fastCaseCmp<'c','h','u','n','k','e','d'>(te->cStr())) { + // #3¶1 return kj::heap(*this); - } else { - KJ_FAIL_REQUIRE("unknown transfer encoding") { break; } + } else if (fastCaseCmp<'i','d','e','n','t','i','t','y'>(te->cStr())) { + // #3¶2 + KJ_REQUIRE(type != REQUEST, "request body cannot have Transfer-Encoding other than chunked"); + return kj::heap(*this); } + + KJ_FAIL_REQUIRE("unknown transfer encoding", *te) { break; }; } - if (connectionHeaders.contentLength != nullptr) { - return kj::heap(*this, - strtoull(connectionHeaders.contentLength.cStr(), nullptr, 10)); + // #4 and #5 + KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) { + // NOTE: By spec, multiple Content-Length values are allowed as long as they are the same, e.g. + // "Content-Length: 5, 5, 5". Hopefully no one actually does that... + char* end; + uint64_t length = strtoull(cl->cStr(), &end, 10); + if (end > cl->begin() && *end == '\0') { + // #5 + return kj::heap(*this, length); + } else { + // #4 (bad content-length) + KJ_FAIL_REQUIRE("invalid Content-Length header value", *cl); + } } + // #6 if (type == REQUEST) { // Lack of a Content-Length or Transfer-Encoding means no body for requests. - return kj::heap(*this); + return kj::heap(*this, uint64_t(0)); } - if (connectionHeaders.connection != nullptr) { - // TODO(0.7): Connection header can actually have multiple tokens... but no one ever uses - // that feature? - if (fastCaseCmp<'c','l','o','s','e'>(connectionHeaders.connection.cStr())) { - return kj::heap(*this); + // RFC 2616 permitted "multipart/byteranges" responses to be self-delimiting, but this was + // mercifully removed in RFC 7230, and new exceptions of this type are disallowed: + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.4 + // https://tools.ietf.org/html/rfc7230#page-81 + // To be extra-safe, we'll reject a multipart/byteranges response that lacks transfer-encoding + // and content-length. + KJ_IF_MAYBE(type, headers.get(HttpHeaderId::CONTENT_TYPE)) { + if (type->startsWith("multipart/byteranges")) { + KJ_FAIL_REQUIRE( + "refusing to handle multipart/byteranges response without transfer-encoding nor " + "content-length due to ambiguity between RFC 2616 vs RFC 7230."); } } - KJ_FAIL_REQUIRE("don't know how HTTP body is delimited", headers); - return kj::heap(*this); + // #7 + return kj::heap(*this); +} + +} // namespace + +kj::Own newHttpInputStream( + kj::AsyncInputStream& input, const HttpHeaderTable& table) { + return kj::heap(input, table); } // ======================================================================================= -class HttpOutputStream { +namespace { + +class HttpOutputStream: public WrappableStreamMixin { public: HttpOutputStream(AsyncOutputStream& inner): inner(inner) {} + bool isInBody() { + return inBody; + } + + bool canReuse() { + return !inBody && !broken && !writeInProgress; + } + + bool canWriteBodyData() { + return !writeInProgress && inBody; + } + + bool isBroken() { + return broken; + } + void writeHeaders(String content) { // Writes some header content and begins a new entity body. + KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return; } KJ_REQUIRE(!inBody, "previous HTTP message body incomplete; can't write more messages"); inBody = true; @@ -1271,42 +2066,56 @@ public: } void writeBodyData(kj::String content) { + KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return; } KJ_REQUIRE(inBody) { return; } queueWrite(kj::mv(content)); } kj::Promise writeBodyData(const void* buffer, size_t size) { + KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return kj::READY_NOW; } KJ_REQUIRE(inBody) { return kj::READY_NOW; } - auto fork = writeQueue.then([this,buffer,size]() { - return inner.write(buffer, size); - }).fork(); - + writeInProgress = true; + auto fork = writeQueue.fork(); writeQueue = fork.addBranch(); - return fork.addBranch(); + + return fork.addBranch().then([this,buffer,size]() { + return inner.write(buffer, size); + }).then([this]() { + writeInProgress = false; + }); } kj::Promise writeBodyData(kj::ArrayPtr> pieces) { + KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return kj::READY_NOW; } KJ_REQUIRE(inBody) { return kj::READY_NOW; } - auto fork = writeQueue.then([this,pieces]() { - return inner.write(pieces); - }).fork(); - + writeInProgress = true; + auto fork = writeQueue.fork(); writeQueue = fork.addBranch(); - return fork.addBranch(); + + return fork.addBranch().then([this,pieces]() { + return inner.write(pieces); + }).then([this]() { + writeInProgress = false; + }); } Promise pumpBodyFrom(AsyncInputStream& input, uint64_t amount) { + KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return uint64_t(0); } KJ_REQUIRE(inBody) { return uint64_t(0); } - auto fork = writeQueue.then([this,&input,amount]() { - return input.pumpTo(inner, amount); - }).fork(); + writeInProgress = true; + auto fork = writeQueue.fork(); + writeQueue = fork.addBranch(); - writeQueue = fork.addBranch().ignoreResult(); - return fork.addBranch(); + return fork.addBranch().then([this,&input,amount]() { + return input.pumpTo(inner, amount); + }).then([this](uint64_t actual) { + writeInProgress = false; + return actual; + }); } void finishBody() { @@ -1314,17 +2123,27 @@ public: KJ_REQUIRE(inBody) { return; } inBody = false; + + if (writeInProgress) { + // It looks like the last write never completed -- possibly because it was canceled or threw + // an exception. We must treat this equivalent to abortBody(). + broken = true; + + // Cancel any writes that are still queued. + writeQueue = KJ_EXCEPTION(FAILED, + "previous HTTP message body incomplete; can't write more messages"); + } } void abortBody() { // Called if the application failed to write all expected body bytes. KJ_REQUIRE(inBody) { return; } inBody = false; + broken = true; - writeQueue = writeQueue.then([]() -> kj::Promise { - return KJ_EXCEPTION(FAILED, - "previous HTTP message body incomplete; can't write more messages"); - }); + // Cancel any writes that are still queued. + writeQueue = KJ_EXCEPTION(FAILED, + "previous HTTP message body incomplete; can't write more messages"); } kj::Promise flush() { @@ -1333,20 +2152,84 @@ public: return fork.addBranch(); } + Promise whenWriteDisconnected() { + return inner.whenWriteDisconnected(); + } + + bool isWriteInProgress() { return writeInProgress; } + private: AsyncOutputStream& inner; kj::Promise writeQueue = kj::READY_NOW; bool inBody = false; + bool broken = false; + + bool writeInProgress = false; + // True if a write method has been called and has not completed successfully. In the case that + // a write throws an exception or is canceled, this remains true forever. In these cases, the + // underlying stream is in an inconsistent state and cannot be reused. void queueWrite(kj::String content) { - writeQueue = writeQueue.then(kj::mvCapture(content, [this](kj::String&& content) { + // We only use queueWrite() in cases where we can take ownership of the write buffer, and where + // it is convenient if we can return `void` rather than a promise. In particular, this is used + // to write headers and chunk boundaries. Writes of application data do not go into + // `writeQueue` because this would prevent cancellation. Instead, they wait until `writeQueue` + // is empty, then they make the write directly, using `writeInProgress` to detect and block + // concurrent writes. + + writeQueue = writeQueue.then([this,content=kj::mv(content)]() mutable { auto promise = inner.write(content.begin(), content.size()); return promise.attach(kj::mv(content)); - })); + }); + } +}; + +class HttpEntityBodyWriter: public kj::AsyncOutputStream { +public: + HttpEntityBodyWriter(HttpOutputStream& inner) { + inner.setCurrentWrapper(weakInner); + } + ~HttpEntityBodyWriter() noexcept(false) { + if (!finished) { + KJ_IF_MAYBE(inner, weakInner) { + inner->unsetCurrentWrapper(weakInner); + inner->abortBody(); + } else { + // Since we're in a destructor, log an error instead of throwing. + KJ_LOG(ERROR, "HTTP body output stream outlived underlying connection", + kj::getStackTrace()); + } + } + } + +protected: + HttpOutputStream& getInner() { + KJ_IF_MAYBE(i, weakInner) { + return *i; + } else if (finished) { + // This is a bug in the implementations in this file, not the app. + KJ_FAIL_ASSERT("bug in KJ HTTP: tried to access inner stream after it had been released"); + } else { + KJ_FAIL_REQUIRE("HTTP body output stream outlived underlying connection"); + } + } + + void doneWriting() { + auto& inner = getInner(); + inner.unsetCurrentWrapper(weakInner); + finished = true; + inner.finishBody(); } + + inline bool alreadyDone() { return weakInner == nullptr; } + +private: + kj::Maybe weakInner; + bool finished = false; }; class HttpNullEntityWriter final: public kj::AsyncOutputStream { + // Does not inherit HttpEntityBodyWriter because it doesn't actually write anything. public: Promise write(const void* buffer, size_t size) override { return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()"); @@ -1354,9 +2237,13 @@ public: Promise write(ArrayPtr> pieces) override { return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()"); } + Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } }; class HttpDiscardingEntityWriter final: public kj::AsyncOutputStream { + // Does not inherit HttpEntityBodyWriter because it doesn't actually write anything. public: Promise write(const void* buffer, size_t size) override { return kj::READY_NOW; @@ -1364,64 +2251,114 @@ public: Promise write(ArrayPtr> pieces) override { return kj::READY_NOW; } + Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } }; -class HttpFixedLengthEntityWriter final: public kj::AsyncOutputStream { +class HttpFixedLengthEntityWriter final: public HttpEntityBodyWriter { public: HttpFixedLengthEntityWriter(HttpOutputStream& inner, uint64_t length) - : inner(inner), length(length) {} - ~HttpFixedLengthEntityWriter() noexcept(false) { - if (length > 0) inner.abortBody(); + : HttpEntityBodyWriter(inner), length(length) { + if (length == 0) doneWriting(); } Promise write(const void* buffer, size_t size) override { + if (size == 0) return kj::READY_NOW; KJ_REQUIRE(size <= length, "overwrote Content-Length"); length -= size; - return maybeFinishAfter(inner.writeBodyData(buffer, size)); + return maybeFinishAfter(getInner().writeBodyData(buffer, size)); } Promise write(ArrayPtr> pieces) override { uint64_t size = 0; for (auto& piece: pieces) size += piece.size(); + if (size == 0) return kj::READY_NOW; KJ_REQUIRE(size <= length, "overwrote Content-Length"); length -= size; - return maybeFinishAfter(inner.writeBodyData(pieces)); + return maybeFinishAfter(getInner().writeBodyData(pieces)); } Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { - KJ_REQUIRE(amount <= length, "overwrote Content-Length"); + if (amount == 0) return constPromise(); + + bool overshot = amount > length; + if (overshot) { + // Hmm, the requested amount was too large, but it's common to specify kj::max as the amount + // to pump, in which case we pump to EOF. Let's try to verify whether EOF is where we + // expect it to be. + KJ_IF_MAYBE(available, input.tryGetLength()) { + // Great, the stream knows how large it is. If it's indeed larger than the space available + // then let's abort. + KJ_REQUIRE(*available <= length, "overwrote Content-Length"); + } else { + // OK, we have no idea how large the input is, so we'll have to check later. + } + } + + amount = kj::min(amount, length); length -= amount; - return inner.pumpBodyFrom(input, amount).then([this,amount](uint64_t actual) { + auto promise = amount == 0 + ? kj::Promise(amount) + : getInner().pumpBodyFrom(input, amount).then([this,amount](uint64_t actual) { // Adjust for bytes not written. length += amount - actual; - if (length == 0) inner.finishBody(); + if (length == 0) doneWriting(); return actual; }); + + if (overshot) { + promise = promise.then([amount,&input](uint64_t actual) -> kj::Promise { + if (actual == amount) { + // We read exactly the amount expected. In order to detect an overshoot, we have to + // try reading one more byte. Ugh. + static byte junk; + return input.tryRead(&junk, 1, 1).then([actual](size_t extra) { + KJ_REQUIRE(extra == 0, "overwrote Content-Length"); + return actual; + }); + } else { + // We actually read less data than requested so we couldn't have overshot. In fact, we + // undershot. + return actual; + } + }); + } + + return kj::mv(promise); + } + + Promise whenWriteDisconnected() override { + return getInner().whenWriteDisconnected(); } private: - HttpOutputStream& inner; uint64_t length; kj::Promise maybeFinishAfter(kj::Promise promise) { if (length == 0) { - return promise.then([this]() { inner.finishBody(); }); + return promise.then([this]() { doneWriting(); }); } else { return kj::mv(promise); } } }; -class HttpChunkedEntityWriter final: public kj::AsyncOutputStream { +class HttpChunkedEntityWriter final: public HttpEntityBodyWriter { public: HttpChunkedEntityWriter(HttpOutputStream& inner) - : inner(inner) {} + : HttpEntityBodyWriter(inner) {} ~HttpChunkedEntityWriter() noexcept(false) { - inner.writeBodyData(kj::str("0\r\n\r\n")); - inner.finishBody(); + if (!alreadyDone()) { + auto& inner = getInner(); + if (inner.canWriteBodyData()) { + inner.writeBodyData(kj::str("0\r\n\r\n")); + doneWriting(); + } + } } Promise write(const void* buffer, size_t size) override { @@ -1433,7 +2370,7 @@ public: parts[1] = kj::arrayPtr(reinterpret_cast(buffer), size); parts[2] = kj::StringPtr("\r\n").asBytes(); - auto promise = inner.writeBodyData(parts.asPtr()); + auto promise = getInner().writeBodyData(parts.asPtr()); return promise.attach(kj::mv(header), kj::mv(parts)); } @@ -1443,8 +2380,8 @@ public: if (size == 0) return kj::READY_NOW; // can't encode zero-size chunk since it indicates EOF. - auto header = kj::str(size, "\r\n"); - auto partsBuilder = kj::heapArrayBuilder>(pieces.size()); + auto header = kj::str(kj::hex(size), "\r\n"); + auto partsBuilder = kj::heapArrayBuilder>(pieces.size() + 2); partsBuilder.add(header.asBytes()); for (auto& piece: pieces) { partsBuilder.add(piece); @@ -1452,19 +2389,21 @@ public: partsBuilder.add(kj::StringPtr("\r\n").asBytes()); auto parts = partsBuilder.finish(); - auto promise = inner.writeBodyData(parts.asPtr()); + auto promise = getInner().writeBodyData(parts.asPtr()); return promise.attach(kj::mv(header), kj::mv(parts)); } Maybe> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { - KJ_IF_MAYBE(length, input.tryGetLength()) { + KJ_IF_MAYBE(l, input.tryGetLength()) { // Hey, we know exactly how large the input is, so we can write just one chunk. - inner.writeBodyData(kj::str(*length, "\r\n")); - auto lengthValue = *length; - return inner.pumpBodyFrom(input, *length) - .then([this,lengthValue](uint64_t actual) { - if (actual < lengthValue) { + uint64_t length = kj::min(amount, *l); + auto& inner = getInner(); + inner.writeBodyData(kj::str(kj::hex(length), "\r\n")); + return inner.pumpBodyFrom(input, length) + .then([this,length](uint64_t actual) { + auto& inner = getInner(); + if (actual < length) { inner.abortBody(); KJ_FAIL_REQUIRE( "value returned by input.tryGetLength() was greater than actual bytes transferred") { @@ -1481,262 +2420,5118 @@ public: } } -private: - HttpOutputStream& inner; + Promise whenWriteDisconnected() override { + return getInner().whenWriteDisconnected(); + } }; // ======================================================================================= -class HttpClientImpl final: public HttpClient { +class WebSocketImpl final: public WebSocket, private WebSocketErrorHandler { public: - HttpClientImpl(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& rawStream) - : httpInput(rawStream, responseHeaderTable), - httpOutput(rawStream) {} + WebSocketImpl(kj::Own stream, + kj::Maybe maskKeyGenerator, + kj::Maybe compressionConfigParam = nullptr, + kj::Maybe errorHandler = nullptr, + kj::Array buffer = kj::heapArray(4096), + kj::ArrayPtr leftover = nullptr, + kj::Maybe> waitBeforeSend = nullptr) + : stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator), + compressionConfig(kj::mv(compressionConfigParam)), + errorHandler(errorHandler.orDefault(*this)), + sendingPong(kj::mv(waitBeforeSend)), + recvBuffer(kj::mv(buffer)), recvData(leftover) { +#if KJ_HAS_ZLIB + KJ_IF_MAYBE(config, compressionConfig) { + compressionContext.emplace(ZlibContext::Mode::COMPRESS, *config); + decompressionContext.emplace(ZlibContext::Mode::DECOMPRESS, *config); + } +#else + KJ_REQUIRE(compressionConfig == nullptr, + "WebSocket compression is only supported if KJ is compiled with Zlib."); +#endif // KJ_HAS_ZLIB + } - Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, - kj::Maybe expectedBodySize = nullptr) override { - HttpHeaders::ConnectionHeaders connectionHeaders; - kj::String lengthStr; + kj::Promise send(kj::ArrayPtr message) override { + return sendImpl(OPCODE_BINARY, message); + } - if (method == HttpMethod::GET || method == HttpMethod::HEAD) { - // No entity-body. - } else KJ_IF_MAYBE(s, expectedBodySize) { - lengthStr = kj::str(*s); - connectionHeaders.contentLength = lengthStr; + kj::Promise send(kj::ArrayPtr message) override { + return sendImpl(OPCODE_TEXT, message.asBytes()); + } + + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + kj::Array payload; + if (code == 1005) { + KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason"); + + // code 1005 -- leave payload empty } else { - connectionHeaders.transferEncoding = "chunked"; + payload = heapArray(reason.size() + 2); + payload[0] = code >> 8; + payload[1] = code; + memcpy(payload.begin() + 2, reason.begin(), reason.size()); } - httpOutput.writeHeaders(headers.serializeRequest(method, url, connectionHeaders)); + auto promise = sendImpl(OPCODE_CLOSE, payload); + return promise.attach(kj::mv(payload)); + } - kj::Own bodyStream; - if (method == HttpMethod::GET || method == HttpMethod::HEAD) { - // No entity-body. - httpOutput.finishBody(); - bodyStream = heap(); - } else KJ_IF_MAYBE(s, expectedBodySize) { - bodyStream = heap(httpOutput, *s); - } else { - bodyStream = heap(httpOutput); + kj::Promise disconnect() override { + KJ_REQUIRE(!currentlySending, "another message send is already in progress"); + + KJ_IF_MAYBE(p, sendingPong) { + // We recently sent a pong, make sure it's finished before proceeding. + currentlySending = true; + auto promise = p->then([this]() { + currentlySending = false; + return disconnect(); + }); + sendingPong = nullptr; + return promise; } - auto responsePromise = httpInput.readResponseHeaders() - .then([this,method](kj::Maybe&& response) -> HttpClient::Response { - KJ_IF_MAYBE(r, response) { - return { - r->statusCode, - r->statusText, - &httpInput.getHeaders(), - httpInput.getEntityBody(HttpInputStream::RESPONSE, method, r->statusCode, - r->connectionHeaders) - }; - } else { - KJ_FAIL_REQUIRE("received invalid HTTP response") { break; } - return HttpClient::Response(); - } - }); + disconnected = true; - return { kj::mv(bodyStream), kj::mv(responsePromise) }; + stream->shutdownWrite(); + return kj::READY_NOW; } -private: - HttpInputStream httpInput; - HttpOutputStream httpOutput; -}; - -} // namespace + void abort() override { + queuedPong = nullptr; + sendingPong = nullptr; + disconnected = true; + stream->abortRead(); + stream->shutdownWrite(); + } -kj::Promise HttpClient::openWebSocket( - kj::StringPtr url, const HttpHeaders& headers, kj::Own downstream) { - return request(HttpMethod::GET, url, headers, nullptr) - .response.then([](HttpClient::Response&& response) -> WebSocketResponse { - kj::OneOf, kj::Own> body; - body.init>(kj::mv(response.body)); + kj::Promise whenAborted() override { + return stream->whenWriteDisconnected(); + } - return { - response.statusCode, - response.statusText, - response.headers, - kj::mv(body) - }; - }); -} + kj::Promise receive(size_t maxSize) override { + size_t headerSize = Header::headerSize(recvData.begin(), recvData.size()); -kj::Promise> HttpClient::connect(kj::String host) { - KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpClient"); -} + if (headerSize > recvData.size()) { + if (recvData.begin() != recvBuffer.begin()) { + // Move existing data to front of buffer. + if (recvData.size() > 0) { + memmove(recvBuffer.begin(), recvData.begin(), recvData.size()); + } + recvData = recvBuffer.slice(0, recvData.size()); + } -kj::Own newHttpClient( - HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream) { - return kj::heap(responseHeaderTable, stream); -} + return stream->tryRead(recvData.end(), 1, recvBuffer.end() - recvData.end()) + .then([this,maxSize](size_t actual) -> kj::Promise { + receivedBytes += actual; + if (actual == 0) { + if (recvData.size() > 0) { + return KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in frame header"); + } else { + // It's incorrect for the WebSocket to disconnect without sending `Close`. + return KJ_EXCEPTION(DISCONNECTED, + "WebSocket disconnected between frames without sending `Close`."); + } + } -// ======================================================================================= + recvData = recvBuffer.slice(0, recvData.size() + actual); + return receive(maxSize); + }); + } -kj::Promise HttpService::openWebSocket( - kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) { - class EmptyStream final: public kj::AsyncInputStream { - public: - Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return size_t(0); + auto& recvHeader = *reinterpret_cast(recvData.begin()); + if (recvHeader.hasRsv2or3()) { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Received frame had RSV bits 2 or 3 set", + }); } - }; - auto requestBody = heap(); - auto promise = request(HttpMethod::GET, url, headers, *requestBody, response); - return promise.attach(kj::mv(requestBody)); -} + recvData = recvData.slice(headerSize, recvData.size()); -kj::Promise> HttpService::connect(kj::String host) { - KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService"); -} + size_t payloadLen = recvHeader.getPayloadLen(); + if (payloadLen > maxSize) { + return errorHandler.handleWebSocketProtocolError({ + 1009, kj::str("Message is too large: ", payloadLen, " > ", maxSize) + }); + } -class HttpServer::Connection final: private HttpService::Response { -public: - Connection(HttpServer& server, kj::AsyncIoStream& stream) - : server(server), - httpInput(stream, server.requestHeaderTable), - httpOutput(stream) { - ++server.connectionCount; - } - Connection(HttpServer& server, kj::Own&& stream) - : server(server), - httpInput(*stream, server.requestHeaderTable), - httpOutput(*stream), - ownStream(kj::mv(stream)) { - ++server.connectionCount; - } - ~Connection() noexcept(false) { - if (--server.connectionCount == 0) { - KJ_IF_MAYBE(f, server.zeroConnectionsFulfiller) { - f->get()->fulfill(); + auto opcode = recvHeader.getOpcode(); + bool isData = opcode < OPCODE_FIRST_CONTROL; + if (opcode == OPCODE_CONTINUATION) { + if (fragments.empty()) { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Unexpected continuation frame" + }); } - } - } - kj::Promise loop() { - // If the timeout promise finishes before the headers do, we kill the connection. - auto timeoutPromise = server.timer.afterDelay(server.settings.headerTimeout) - .then([this]() -> kj::Maybe { - timedOut = true; - return nullptr; - }); + opcode = fragmentOpcode; + } else if (isData) { + if (!fragments.empty()) { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Missing continuation frame" + }); + } + } - return httpInput.readRequestHeaders().exclusiveJoin(kj::mv(timeoutPromise)) - .then([this](kj::Maybe&& request) -> kj::Promise { - if (timedOut) { - return sendError(408, "Request Timeout", kj::str( - "ERROR: Your client took too long to send HTTP headers.")); + bool isFin = recvHeader.isFin(); + + kj::Array message; // space to allocate + byte* payloadTarget; // location into which to read payload (size is payloadLen) + kj::Maybe originalMaxSize; // maxSize from first `receive()` call + if (isFin) { + size_t amountToAllocate; + if (recvHeader.isCompressed() || fragmentCompressed) { + // Add 4 since we append 0x00 0x00 0xFF 0xFF to the tail of the payload. + // See: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2 + amountToAllocate = payloadLen + 4; + } else { + // Add space for NUL terminator when allocating text message. + amountToAllocate = payloadLen + (opcode == OPCODE_TEXT && isFin); } - KJ_IF_MAYBE(req, request) { - currentMethod = req->method; - auto body = httpInput.getEntityBody( - HttpInputStream::REQUEST, req->method, 0, req->connectionHeaders); + if (isData && !fragments.empty()) { + // Final frame of a fragmented message. Gather the fragments. + size_t offset = 0; + for (auto& fragment: fragments) offset += fragment.size(); + message = kj::heapArray(offset + amountToAllocate); + originalMaxSize = offset + maxSize; // gives us back the original maximum message size. + + offset = 0; + for (auto& fragment: fragments) { + memcpy(message.begin() + offset, fragment.begin(), fragment.size()); + offset += fragment.size(); + } + payloadTarget = message.begin() + offset; - // TODO(perf): If the client disconnects, should we cancel the response? Probably, to - // prevent permanent deadlock. It's slightly weird in that arguably the client should - // be able to shutdown the upstream but still wait on the downstream, but I believe many - // other HTTP servers do similar things. + fragments.clear(); + fragmentOpcode = 0; + fragmentCompressed = false; + } else { + // Single-frame message. + message = kj::heapArray(amountToAllocate); + originalMaxSize = maxSize; // gives us back the original maximum message size. + payloadTarget = message.begin(); + } + } else { + // Fragmented message, and this isn't the final fragment. + if (!isData) { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Received fragmented control frame" + }); + } - auto promise = server.service.request( - req->method, req->url, httpInput.getHeaders(), *body, *this); - return promise.attach(kj::mv(body)) - .then([this]() { return httpOutput.flush(); }) - .then([this]() -> kj::Promise { - // Response done. Await next request. + message = kj::heapArray(payloadLen); + payloadTarget = message.begin(); + if (fragments.empty()) { + // This is the first fragment, so set the opcode. + fragmentOpcode = opcode; + fragmentCompressed = recvHeader.isCompressed(); + } + } - if (currentMethod != nullptr) { - return sendError(500, "Internal Server Error", kj::str( - "ERROR: The HttpService did not generate a response.")); - } + Mask mask = recvHeader.getMask(); - if (server.draining) { - // Never mind, drain time. - return kj::READY_NOW; - } + auto handleMessage = + [this,opcode,payloadTarget,payloadLen,mask,isFin,maxSize,originalMaxSize,message=kj::mv(message)]() mutable + -> kj::Promise { + if (!mask.isZero()) { + mask.apply(kj::arrayPtr(payloadTarget, payloadLen)); + } - auto timeoutPromise = server.timer.afterDelay(server.settings.pipelineTimeout) - .then([]() { return false; }); - auto awaitPromise = httpInput.awaitNextMessage(); + if (!isFin) { + // Add fragment to the list and loop. + auto newMax = maxSize - message.size(); + fragments.add(kj::mv(message)); + return receive(newMax); + } - return timeoutPromise.exclusiveJoin(kj::mv(awaitPromise)) - .then([this](bool hasMore) -> kj::Promise { - if (hasMore) { - return loop(); - } else { - // In this case we assume the client has no more requests, so we simply close the - // connection. - return kj::READY_NOW; + switch (opcode) { + case OPCODE_CONTINUATION: + // Shouldn't get here; handled above. + KJ_UNREACHABLE; + case OPCODE_TEXT: +#if KJ_HAS_ZLIB + KJ_IF_MAYBE(config, compressionConfig) { + auto& decompressor = KJ_ASSERT_NONNULL(decompressionContext); + auto tail = message.slice(message.size() - 4, message.size()); + // Note that we added an additional 4 bytes to `message`s capacity to account for these + // extra bytes. See `amountToAllocate` in the if(recvHeader.isCompressed()) block above. + const byte tailBytes[] = {0x00, 0x00, 0xFF, 0xFF}; + memcpy(tail.begin(), tailBytes, sizeof(tailBytes)); + // We have to append 0x00 0x00 0xFF 0xFF to the message before inflating. + // See: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2 + if (config->inboundNoContextTakeover) { + // We must reset context on each message. + decompressor.reset(); + } + bool addNullTerminator = true; + // We want to add the null terminator when receiving a TEXT message. + auto decompressed = decompressor.processMessage(message, originalMaxSize, + addNullTerminator); + return Message(kj::String(decompressed.releaseAsChars())); + } +#endif // KJ_HAS_ZLIB + message.back() = '\0'; + return Message(kj::String(message.releaseAsChars())); + case OPCODE_BINARY: +#if KJ_HAS_ZLIB + KJ_IF_MAYBE(config, compressionConfig) { + auto& decompressor = KJ_ASSERT_NONNULL(decompressionContext); + auto tail = message.slice(message.size() - 4, message.size()); + // Note that we added an additional 4 bytes to `message`s capacity to account for these + // extra bytes. See `amountToAllocate` in the if(recvHeader.isCompressed()) block above. + const byte tailBytes[] = {0x00, 0x00, 0xFF, 0xFF}; + memcpy(tail.begin(), tailBytes, sizeof(tailBytes)); + // We have to append 0x00 0x00 0xFF 0xFF to the message before inflating. + // See: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2 + if (config->inboundNoContextTakeover) { + // We must reset context on each message. + decompressor.reset(); } + auto decompressed = decompressor.processMessage(message, originalMaxSize); + return Message(decompressed.releaseAsBytes()); + } +#endif // KJ_HAS_ZLIB + return Message(message.releaseAsBytes()); + case OPCODE_CLOSE: + if (message.size() < 2) { + return Message(Close { 1005, nullptr }); + } else { + uint16_t status = (static_cast(message[0]) << 8) + | (static_cast(message[1]) ); + return Message(Close { + status, kj::heapString(message.slice(2, message.size()).asChars()) + }); + } + case OPCODE_PING: + // Send back a pong. + queuePong(kj::mv(message)); + return receive(maxSize); + case OPCODE_PONG: + // Unsolicited pong. Ignore. + return receive(maxSize); + default: + return errorHandler.handleWebSocketProtocolError({ + 1002, kj::str("Unknown opcode ", opcode) }); - }); - } else { - // Bad request. - - return sendError(400, "Bad Request", kj::str( - "ERROR: The headers sent by your client were not valid.")); } - }).catch_([this](kj::Exception&& e) -> kj::Promise { - // Exception; report 500. + }; - if (currentMethod == nullptr) { - // Dang, already sent a partial response. Can't do anything else. - KJ_LOG(ERROR, "HttpService threw exception after generating a partial response", - "too late to report error to client", e); - return kj::READY_NOW; + if (payloadLen <= recvData.size()) { + // All data already received. + memcpy(payloadTarget, recvData.begin(), payloadLen); + recvData = recvData.slice(payloadLen, recvData.size()); + return handleMessage(); + } else { + // Need to read more data. + memcpy(payloadTarget, recvData.begin(), recvData.size()); + size_t remaining = payloadLen - recvData.size(); + auto promise = stream->tryRead(payloadTarget + recvData.size(), remaining, remaining) + .then([this, remaining](size_t amount) { + receivedBytes += amount; + if (amount < remaining) { + kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in message")); + } + }); + recvData = nullptr; + return promise.then(kj::mv(handleMessage)); + } + } + + kj::Maybe> tryPumpFrom(WebSocket& other) override { + KJ_IF_MAYBE(optOther, kj::dynamicDowncastIfAvailable(other)) { + // Both WebSockets are raw WebSockets, so we can pump the streams directly rather than read + // whole messages. + + if ((maskKeyGenerator == nullptr) == (optOther->maskKeyGenerator == nullptr)) { + // Oops, it appears that we either believe we are the client side of both sockets, or we + // are the server side of both sockets. Since clients must "mask" their outgoing frames but + // servers must *not* do so, we can't direct-pump. Sad. + return nullptr; } - if (e.getType() == kj::Exception::Type::OVERLOADED) { - return sendError(503, "Service Unavailable", kj::str( - "ERROR: The server is temporarily unable to handle your request. Details:\n\n", e)); - } else if (e.getType() == kj::Exception::Type::UNIMPLEMENTED) { - return sendError(501, "Not Implemented", kj::str( - "ERROR: The server does not implement this operation. Details:\n\n", e)); - } else if (e.getType() == kj::Exception::Type::DISCONNECTED) { - // How do we tell an HTTP client that there was a transient network error, and it should - // try again immediately? There's no HTTP status code for this (503 is meant for "try - // again later, not now"). Here's an idea: Don't send any response; just close the - // connection, so that it looks like the connection between the HTTP client and server - // was dropped. A good client should treat this exactly the way we want. - return kj::READY_NOW; + KJ_IF_MAYBE(config, compressionConfig) { + KJ_IF_MAYBE(otherConfig, optOther->compressionConfig) { + if (config->outboundMaxWindowBits != otherConfig->inboundMaxWindowBits || + config->inboundMaxWindowBits != otherConfig->outboundMaxWindowBits || + config->inboundNoContextTakeover!= otherConfig->outboundNoContextTakeover || + config->outboundNoContextTakeover!= otherConfig->inboundNoContextTakeover) { + // Compression configurations differ. + return nullptr; + } + } else { + // Only one websocket uses compression. + return nullptr; + } } else { - return sendError(500, "Internal Server Error", kj::str( - "ERROR: The server threw an exception. Details:\n\n", e)); + if (optOther->compressionConfig != nullptr) { + // Only one websocket uses compression. + return nullptr; + } } - }); - } + // Both websockets use compatible compression configurations so we can pump directly. + + // Check same error conditions as with sendImpl(). + KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); + KJ_REQUIRE(!currentlySending, "another message send is already in progress"); + currentlySending = true; + + // If the application chooses to pump messages out, but receives incoming messages normally + // with `receive()`, then we will receive pings and attempt to send pongs. But we can't + // safely insert a pong in the middle of a pumped stream. We kind of don't have a choice + // except to drop them on the floor, which is what will happen if we set `hasSentClose` true. + // Hopefully most apps that set up a pump do so in both directions at once, and so pings will + // flow through and pongs will flow back. + hasSentClose = true; + + return optOther->optimizedPumpTo(*this); + } + + return nullptr; + } + + uint64_t sentByteCount() override { return sentBytes; } + + uint64_t receivedByteCount() override { return receivedBytes; } + + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + if (maskKeyGenerator == nullptr) { + // `this` is the server side of a websocket. + if (ctx == ExtensionsContext::REQUEST) { + // The other WebSocket is (going to be) the client side of a WebSocket, i.e. this is a + // proxying pass-through scenario. Optimization is possible. Confusingly, we have to use + // generateExtensionResponse() (even though we're generating headers to be passed in a + // request) because this is the function that correctly maps our config's inbound/outbound + // to client/server. + KJ_IF_MAYBE(c, compressionConfig) { + return _::generateExtensionResponse(*c); + } else { + return kj::String(nullptr); // recommend no compression + } + } else { + // We're apparently arranging to pump from the server side of one WebSocket to the server + // side of another; i.e., we are a server, we have two clients, and we're trying to pump + // between them. We cannot optimize this case, because the masking requirements are + // different for client->server vs. server->client messages. Since we have to parse out + // the messages anyway there's no point in trying to match extensions, so return null. + return nullptr; + } + } else { + // `this` is the client side of a websocket. + if (ctx == ExtensionsContext::RESPONSE) { + // The other WebSocket is (going to be) the server side of a WebSocket, i.e. this is a + // proxying pass-through scenario. Optimization is possible. Confusingly, we have to use + // generateExtensionRequest() (even though we're generating headers to be passed in a + // response) because this is the function that correctly maps our config's inbound/outbound + // to server/client. + KJ_IF_MAYBE(c, compressionConfig) { + CompressionParameters arr[1]{*c}; + return _::generateExtensionRequest(arr); + } else { + return kj::String(nullptr); // recommend no compression + } + } else { + // We're apparently arranging to pump from the client side of one WebSocket to the client + // side of another; i.e., we are a client, we are connected to two servers, and we're + // trying to pump between them. We cannot optimize this case, because the masking + // requirements are different for client->server vs. server->client messages. Since we have + // to parse out the messages anyway there's no point in trying to match extensions, so + // return null. + return nullptr; + } + } + } + +private: + class Mask { + public: + Mask(): maskBytes { 0, 0, 0, 0 } {} + Mask(const byte* ptr) { memcpy(maskBytes, ptr, 4); } + + Mask(kj::Maybe generator) { + KJ_IF_MAYBE(g, generator) { + g->generate(maskBytes); + } else { + memset(maskBytes, 0, 4); + } + } + + void apply(kj::ArrayPtr bytes) const { + apply(bytes.begin(), bytes.size()); + } + + void copyTo(byte* output) const { + memcpy(output, maskBytes, 4); + } + + bool isZero() const { + return (maskBytes[0] | maskBytes[1] | maskBytes[2] | maskBytes[3]) == 0; + } + + private: + byte maskBytes[4]; + + void apply(byte* __restrict__ bytes, size_t size) const { + for (size_t i = 0; i < size; i++) { + bytes[i] ^= maskBytes[i % 4]; + } + } + }; + + class Header { + public: + kj::ArrayPtr compose(bool fin, bool compressed, byte opcode, uint64_t payloadLen, + Mask mask) { + bytes[0] = (fin ? FIN_MASK : 0) | (compressed ? RSV1_MASK : 0) | opcode; + // Note that we can only set the compressed bit on DATA frames. + bool hasMask = !mask.isZero(); + + size_t fill; + + if (payloadLen < 126) { + bytes[1] = (hasMask ? USE_MASK_MASK : 0) | payloadLen; + if (hasMask) { + mask.copyTo(bytes + 2); + fill = 6; + } else { + fill = 2; + } + } else if (payloadLen < 65536) { + bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 126; + bytes[2] = static_cast(payloadLen >> 8); + bytes[3] = static_cast(payloadLen ); + if (hasMask) { + mask.copyTo(bytes + 4); + fill = 8; + } else { + fill = 4; + } + } else { + bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 127; + bytes[2] = static_cast(payloadLen >> 56); + bytes[3] = static_cast(payloadLen >> 48); + bytes[4] = static_cast(payloadLen >> 40); + bytes[5] = static_cast(payloadLen >> 42); + bytes[6] = static_cast(payloadLen >> 24); + bytes[7] = static_cast(payloadLen >> 16); + bytes[8] = static_cast(payloadLen >> 8); + bytes[9] = static_cast(payloadLen ); + if (hasMask) { + mask.copyTo(bytes + 10); + fill = 14; + } else { + fill = 10; + } + } + + return arrayPtr(bytes, fill); + } + + bool isFin() const { + return bytes[0] & FIN_MASK; + } + + bool isCompressed() const { + return bytes[0] & RSV1_MASK; + } + + bool hasRsv2or3() const { + return bytes[0] & RSV2_3_MASK; + } + + byte getOpcode() const { + return bytes[0] & OPCODE_MASK; + } + + uint64_t getPayloadLen() const { + byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK; + if (payloadLen == 127) { + return (static_cast(bytes[2]) << 56) + | (static_cast(bytes[3]) << 48) + | (static_cast(bytes[4]) << 40) + | (static_cast(bytes[5]) << 32) + | (static_cast(bytes[6]) << 24) + | (static_cast(bytes[7]) << 16) + | (static_cast(bytes[8]) << 8) + | (static_cast(bytes[9]) ); + } else if (payloadLen == 126) { + return (static_cast(bytes[2]) << 8) + | (static_cast(bytes[3]) ); + } else { + return payloadLen; + } + } + + Mask getMask() const { + if (bytes[1] & USE_MASK_MASK) { + byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK; + if (payloadLen == 127) { + return Mask(bytes + 10); + } else if (payloadLen == 126) { + return Mask(bytes + 4); + } else { + return Mask(bytes + 2); + } + } else { + return Mask(); + } + } + + static size_t headerSize(byte const* bytes, size_t sizeSoFar) { + if (sizeSoFar < 2) return 2; + + size_t required = 2; + + if (bytes[1] & USE_MASK_MASK) { + required += 4; + } + + byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK; + if (payloadLen == 127) { + required += 8; + } else if (payloadLen == 126) { + required += 2; + } + + return required; + } + + private: + byte bytes[14]; + + static constexpr byte FIN_MASK = 0x80; + static constexpr byte RSV2_3_MASK = 0x30; + static constexpr byte RSV1_MASK = 0x40; + static constexpr byte OPCODE_MASK = 0x0f; + + static constexpr byte USE_MASK_MASK = 0x80; + static constexpr byte PAYLOAD_LEN_MASK = 0x7f; + }; + +#if KJ_HAS_ZLIB + class ZlibContext { + // `ZlibContext` is the WebSocket's interface to Zlib's compression/decompression functions. + // Depending on the `mode`, `ZlibContext` will act as a compressor or a decompressor. + public: + enum class Mode { + COMPRESS, + DECOMPRESS, + }; + + struct Result { + int processResult = 0; + kj::Array buffer; + size_t size = 0; // Number of bytes used; size <= buffer.size(). + }; + + ZlibContext(Mode mode, const CompressionParameters& config) : mode(mode) { + switch (mode) { + case Mode::COMPRESS: { + int windowBits = -config.outboundMaxWindowBits.orDefault(15); + // We use negative values because we want to use raw deflate. + if(windowBits == -8) { + // Zlib cannot accept `windowBits` of 8 for the deflater. However, due to an + // implementation quirk, `windowBits` of 8 and 9 would both use 250 bytes. + // Therefore, a decompressor using `windowBits` of 8 could safely inflate a message + // that a zlib client compressed using `windowBits` = 9. + // https://bugs.chromium.org/p/chromium/issues/detail?id=691074 + windowBits = -9; + } + int result = deflateInit2( + &ctx, + Z_DEFAULT_COMPRESSION, + Z_DEFLATED, + windowBits, + 8, // memLevel = 8 is the default + Z_DEFAULT_STRATEGY); + KJ_REQUIRE(result == Z_OK, "Failed to initialize compression context (deflate)."); + break; + } + case Mode::DECOMPRESS: { + int windowBits = -config.inboundMaxWindowBits.orDefault(15); + // We use negative values because we want to use raw inflate. + int result = inflateInit2(&ctx, windowBits); + KJ_REQUIRE(result == Z_OK, "Failed to initialize decompression context (inflate)."); + break; + } + } + } + + ~ZlibContext() noexcept(false) { + switch (mode) { + case Mode::COMPRESS: + deflateEnd(&ctx); + break; + case Mode::DECOMPRESS: + inflateEnd(&ctx); + break; + } + } + + KJ_DISALLOW_COPY_AND_MOVE(ZlibContext); + + kj::Array processMessage(kj::ArrayPtr message, + kj::Maybe maxSize = nullptr, + bool addNullTerminator = false) { + // If `this` is the compressor, calling `processMessage()` will compress the `message`. + // Likewise, if `this` is the decompressor, `processMessage()` will decompress the `message`. + // + // `maxSize` is only passed in when decompressing, since we want to ensure the decompressed + // message is smaller than the `maxSize` passed to `receive()`. + // + // If (de)compression is successful, the result is returned as a Vector, otherwise, + // an Exception is thrown. + + ctx.next_in = const_cast(reinterpret_cast(message.begin())); + ctx.avail_in = message.size(); + + kj::Vector parts(processLoop(maxSize)); + + size_t amountToAllocate = 0; + for (const auto& part : parts) { + amountToAllocate += part.size; + } + + if (addNullTerminator) { + // Add space for the null-terminator. + amountToAllocate += 1; + } + + kj::Array processedMessage = kj::heapArray(amountToAllocate); + size_t currentIndex = 0; // Current index into processedMessage. + for (const auto& part : parts) { + memcpy(&processedMessage[currentIndex], part.buffer.begin(), part.size); + // We need to use `part.size` to determine the number of useful bytes, since data after + // `part.size` is unused (and probably junk). + currentIndex += part.size; + } + + if (addNullTerminator) { + processedMessage[currentIndex++] = '\0'; + } + + KJ_ASSERT(currentIndex == processedMessage.size()); + + return kj::mv(processedMessage); + } + + void reset() { + // Resets the (de)compression context. This should only be called when the (de)compressor uses + // client/server_no_context_takeover. + switch (mode) { + case Mode::COMPRESS: { + KJ_ASSERT(deflateReset(&ctx) == Z_OK, "deflateReset() failed."); + break; + } + case Mode::DECOMPRESS: { + KJ_ASSERT(inflateReset(&ctx) == Z_OK, "inflateReset failed."); + break; + } + } + + } + + private: + Result pumpOnce() { + // Prepares Zlib's internal state for a call to deflate/inflate, then calls the relevant + // function to process the input buffer. It is assumed that the caller has already set up + // Zlib's input buffer. + // + // Since calls to deflate/inflate will process data until the input is empty, or until the + // output is full, multiple calls to `pumpOnce()` may be required to process the entire + // message. We're done processing once either `result` is `Z_STREAM_END`, or we get + // `Z_BUF_ERROR` and did not write any more output. + size_t bufSize = 4096; + Array buffer = kj::heapArray(bufSize); + ctx.next_out = buffer.begin(); + ctx.avail_out = bufSize; + + int result = Z_OK; + + switch (mode) { + case Mode::COMPRESS: + result = deflate(&ctx, Z_SYNC_FLUSH); + KJ_REQUIRE(result == Z_OK || result == Z_BUF_ERROR || result == Z_STREAM_END, + "Compression failed", result); + break; + case Mode::DECOMPRESS: + result = inflate(&ctx, Z_SYNC_FLUSH); + KJ_REQUIRE(result == Z_OK || result == Z_BUF_ERROR || result == Z_STREAM_END, + "Decompression failed", result, " with reason", ctx.msg); + break; + } + + return Result { + result, + kj::mv(buffer), + bufSize - ctx.avail_out, + }; + } + + kj::Vector processLoop(kj::Maybe maxSize) { + // Since Zlib buffers the writes, we want to continue processing until there's nothing left. + kj::Vector output; + size_t totalBytesProcessed = 0; + for (;;) { + Result result = pumpOnce(); + + auto status = result.processResult; + auto bytesProcessed = result.size; + if (bytesProcessed > 0) { + output.add(kj::mv(result)); + totalBytesProcessed += bytesProcessed; + KJ_IF_MAYBE(m, maxSize) { + // This is only non-null for `receive` calls, so we must be decompressing. We don't want + // the decompressed message to OOM us, so let's make sure it's not too big. + KJ_REQUIRE(totalBytesProcessed < *m, + "Decompressed WebSocket message is too large"); + } + } + + if ((ctx.avail_in == 0 && ctx.avail_out != 0) || status == Z_STREAM_END) { + // If we're out of input to consume, and we have space in the output buffer, then we must + // have flushed the remaining message, so we're done pumping. Alternatively, if we found a + // BFINAL deflate block, then we know the stream is completely finished. + if (status == Z_STREAM_END) { + reset(); + } + return kj::mv(output); + } + } + } + + Mode mode; + z_stream ctx = {}; + }; +#endif // KJ_HAS_ZLIB + + static constexpr byte OPCODE_CONTINUATION = 0; + static constexpr byte OPCODE_TEXT = 1; + static constexpr byte OPCODE_BINARY = 2; + static constexpr byte OPCODE_CLOSE = 8; + static constexpr byte OPCODE_PING = 9; + static constexpr byte OPCODE_PONG = 10; + + static constexpr byte OPCODE_FIRST_CONTROL = 8; + + // --------------------------------------------------------------------------- + + kj::Own stream; + kj::Maybe maskKeyGenerator; + kj::Maybe compressionConfig; + WebSocketErrorHandler& errorHandler; +#if KJ_HAS_ZLIB + kj::Maybe compressionContext; + kj::Maybe decompressionContext; +#endif // KJ_HAS_ZLIB + + bool hasSentClose = false; + bool disconnected = false; + bool currentlySending = false; + Header sendHeader; + kj::ArrayPtr sendParts[2]; + + kj::Maybe> queuedPong; + // If a Ping is received while currentlySending is true, then queuedPong is set to the body of + // a pong message that should be sent once the current send is complete. + + kj::Maybe> sendingPong; + // If a Pong is being sent asynchronously in response to a Ping, this is a promise for the + // completion of that send. + // + // Additionally, this member is used if we need to block our first send on WebSocket startup, + // e.g. because we need to wait for HTTP handshake writes to flush before we can start sending + // WebSocket data. `sendingPong` was overloaded for this use case because the logic is the same. + // Perhaps it should be renamed to `blockSend` or `writeQueue`. + + uint fragmentOpcode = 0; + bool fragmentCompressed = false; + // For fragmented messages, was the first frame compressed? + // Note that subsequent frames of a compressed message will not set the RSV1 bit. + kj::Vector> fragments; + // If `fragments` is non-empty, we've already received some fragments of a message. + // `fragmentOpcode` is the original opcode. + + kj::Array recvBuffer; + kj::ArrayPtr recvData; + + uint64_t sentBytes = 0; + uint64_t receivedBytes = 0; + + kj::Promise sendImpl(byte opcode, kj::ArrayPtr message) { + KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); + KJ_REQUIRE(!currentlySending, "another message send is already in progress"); + + currentlySending = true; + + KJ_IF_MAYBE(p, sendingPong) { + // We recently sent a pong, make sure it's finished before proceeding. + auto promise = p->then([this, opcode, message]() { + currentlySending = false; + return sendImpl(opcode, message); + }); + sendingPong = nullptr; + return promise; + } + + // We don't stop the application from sending further messages after close() -- this is the + // application's error to make. But, we do want to make sure we don't send any PONGs after a + // close, since that would be our error. So we stack whether we closed for that reason. + hasSentClose = hasSentClose || opcode == OPCODE_CLOSE; + + Mask mask(maskKeyGenerator); + + bool useCompression = false; + kj::Maybe> compressedMessage; + if (opcode == OPCODE_BINARY || opcode == OPCODE_TEXT) { + // We can only compress data frames. +#if KJ_HAS_ZLIB + KJ_IF_MAYBE(config, compressionConfig) { + useCompression = true; + // Compress `message` according to `compressionConfig`s outbound parameters. + auto& compressor = KJ_ASSERT_NONNULL(compressionContext); + if (config->outboundNoContextTakeover) { + // We must reset context on each message. + compressor.reset(); + } + auto& innerMessage = compressedMessage.emplace(compressor.processMessage(message)); + if (message.size() > 0) { + KJ_ASSERT(innerMessage.asPtr().endsWith({0x00, 0x00, 0xFF, 0xFF})); + message = innerMessage.slice(0, innerMessage.size() - 4); + // Strip 0x00 0x00 0xFF 0xFF off the tail. + // See: https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1 + } else { + // RFC 7692 (7.2.3.6) specifies that an empty uncompressed DEFLATE block (0x00) should be + // built if the compression library doesn't generate data when the input is empty. + message = compressedMessage.emplace(kj::heapArray({0x00})); + } + } +#endif // KJ_HAS_ZLIB + } + + kj::Array ownMessage; + if (!mask.isZero()) { + // Sadness, we have to make a copy to apply the mask. + ownMessage = kj::heapArray(message); + mask.apply(ownMessage); + message = ownMessage; + } + + sendParts[0] = sendHeader.compose(true, useCompression, opcode, message.size(), mask); + sendParts[1] = message; + KJ_ASSERT(!sendHeader.hasRsv2or3(), "RSV bits 2 and 3 must be 0, as we do not currently " + "support an extension that would set these bits"); + + auto promise = stream->write(sendParts).attach(kj::mv(compressedMessage)); + if (!mask.isZero()) { + promise = promise.attach(kj::mv(ownMessage)); + } + return promise.then([this, size = sendParts[0].size() + sendParts[1].size()]() { + currentlySending = false; + + // Send queued pong if needed. + KJ_IF_MAYBE(q, queuedPong) { + kj::Array payload = kj::mv(*q); + queuedPong = nullptr; + queuePong(kj::mv(payload)); + } + sentBytes += size; + }); + } + + void queuePong(kj::Array payload) { + if (currentlySending) { + // There is a message-send in progress, so we cannot write to the stream now. + // + // Note: According to spec, if the server receives a second ping before responding to the + // previous one, it can opt to respond only to the last ping. So we don't have to check if + // queuedPong is already non-null. + queuedPong = kj::mv(payload); + } else KJ_IF_MAYBE(promise, sendingPong) { + // We're still sending a previous pong. Wait for it to finish before sending ours. + sendingPong = promise->then([this,payload=kj::mv(payload)]() mutable { + return sendPong(kj::mv(payload)); + }); + } else { + // We're not sending any pong currently. + sendingPong = sendPong(kj::mv(payload)); + } + } + + kj::Promise sendPong(kj::Array payload) { + if (hasSentClose || disconnected) { + return kj::READY_NOW; + } + + sendParts[0] = sendHeader.compose(true, false, OPCODE_PONG, + payload.size(), Mask(maskKeyGenerator)); + sendParts[1] = payload; + return stream->write(sendParts).attach(kj::mv(payload)); + } + + kj::Promise optimizedPumpTo(WebSocketImpl& other) { + KJ_IF_MAYBE(p, other.sendingPong) { + // We recently sent a pong, make sure it's finished before proceeding. + auto promise = p->then([this, &other]() { + return optimizedPumpTo(other); + }); + other.sendingPong = nullptr; + return promise; + } + + if (recvData.size() > 0) { + // We have some data buffered. Write it first. + return other.stream->write(recvData.begin(), recvData.size()) + .then([this, &other, size = recvData.size()]() { + recvData = nullptr; + other.sentBytes += size; + return optimizedPumpTo(other); + }); + } + + auto cancelPromise = other.stream->whenWriteDisconnected() + .then([this]() -> kj::Promise { + this->abort(); + return KJ_EXCEPTION(DISCONNECTED, + "destination of WebSocket pump disconnected prematurely"); + }); + + // There's no buffered incoming data, so start pumping stream now. + return stream->pumpTo(*other.stream).then([this, &other](size_t s) -> kj::Promise { + // WebSocket pumps are expected to include end-of-stream. + other.disconnected = true; + other.stream->shutdownWrite(); + receivedBytes += s; + other.sentBytes += s; + return kj::READY_NOW; + }, [&other](kj::Exception&& e) -> kj::Promise { + // We don't know if it was a read or a write that threw. If it was a read that threw, we need + // to send a disconnect on the destination. If it was the destination that threw, it + // shouldn't hurt to disconnect() it again, but we'll catch and squelch any exceptions. + other.disconnected = true; + kj::runCatchingExceptions([&other]() { other.stream->shutdownWrite(); }); + return kj::mv(e); + }).exclusiveJoin(kj::mv(cancelPromise)); + } +}; + +kj::Own upgradeToWebSocket( + kj::Own stream, HttpInputStreamImpl& httpInput, HttpOutputStream& httpOutput, + kj::Maybe maskKeyGenerator, + kj::Maybe compressionConfig = nullptr, + kj::Maybe errorHandler = nullptr) { + // Create a WebSocket upgraded from an HTTP stream. + auto releasedBuffer = httpInput.releaseBuffer(); + return kj::heap(kj::mv(stream), maskKeyGenerator, + kj::mv(compressionConfig), errorHandler, + kj::mv(releasedBuffer.buffer), + releasedBuffer.leftover, httpOutput.flush()); +} + +} // namespace + +kj::Own newWebSocket(kj::Own stream, + kj::Maybe maskKeyGenerator, + kj::Maybe compressionConfig, + kj::Maybe errorHandler) { + return kj::heap(kj::mv(stream), maskKeyGenerator, kj::mv(compressionConfig), errorHandler); +} + +static kj::Promise pumpWebSocketLoop(WebSocket& from, WebSocket& to) { + return from.receive().then([&from,&to](WebSocket::Message&& message) { + KJ_SWITCH_ONEOF(message) { + KJ_CASE_ONEOF(text, kj::String) { + return to.send(text) + .attach(kj::mv(text)) + .then([&from,&to]() { return pumpWebSocketLoop(from, to); }); + } + KJ_CASE_ONEOF(data, kj::Array) { + return to.send(data) + .attach(kj::mv(data)) + .then([&from,&to]() { return pumpWebSocketLoop(from, to); }); + } + KJ_CASE_ONEOF(close, WebSocket::Close) { + // Once a close has passed through, the pump is complete. + return to.close(close.code, close.reason) + .attach(kj::mv(close)); + } + } + KJ_UNREACHABLE; + }, [&to](kj::Exception&& e) { + if (e.getType() == kj::Exception::Type::DISCONNECTED) { + return to.disconnect(); + } else { + return to.close(1002, e.getDescription()); + } + }); +} + +kj::Promise WebSocket::pumpTo(WebSocket& other) { + KJ_IF_MAYBE(p, other.tryPumpFrom(*this)) { + // Yay, optimized pump! + return kj::mv(*p); + } else { + // Fall back to default implementation. + return kj::evalNow([&]() { + auto cancelPromise = other.whenAborted().then([this]() -> kj::Promise { + this->abort(); + return KJ_EXCEPTION(DISCONNECTED, + "destination of WebSocket pump disconnected prematurely"); + }); + return pumpWebSocketLoop(*this, other).exclusiveJoin(kj::mv(cancelPromise)); + }); + } +} + +kj::Maybe> WebSocket::tryPumpFrom(WebSocket& other) { + return nullptr; +} + +namespace { + +class WebSocketPipeImpl final: public WebSocket, public kj::Refcounted { + // Represents one direction of a WebSocket pipe. + // + // This class behaves as a "loopback" WebSocket: a message sent using send() is received using + // receive(), on the same object. This is *not* how WebSocket implementations usually behave. + // But, this object is actually used to implement only one direction of a bidirectional pipe. At + // another layer above this, the pipe is actually composed of two WebSocketPipeEnd instances, + // which layer on top of two WebSocketPipeImpl instances representing the two directions. So, + // send() calls on a WebSocketPipeImpl instance always come from one of the two WebSocketPipeEnds + // while receive() calls come from the other end. + +public: + ~WebSocketPipeImpl() noexcept(false) { + KJ_REQUIRE(state == nullptr || ownState.get() != nullptr, + "destroying WebSocketPipe with operation still in-progress; probably going to segfault") { + // Don't std::terminate(). + break; + } + } + + void abort() override { + KJ_IF_MAYBE(s, state) { + s->abort(); + } else { + ownState = heap(); + state = *ownState; + + aborted = true; + KJ_IF_MAYBE(f, abortedFulfiller) { + f->get()->fulfill(); + abortedFulfiller = nullptr; + } + } + } + + kj::Promise send(kj::ArrayPtr message) override { + KJ_IF_MAYBE(s, state) { + return s->send(message).then([&, size = message.size()]() { transferredBytes += size; }); + } else { + return newAdaptedPromise(*this, MessagePtr(message)) + .then([&, size = message.size()]() { transferredBytes += size; }); + } + } + kj::Promise send(kj::ArrayPtr message) override { + KJ_IF_MAYBE(s, state) { + return s->send(message).then([&, size = message.size()]() { transferredBytes += size; }); + } else { + return newAdaptedPromise(*this, MessagePtr(message)) + .then([&, size = message.size()]() { transferredBytes += size; }); + } + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + KJ_IF_MAYBE(s, state) { + return s->close(code, reason) + .then([&, size = reason.size()]() { transferredBytes += (2 +size); }); + } else { + return newAdaptedPromise(*this, MessagePtr(ClosePtr { code, reason })) + .then([&, size = reason.size()]() { transferredBytes += (2 +size); }); + } + } + kj::Promise disconnect() override { + KJ_IF_MAYBE(s, state) { + return s->disconnect(); + } else { + ownState = heap(); + state = *ownState; + return kj::READY_NOW; + } + } + kj::Promise whenAborted() override { + if (aborted) { + return kj::READY_NOW; + } else KJ_IF_MAYBE(p, abortedPromise) { + return p->addBranch(); + } else { + auto paf = newPromiseAndFulfiller(); + abortedFulfiller = kj::mv(paf.fulfiller); + auto fork = paf.promise.fork(); + auto result = fork.addBranch(); + abortedPromise = kj::mv(fork); + return result; + } + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + KJ_IF_MAYBE(s, state) { + return s->tryPumpFrom(other); + } else { + return newAdaptedPromise(*this, other); + } + } + + kj::Promise receive(size_t maxSize) override { + KJ_IF_MAYBE(s, state) { + return s->receive(maxSize); + } else { + return newAdaptedPromise(*this, maxSize); + } + } + kj::Promise pumpTo(WebSocket& other) override { + auto onAbort = other.whenAborted() + .then([]() -> kj::Promise { + return KJ_EXCEPTION(DISCONNECTED, "WebSocket was aborted"); + }); + + KJ_IF_MAYBE(s, state) { + auto before = other.receivedByteCount(); + return s->pumpTo(other).attach(kj::defer([this, &other, before]() { + transferredBytes += other.receivedByteCount() - before; + })).exclusiveJoin(kj::mv(onAbort)); + } else { + return newAdaptedPromise(*this, other).exclusiveJoin(kj::mv(onAbort)); + } + } + + uint64_t sentByteCount() override { + return transferredBytes; + } + uint64_t receivedByteCount() override { + return transferredBytes; + } + +private: + kj::Maybe state; + // Object-oriented state! If any method call is blocked waiting on activity from the other end, + // then `state` is non-null and method calls should be forwarded to it. If no calls are + // outstanding, `state` is null. + + kj::Own ownState; + + uint64_t transferredBytes = 0; + + bool aborted = false; + Maybe>> abortedFulfiller = nullptr; + Maybe> abortedPromise = nullptr; + + void endState(WebSocket& obj) { + KJ_IF_MAYBE(s, state) { + if (s == &obj) { + state = nullptr; + } + } + } + + struct ClosePtr { + uint16_t code; + kj::StringPtr reason; + }; + typedef kj::OneOf, kj::ArrayPtr, ClosePtr> MessagePtr; + + class BlockedSend final: public WebSocket { + public: + BlockedSend(kj::PromiseFulfiller& fulfiller, WebSocketPipeImpl& pipe, MessagePtr message) + : fulfiller(fulfiller), pipe(pipe), message(kj::mv(message)) { + KJ_REQUIRE(pipe.state == nullptr); + pipe.state = *this; + } + ~BlockedSend() noexcept(false) { + pipe.endState(*this); + } + + void abort() override { + canceler.cancel("other end of WebSocketPipe was destroyed"); + fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed")); + pipe.endState(*this); + pipe.abort(); + } + kj::Promise whenAborted() override { + KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); + } + + kj::Promise send(kj::ArrayPtr message) override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + kj::Promise send(kj::ArrayPtr message) override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + kj::Promise disconnect() override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + + kj::Promise receive(size_t maxSize) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + fulfiller.fulfill(); + pipe.endState(*this); + KJ_SWITCH_ONEOF(message) { + KJ_CASE_ONEOF(arr, kj::ArrayPtr) { + return Message(kj::str(arr)); + } + KJ_CASE_ONEOF(arr, kj::ArrayPtr) { + auto copy = kj::heapArray(arr.size()); + memcpy(copy.begin(), arr.begin(), arr.size()); + return Message(kj::mv(copy)); + } + KJ_CASE_ONEOF(close, ClosePtr) { + return Message(Close { close.code, kj::str(close.reason) }); + } + } + KJ_UNREACHABLE; + } + kj::Promise pumpTo(WebSocket& other) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + kj::Promise promise = nullptr; + KJ_SWITCH_ONEOF(message) { + KJ_CASE_ONEOF(arr, kj::ArrayPtr) { + promise = other.send(arr); + } + KJ_CASE_ONEOF(arr, kj::ArrayPtr) { + promise = other.send(arr); + } + KJ_CASE_ONEOF(close, ClosePtr) { + promise = other.close(close.code, close.reason); + } + } + return canceler.wrap(promise.then([this,&other]() { + canceler.release(); + fulfiller.fulfill(); + pipe.endState(*this); + return pipe.pumpTo(other); + }, [this](kj::Exception&& e) -> kj::Promise { + canceler.release(); + fulfiller.reject(kj::cp(e)); + pipe.endState(*this); + return kj::mv(e); + })); + } + + uint64_t sentByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + uint64_t receivedByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + + private: + kj::PromiseFulfiller& fulfiller; + WebSocketPipeImpl& pipe; + MessagePtr message; + Canceler canceler; + }; + + class BlockedPumpFrom final: public WebSocket { + public: + BlockedPumpFrom(kj::PromiseFulfiller& fulfiller, WebSocketPipeImpl& pipe, + WebSocket& input) + : fulfiller(fulfiller), pipe(pipe), input(input) { + KJ_REQUIRE(pipe.state == nullptr); + pipe.state = *this; + } + ~BlockedPumpFrom() noexcept(false) { + pipe.endState(*this); + } + + void abort() override { + canceler.cancel("other end of WebSocketPipe was destroyed"); + fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed")); + pipe.endState(*this); + pipe.abort(); + } + kj::Promise whenAborted() override { + KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); + } + + kj::Promise send(kj::ArrayPtr message) override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + kj::Promise send(kj::ArrayPtr message) override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + kj::Promise disconnect() override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + KJ_FAIL_ASSERT("another message send is already in progress"); + } + + kj::Promise receive(size_t maxSize) override { + KJ_REQUIRE(canceler.isEmpty(), "another message receive is already in progress"); + return canceler.wrap(input.receive(maxSize) + .then([this](Message message) { + if (message.is()) { + canceler.release(); + fulfiller.fulfill(); + pipe.endState(*this); + } + return kj::mv(message); + }, [this](kj::Exception&& e) -> Message { + canceler.release(); + fulfiller.reject(kj::cp(e)); + pipe.endState(*this); + kj::throwRecoverableException(kj::mv(e)); + return Message(kj::String()); + })); + } + kj::Promise pumpTo(WebSocket& other) override { + KJ_REQUIRE(canceler.isEmpty(), "another message receive is already in progress"); + return canceler.wrap(input.pumpTo(other) + .then([this]() { + canceler.release(); + fulfiller.fulfill(); + pipe.endState(*this); + }, [this](kj::Exception&& e) { + canceler.release(); + fulfiller.reject(kj::cp(e)); + pipe.endState(*this); + kj::throwRecoverableException(kj::mv(e)); + })); + } + + uint64_t sentByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + uint64_t receivedByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + + private: + kj::PromiseFulfiller& fulfiller; + WebSocketPipeImpl& pipe; + WebSocket& input; + Canceler canceler; + }; + + class BlockedReceive final: public WebSocket { + public: + BlockedReceive(kj::PromiseFulfiller& fulfiller, WebSocketPipeImpl& pipe, + size_t maxSize) + : fulfiller(fulfiller), pipe(pipe), maxSize(maxSize) { + KJ_REQUIRE(pipe.state == nullptr); + pipe.state = *this; + } + ~BlockedReceive() noexcept(false) { + pipe.endState(*this); + } + + void abort() override { + canceler.cancel("other end of WebSocketPipe was destroyed"); + fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed")); + pipe.endState(*this); + pipe.abort(); + } + kj::Promise whenAborted() override { + KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); + } + + kj::Promise send(kj::ArrayPtr message) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + auto copy = kj::heapArray(message.size()); + memcpy(copy.begin(), message.begin(), message.size()); + fulfiller.fulfill(Message(kj::mv(copy))); + pipe.endState(*this); + return kj::READY_NOW; + } + kj::Promise send(kj::ArrayPtr message) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + fulfiller.fulfill(Message(kj::str(message))); + pipe.endState(*this); + return kj::READY_NOW; + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + fulfiller.fulfill(Message(Close { code, kj::str(reason) })); + pipe.endState(*this); + return kj::READY_NOW; + } + kj::Promise disconnect() override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "WebSocket disconnected")); + pipe.endState(*this); + return pipe.disconnect(); + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + KJ_REQUIRE(canceler.isEmpty(), "already pumping"); + return canceler.wrap(other.receive(maxSize).then([this,&other](Message message) { + canceler.release(); + fulfiller.fulfill(kj::mv(message)); + pipe.endState(*this); + return other.pumpTo(pipe); + }, [this](kj::Exception&& e) -> kj::Promise { + canceler.release(); + fulfiller.reject(kj::cp(e)); + pipe.endState(*this); + return kj::mv(e); + })); + } + + kj::Promise receive(size_t maxSize) override { + KJ_FAIL_ASSERT("another message receive is already in progress"); + } + kj::Promise pumpTo(WebSocket& other) override { + KJ_FAIL_ASSERT("another message receive is already in progress"); + } + + uint64_t sentByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + uint64_t receivedByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + + private: + kj::PromiseFulfiller& fulfiller; + WebSocketPipeImpl& pipe; + size_t maxSize; + Canceler canceler; + }; + + class BlockedPumpTo final: public WebSocket { + public: + BlockedPumpTo(kj::PromiseFulfiller& fulfiller, WebSocketPipeImpl& pipe, WebSocket& output) + : fulfiller(fulfiller), pipe(pipe), output(output) { + KJ_REQUIRE(pipe.state == nullptr); + pipe.state = *this; + } + ~BlockedPumpTo() noexcept(false) { + pipe.endState(*this); + } + + void abort() override { + canceler.cancel("other end of WebSocketPipe was destroyed"); + + // abort() is called when the pipe end is dropped. This should be treated as disconnecting, + // so pumpTo() should complete normally. + fulfiller.fulfill(); + + pipe.endState(*this); + pipe.abort(); + } + kj::Promise whenAborted() override { + KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); + } + + kj::Promise send(kj::ArrayPtr message) override { + KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); + return canceler.wrap(output.send(message)); + } + kj::Promise send(kj::ArrayPtr message) override { + KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); + return canceler.wrap(output.send(message)); + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); + return canceler.wrap(output.close(code, reason).then([this]() { + // A pump is expected to end upon seeing a Close message. + canceler.release(); + pipe.endState(*this); + fulfiller.fulfill(); + }, [this](kj::Exception&& e) { + canceler.release(); + pipe.endState(*this); + fulfiller.reject(kj::cp(e)); + kj::throwRecoverableException(kj::mv(e)); + })); + } + kj::Promise disconnect() override { + KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); + return canceler.wrap(output.disconnect().then([this]() { + canceler.release(); + pipe.endState(*this); + fulfiller.fulfill(); + return pipe.disconnect(); + }, [this](kj::Exception&& e) { + canceler.release(); + pipe.endState(*this); + fulfiller.reject(kj::cp(e)); + kj::throwRecoverableException(kj::mv(e)); + })); + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); + return canceler.wrap(other.pumpTo(output).then([this]() { + canceler.release(); + pipe.endState(*this); + fulfiller.fulfill(); + }, [this](kj::Exception&& e) { + canceler.release(); + pipe.endState(*this); + fulfiller.reject(kj::cp(e)); + kj::throwRecoverableException(kj::mv(e)); + })); + } + + kj::Promise receive(size_t maxSize) override { + KJ_FAIL_ASSERT("another message receive is already in progress"); + } + kj::Promise pumpTo(WebSocket& other) override { + KJ_FAIL_ASSERT("another message receive is already in progress"); + } + + uint64_t sentByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + uint64_t receivedByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + + private: + kj::PromiseFulfiller& fulfiller; + WebSocketPipeImpl& pipe; + WebSocket& output; + Canceler canceler; + }; + + class Disconnected final: public WebSocket { + public: + void abort() override { + // can ignore + } + kj::Promise whenAborted() override { + KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); + } + + kj::Promise send(kj::ArrayPtr message) override { + KJ_FAIL_REQUIRE("can't send() after disconnect()"); + } + kj::Promise send(kj::ArrayPtr message) override { + KJ_FAIL_REQUIRE("can't send() after disconnect()"); + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + KJ_FAIL_REQUIRE("can't close() after disconnect()"); + } + kj::Promise disconnect() override { + return kj::READY_NOW; + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + KJ_FAIL_REQUIRE("can't tryPumpFrom() after disconnect()"); + } + + kj::Promise receive(size_t maxSize) override { + return KJ_EXCEPTION(DISCONNECTED, "WebSocket disconnected"); + } + kj::Promise pumpTo(WebSocket& other) override { + return kj::READY_NOW; + } + + uint64_t sentByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + uint64_t receivedByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + + }; + + class Aborted final: public WebSocket { + public: + void abort() override { + // can ignore + } + kj::Promise whenAborted() override { + KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); + } + + kj::Promise send(kj::ArrayPtr message) override { + return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); + } + kj::Promise send(kj::ArrayPtr message) override { + return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); + } + kj::Promise disconnect() override { + return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + return kj::Promise(KJ_EXCEPTION(DISCONNECTED, + "other end of WebSocketPipe was destroyed")); + } + + kj::Promise receive(size_t maxSize) override { + return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); + } + kj::Promise pumpTo(WebSocket& other) override { + return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); + } + + uint64_t sentByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + uint64_t receivedByteCount() override { + KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); + } + }; +}; + +class WebSocketPipeEnd final: public WebSocket { +public: + WebSocketPipeEnd(kj::Own in, kj::Own out) + : in(kj::mv(in)), out(kj::mv(out)) {} + ~WebSocketPipeEnd() noexcept(false) { + in->abort(); + out->abort(); + } + + kj::Promise send(kj::ArrayPtr message) override { + return out->send(message); + } + kj::Promise send(kj::ArrayPtr message) override { + return out->send(message); + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + return out->close(code, reason); + } + kj::Promise disconnect() override { + return out->disconnect(); + } + void abort() override { + in->abort(); + out->abort(); + } + kj::Promise whenAborted() override { + return out->whenAborted(); + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + return out->tryPumpFrom(other); + } + + kj::Promise receive(size_t maxSize) override { + return in->receive(maxSize); + } + kj::Promise pumpTo(WebSocket& other) override { + return in->pumpTo(other); + } + + uint64_t sentByteCount() override { return out->sentByteCount(); } + uint64_t receivedByteCount() override { return in->sentByteCount(); } + +private: + kj::Own in; + kj::Own out; +}; + +} // namespace + +WebSocketPipe newWebSocketPipe() { + auto pipe1 = kj::refcounted(); + auto pipe2 = kj::refcounted(); + + auto end1 = kj::heap(kj::addRef(*pipe1), kj::addRef(*pipe2)); + auto end2 = kj::heap(kj::mv(pipe2), kj::mv(pipe1)); + + return { { kj::mv(end1), kj::mv(end2) } }; +} + +// ======================================================================================= +class AsyncIoStreamWithInitialBuffer final: public kj::AsyncIoStream { + // An AsyncIoStream implementation that accepts an initial buffer of data + // to be read out first, and is optionally capable of deferring writes + // until a given waitBeforeSend promise is fulfilled. + // + // Instances are created with a leftoverBackingBuffer (a kj::Array) + // and a leftover kj::ArrayPtr that provides a view into the backing + // buffer representing the queued data that is pending to be read. Calling + // tryRead will consume the data from the leftover first. Once leftover has + // been fully consumed, reads will defer to the underlying stream. +public: + AsyncIoStreamWithInitialBuffer(kj::Own stream, + kj::Array leftoverBackingBuffer, + kj::ArrayPtr leftover) + : stream(kj::mv(stream)), + leftoverBackingBuffer(kj::mv(leftoverBackingBuffer)), + leftover(leftover) {} + + void shutdownWrite() override { + stream->shutdownWrite(); + } + + // AsyncInputStream + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + KJ_REQUIRE(maxBytes >= minBytes); + auto destination = static_cast(buffer); + + // If there are at least minBytes available in the leftover buffer... + if (leftover.size() >= minBytes) { + // We are going to immediately read up to maxBytes from the leftover buffer... + auto bytesToCopy = kj::min(maxBytes, leftover.size()); + memcpy(destination, leftover.begin(), bytesToCopy); + leftover = leftover.slice(bytesToCopy, leftover.size()); + + // If we've consumed all of the data in the leftover buffer, go ahead and free it. + if (leftover.size() == 0) { + leftoverBackingBuffer = nullptr; + } + + return bytesToCopy; + } else { + // We know here that leftover.size() is less than minBytes, but it might not + // be zero. Copy everything from leftover into the destination buffer then read + // the rest from the underlying stream. + auto bytesToCopy = leftover.size(); + KJ_DASSERT(bytesToCopy < minBytes); + + if (bytesToCopy > 0) { + memcpy(destination, leftover.begin(), bytesToCopy); + leftoverBackingBuffer = nullptr; + minBytes -= bytesToCopy; + maxBytes -= bytesToCopy; + KJ_DASSERT(minBytes >= 1); + KJ_DASSERT(maxBytes >= minBytes); + } + + return stream->tryRead(destination + bytesToCopy, minBytes, maxBytes) + .then([bytesToCopy](size_t amount) { return amount + bytesToCopy; }); + } + } + + Maybe tryGetLength() override { + // For a CONNECT pipe, we have no idea how much data there is going to be. + return nullptr; + } + + kj::Promise pumpTo(AsyncOutputStream& output, + uint64_t amount = kj::maxValue) override { + return pumpLoop(output, amount, 0); + } + + kj::Maybe> tryPumpFrom(AsyncInputStream& input, + uint64_t amount = kj::maxValue) override { + return input.pumpTo(*stream, amount); + } + + // AsyncOutputStream + Promise write(const void* buffer, size_t size) override { + return stream->write(buffer, size); + } + + Promise write(ArrayPtr> pieces) override { + return stream->write(pieces); + } + + Promise whenWriteDisconnected() override { + return stream->whenWriteDisconnected(); + } + +private: + + kj::Promise pumpLoop( + kj::AsyncOutputStream& output, + uint64_t remaining, + uint64_t total) { + // If there is any data remaining in the leftover queue, we'll write it out first to output. + if (leftover.size() > 0) { + auto bytesToWrite = kj::min(leftover.size(), remaining); + return output.write(leftover.begin(), bytesToWrite).then( + [this, &output, remaining, total, bytesToWrite]() mutable -> kj::Promise { + leftover = leftover.slice(bytesToWrite, leftover.size()); + // If the leftover buffer has been fully consumed, go ahead and free it now. + if (leftover.size() == 0) { + leftoverBackingBuffer = nullptr; + } + remaining -= bytesToWrite; + total += bytesToWrite; + + if (remaining == 0) { + return total; + } + return pumpLoop(output, remaining, total); + }); + } else { + // Otherwise, we are just going to defer to stream's pumpTo, making sure to + // account for the total amount we've already written from the leftover queue. + return stream->pumpTo(output, remaining).then([total](auto read) { + return total + read; + }); + } + }; + + kj::Own stream; + kj::Array leftoverBackingBuffer; + kj::ArrayPtr leftover; +}; + +class AsyncIoStreamWithGuards final: public kj::AsyncIoStream, + private kj::TaskSet::ErrorHandler { + // This AsyncIoStream adds separate kj::Promise guards to both the input and output, + // delaying reads and writes until each relevant guard is resolved. + // + // When the read guard promise resolves, it may provide a released buffer that will + // be read out first. + // The primary use case for this impl is to support pipelined CONNECT calls which + // optimistically allow outbound writes to happen while establishing the CONNECT + // tunnel has not yet been completed. If the guard promise rejects, the stream + // is permanently errored and existing pending calls (reads and writes) are canceled. +public: + AsyncIoStreamWithGuards( + kj::Own inner, + kj::Promise> readGuard, + kj::Promise writeGuard) + : inner(kj::mv(inner)), + readGuard(handleReadGuard(kj::mv(readGuard))), + writeGuard(handleWriteGuard(kj::mv(writeGuard))), + tasks(*this) {} + + // AsyncInputStream + Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + if (readGuardReleased) { + return inner->tryRead(buffer, minBytes, maxBytes); + } + return readGuard.addBranch().then([this, buffer, minBytes, maxBytes] { + return inner->tryRead(buffer, minBytes, maxBytes); + }); + } + + Maybe tryGetLength() override { + return nullptr; + } + + kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount = kj::maxValue) override { + if (readGuardReleased) { + return inner->pumpTo(output, amount); + } + return readGuard.addBranch().then([this, &output, amount] { + return inner->pumpTo(output, amount); + }); + } + + // AsyncOutputStream + + void shutdownWrite() override { + if (writeGuardReleased) { + inner->shutdownWrite(); + } else { + tasks.add(writeGuard.addBranch().then([this]() { inner->shutdownWrite(); })); + } + } + + kj::Maybe> tryPumpFrom(AsyncInputStream& input, + uint64_t amount = kj::maxValue) override { + if (writeGuardReleased) { + return input.pumpTo(*inner, amount); + } else { + return writeGuard.addBranch().then([this,&input,amount]() { + return input.pumpTo(*inner, amount); + }); + } + } + + Promise write(const void* buffer, size_t size) override { + if (writeGuardReleased) { + return inner->write(buffer, size); + } else { + return writeGuard.addBranch().then([this,buffer,size]() { + return inner->write(buffer, size); + }); + } + } + + Promise write(ArrayPtr> pieces) override { + if (writeGuardReleased) { + return inner->write(pieces); + } else { + return writeGuard.addBranch().then([this, pieces]() { + return inner->write(pieces); + }); + } + } + + Promise whenWriteDisconnected() override { + if (writeGuardReleased) { + return inner->whenWriteDisconnected(); + } else { + return writeGuard.addBranch().then([this]() { + return inner->whenWriteDisconnected(); + }, [](kj::Exception&& e) mutable -> kj::Promise { + if (e.getType() == kj::Exception::Type::DISCONNECTED) { + return kj::READY_NOW; + } else { + return kj::mv(e); + } + }); + } + } + +private: + kj::Own inner; + kj::ForkedPromise readGuard; + kj::ForkedPromise writeGuard; + bool readGuardReleased = false; + bool writeGuardReleased = false; + kj::TaskSet tasks; + // Set of tasks used to call `shutdownWrite` after write guard is released. + + void taskFailed(kj::Exception&& exception) override { + // This `taskFailed` callback is only used when `shutdownWrite` is being called. Because we + // don't care about DISCONNECTED exceptions when `shutdownWrite` is called we ignore this + // class of exceptions here. + if (exception.getType() != kj::Exception::Type::DISCONNECTED) { + KJ_LOG(ERROR, exception); + } + } + + kj::ForkedPromise handleWriteGuard(kj::Promise guard) { + return guard.then([this]() { + writeGuardReleased = true; + }).fork(); + } + + kj::ForkedPromise handleReadGuard( + kj::Promise> guard) { + return guard.then([this](kj::Maybe buffer) mutable { + readGuardReleased = true; + KJ_IF_MAYBE(b, buffer) { + if (b->leftover.size() > 0) { + // We only need to replace the inner stream if a non-empty buffer is provided. + inner = heap( + kj::mv(inner), + kj::mv(b->buffer), b->leftover); + } + } + }).fork(); + } +}; + +// ======================================================================================= + +namespace _ { // private implementation details + +kj::ArrayPtr splitNext(kj::ArrayPtr& cursor, char delimiter) { + // Consumes and returns the next item in a delimited list. + // + // If a delimiter is found: + // - `cursor` is updated to point to the rest of the string after the delimiter. + // - The text before the delimiter is returned. + // If no delimiter is found: + // - `cursor` is updated to an empty string. + // - The text that had been in `cursor` is returned. + // + // (It's up to the caller to stop the loop once `cursor` is empty.) + KJ_IF_MAYBE(index, cursor.findFirst(delimiter)) { + auto part = cursor.slice(0, *index); + cursor = cursor.slice(*index + 1, cursor.size()); + return part; + } + kj::ArrayPtr result(kj::mv(cursor)); + cursor = nullptr; + + return result; +} + +void stripLeadingAndTrailingSpace(ArrayPtr& str) { + // Remove any leading/trailing spaces from `str`, modifying it in-place. + while (str.size() > 0 && (str[0] == ' ' || str[0] == '\t')) { + str = str.slice(1, str.size()); + } + while (str.size() > 0 && (str.back() == ' ' || str.back() == '\t')) { + str = str.slice(0, str.size() - 1); + } +} + +kj::Vector> splitParts(kj::ArrayPtr input, char delim) { + // Given a string `input` and a delimiter `delim`, split the string into a vector of substrings, + // separated by the delimiter. Note that leading/trailing whitespace is stripped from each element. + kj::Vector> parts; + + while (input.size() != 0) { + auto part = splitNext(input, delim); + stripLeadingAndTrailingSpace(part); + parts.add(kj::mv(part)); + } + + return parts; +} + +kj::Array toKeysAndVals(const kj::ArrayPtr>& params) { + // Given a collection of parameters (a single offer), parse the parameters into + // pairs. If the parameter contains an `=`, we set the `key` to everything before, and the `value` + // to everything after. Otherwise, we set the `key` to be the entire parameter. + // Either way, both the key and value (if it exists) are stripped of leading & trailing whitespace. + auto result = kj::heapArray(params.size()); + size_t count = 0; + for (const auto& param : params) { + kj::ArrayPtr key; + kj::Maybe> value; + + KJ_IF_MAYBE(index, param.findFirst('=')) { + // Found '=' so we have a value. + key = param.slice(0, *index); + stripLeadingAndTrailingSpace(key); + value = param.slice(*index + 1, param.size()); + KJ_IF_MAYBE(v, value) { + stripLeadingAndTrailingSpace(*v); + } + } else { + key = kj::mv(param); + } + + result[count].key = kj::mv(key); + result[count].val = kj::mv(value); + ++count; + } + return kj::mv(result); +} + +struct ParamType { + enum { CLIENT, SERVER } side; + enum { NO_CONTEXT_TAKEOVER, MAX_WINDOW_BITS } property; +}; + +inline kj::Maybe parseKeyName(kj::ArrayPtr& key) { + // Returns a `ParamType` struct if the `key` is valid and nullptr if invalid. + + if (key == "client_no_context_takeover"_kj) { + return ParamType { ParamType::CLIENT, ParamType::NO_CONTEXT_TAKEOVER }; + } else if (key == "server_no_context_takeover"_kj) { + return ParamType { ParamType::SERVER, ParamType::NO_CONTEXT_TAKEOVER }; + } else if (key == "client_max_window_bits"_kj) { + return ParamType { ParamType::CLIENT, ParamType::MAX_WINDOW_BITS }; + } else if (key == "server_max_window_bits"_kj) { + return ParamType { ParamType::SERVER, ParamType::MAX_WINDOW_BITS }; + } + return nullptr; +} + +kj::Maybe populateUnverifiedConfig(kj::Array& params) { + // Given a collection of pairs, attempt to populate an `UnverifiedConfig` struct. + // If the struct cannot be populated, we return null. + // + // This function populates the struct with what it finds, it does not perform bounds checking or + // concern itself with valid `Value`s (so long as the `Value` is non-empty). + // + // The following issues would prevent a struct from being populated: + // Key issues: + // - `Key` is invalid (see `parseKeyName()`). + // - `Key` is repeated. + // Value issues: + // - Got a `Value` when none was expected (only the `max_window_bits` parameters expect values). + // - Got an empty `Value` (0 characters, or all whitespace characters). + + if (params.size() > 4) { + // We expect 4 `Key`s at most, having more implies repeats/invalid keys are present. + return nullptr; + } + + UnverifiedConfig config; + + for (auto& param : params) { + KJ_IF_MAYBE(paramType, parseKeyName(param.key)) { + // `Key` is valid, but we still want to check for repeats. + const auto& side = paramType->side; + const auto& property = paramType->property; + + if (property == ParamType::NO_CONTEXT_TAKEOVER) { + auto& takeOverSetting = (side == ParamType::CLIENT) ? + config.clientNoContextTakeover : config.serverNoContextTakeover; + + if (takeOverSetting == true) { + // This `Key` is a repeat; invalid config. + return nullptr; + } + + if (param.val != nullptr) { + // The `x_no_context_takeover` parameter shouldn't have a value; invalid config. + return nullptr; + } + + takeOverSetting = true; + } else if (property == ParamType::MAX_WINDOW_BITS) { + auto& maxBitsSetting = + (side == ParamType::CLIENT) ? config.clientMaxWindowBits : config.serverMaxWindowBits; + + if (maxBitsSetting != nullptr) { + // This `Key` is a repeat; invalid config. + return nullptr; + } + + KJ_IF_MAYBE(value, param.val) { + if (value->size() == 0) { + // This is equivalent to `x_max_window_bits=`, since we got an "=" we expected a token + // to follow. + return nullptr; + } + maxBitsSetting = param.val; + } else { + // We know we got this `max_window_bits` parameter in a Request/Response, and we also know + // that it didn't include an "=" (otherwise the value wouldn't be null). + // It's important to retain the information that the parameter was received *without* a + // corresponding value, as this may determine whether the offer is valid or not. + // + // To retain this information, we'll set `maxBitsSetting` to be an empty ArrayPtr so this + // can be dealt with properly later. + maxBitsSetting = ArrayPtr(); + } + } + } else { + // Invalid parameter. + return nullptr; + } + } + return kj::mv(config); +} + +kj::Maybe validateCompressionConfig(UnverifiedConfig&& config, + bool isAgreement) { + // Verifies that the `config` is valid depending on whether we're validating a Request (offer) or + // a Response (agreement). This essentially consumes the `UnverifiedConfig` and converts it into a + // `CompressionParameters` struct. + CompressionParameters result; + + KJ_IF_MAYBE(serverBits, config.serverMaxWindowBits) { + if (serverBits->size() == 0) { + // This means `server_max_window_bits` was passed without a value. Since a value is required, + // this config is invalid. + return nullptr; + } else { + KJ_IF_MAYBE(bits, kj::str(*serverBits).tryParseAs()) { + if (*bits < 8 || 15 < *bits) { + // Out of range -- invalid. + return nullptr; + } + if (isAgreement) { + result.inboundMaxWindowBits = *bits; + } else { + result.outboundMaxWindowBits = *bits; + } + } else { + // Invalid ABNF, expected 1*DIGIT. + return nullptr; + } + } + } + + KJ_IF_MAYBE(clientBits, config.clientMaxWindowBits) { + if (clientBits->size() == 0) { + if (!isAgreement) { + // `client_max_window_bits` does not need to have a value in an offer, let's set it to 15 + // to get the best level of compression. + result.inboundMaxWindowBits = 15; + } else { + // `client_max_window_bits` must have a value in a Response. + return nullptr; + } + } else { + KJ_IF_MAYBE(bits, kj::str(*clientBits).tryParseAs()) { + if (*bits < 8 || 15 < *bits) { + // Out of range -- invalid. + return nullptr; + } + if (isAgreement) { + result.outboundMaxWindowBits = *bits; + } else { + result.inboundMaxWindowBits = *bits; + } + } else { + // Invalid ABNF, expected 1*DIGIT. + return nullptr; + } + } + } + + if (isAgreement) { + result.outboundNoContextTakeover = config.clientNoContextTakeover; + result.inboundNoContextTakeover = config.serverNoContextTakeover; + } else { + result.inboundNoContextTakeover = config.clientNoContextTakeover; + result.outboundNoContextTakeover = config.serverNoContextTakeover; + } + return kj::mv(result); +} + +inline kj::Maybe tryExtractParameters( + kj::Vector>& configuration, + bool isAgreement) { + // If the `configuration` is structured correctly and has no invalid parameters/values, we will + // return a populated `CompressionParameters` struct. + if (configuration.size() == 1) { + // Plain `permessage-deflate`. + return CompressionParameters{}; + } + auto params = configuration.slice(1, configuration.size()); + auto keyMaybeValuePairs = toKeysAndVals(params); + // Parse parameter strings into parameter[=value] pairs. + auto maybeUnverified = populateUnverifiedConfig(keyMaybeValuePairs); + KJ_IF_MAYBE(unverified, maybeUnverified) { + // Parsing succeeded, i.e. the parameter (`key`) names are valid and we don't have + // values for `x_no_context_takeover` parameters (the configuration is structured correctly). + // All that's left is to check the `x_max_window_bits` values (if any are present). + KJ_IF_MAYBE(validConfig, validateCompressionConfig(kj::mv(*unverified), isAgreement)) { + return kj::mv(*validConfig); + } + } + return nullptr; +} + +kj::Vector findValidExtensionOffers(StringPtr offers) { + // A function to be called by the client that wants to offer extensions through + // `Sec-WebSocket-Extensions`. This function takes the value of the header (a string) and + // populates a Vector of all the valid offers. + kj::Vector result; + + auto extensions = splitParts(offers, ','); + + for (const auto& offer : extensions) { + auto splitOffer = splitParts(offer, ';'); + if (splitOffer.front() != "permessage-deflate"_kj) { + continue; + } + KJ_IF_MAYBE(validated, tryExtractParameters(splitOffer, false)) { + // We need to swap the inbound/outbound properties since `tryExtractParameters` thinks we're + // parsing as the server (`isAgreement` is false). + auto tempCtx = validated->inboundNoContextTakeover; + validated->inboundNoContextTakeover = validated->outboundNoContextTakeover; + validated->outboundNoContextTakeover = tempCtx; + auto tempWindow = validated->inboundMaxWindowBits; + validated->inboundMaxWindowBits = validated->outboundMaxWindowBits; + validated->outboundMaxWindowBits = tempWindow; + result.add(kj::mv(*validated)); + } + } + + return kj::mv(result); +} + +kj::String generateExtensionRequest(const ArrayPtr& extensions) { + // Build the `Sec-WebSocket-Extensions` request from the validated parameters. + constexpr auto EXT = "permessage-deflate"_kj; + auto offers = kj::heapArray(extensions.size()); + size_t i = 0; + for (const auto& offer : extensions) { + offers[i] = kj::str(EXT); + if (offer.outboundNoContextTakeover) { + offers[i] = kj::str(offers[i], "; client_no_context_takeover"); + } + if (offer.inboundNoContextTakeover) { + offers[i] = kj::str(offers[i], "; server_no_context_takeover"); + } + if (offer.outboundMaxWindowBits != nullptr) { + auto w = KJ_ASSERT_NONNULL(offer.outboundMaxWindowBits); + offers[i] = kj::str(offers[i], "; client_max_window_bits=", w); + } + if (offer.inboundMaxWindowBits != nullptr) { + auto w = KJ_ASSERT_NONNULL(offer.inboundMaxWindowBits); + offers[i] = kj::str(offers[i], "; server_max_window_bits=", w); + } + ++i; + } + return kj::strArray(offers, ", "); +} + +kj::Maybe tryParseExtensionOffers(StringPtr offers) { + // Given a string of offers, accept the first valid offer by returning a `CompressionParameters` + // struct. If there are no valid offers, return `nullptr`. + auto splitOffers = splitParts(offers, ','); + + for (const auto& offer : splitOffers) { + auto splitOffer = splitParts(offer, ';'); + + if (splitOffer.front() != "permessage-deflate"_kj) { + // Extension token was invalid. + continue; + } + KJ_IF_MAYBE(config, tryExtractParameters(splitOffer, false)) { + return kj::mv(*config); + } + } + return nullptr; +} + +kj::Maybe tryParseAllExtensionOffers(StringPtr offers, + CompressionParameters manualConfig) { + // Similar to `tryParseExtensionOffers()`, however, this function is called when parsing in + // `MANUAL_COMPRESSION` mode. In some cases, the server's configuration might not support the + // `server_no_context_takeover` or `server_max_window_bits` parameters. Essentially, this function + // will look at all the client's offers, and accept the first one that it can support. + // + // We differentiate these functions because in `AUTOMATIC_COMPRESSION` mode, KJ can support these + // server restricting compression parameters. + auto splitOffers = splitParts(offers, ','); + + for (const auto& offer : splitOffers) { + auto splitOffer = splitParts(offer, ';'); + + if (splitOffer.front() != "permessage-deflate"_kj) { + // Extension token was invalid. + continue; + } + KJ_IF_MAYBE(config, tryExtractParameters(splitOffer, false)) { + KJ_IF_MAYBE(finalConfig, compareClientAndServerConfigs(*config, manualConfig)) { + // Found a compatible configuration between the server's config and client's offer. + return kj::mv(*finalConfig); + } + } + } + return nullptr; +} + +kj::Maybe compareClientAndServerConfigs(CompressionParameters requestConfig, + CompressionParameters manualConfig) { + // We start from the `manualConfig` and go through a series of filters to get a compression + // configuration that both the client and the server can agree upon. If no agreement can be made, + // we return null. + + CompressionParameters acceptedParameters = manualConfig; + + // We only need to modify `client_no_context_takeover` and `server_no_context_takeover` when + // `manualConfig` doesn't include them. + if (manualConfig.inboundNoContextTakeover == false) { + acceptedParameters.inboundNoContextTakeover = false; + } + + if (manualConfig.outboundNoContextTakeover == false) { + acceptedParameters.outboundNoContextTakeover = false; + if (requestConfig.outboundNoContextTakeover == true) { + // The client has told the server to not use context takeover. This is not a "hint", + // rather it is a restriction on the server's configuration. If the server does not support + // the configuration, it must reject the offer. + return nullptr; + } + } + + // client_max_window_bits + if (requestConfig.inboundMaxWindowBits != nullptr && + manualConfig.inboundMaxWindowBits != nullptr) { + // We want `min(requestConfig, manualConfig)` in this case. + auto reqBits = KJ_ASSERT_NONNULL(requestConfig.inboundMaxWindowBits); + auto manualBits = KJ_ASSERT_NONNULL(manualConfig.inboundMaxWindowBits); + if (reqBits < manualBits) { + acceptedParameters.inboundMaxWindowBits = reqBits; + } + } else { + // We will not reply with `client_max_window_bits`. + acceptedParameters.inboundMaxWindowBits = nullptr; + } + + // server_max_window_bits + if (manualConfig.outboundMaxWindowBits != nullptr) { + auto manualBits = KJ_ASSERT_NONNULL(manualConfig.outboundMaxWindowBits); + if (requestConfig.outboundMaxWindowBits != nullptr) { + // We want `min(requestConfig, manualConfig)` in this case. + auto reqBits = KJ_ASSERT_NONNULL(requestConfig.outboundMaxWindowBits); + if (reqBits < manualBits) { + acceptedParameters.outboundMaxWindowBits = reqBits; + } + } + } else { + acceptedParameters.outboundMaxWindowBits = nullptr; + if (requestConfig.outboundMaxWindowBits != nullptr) { + // The client has told the server to use `server_max_window_bits`. This is not a "hint", + // rather it is a restriction on the server's configuration. If the server does not support + // the configuration, it must reject the offer. + return nullptr; + } + } + return acceptedParameters; +} + +kj::String generateExtensionResponse(const CompressionParameters& parameters) { + // Build the `Sec-WebSocket-Extensions` response from the agreed parameters. + kj::String response = kj::str("permessage-deflate"); + if (parameters.inboundNoContextTakeover) { + response = kj::str(response, "; client_no_context_takeover"); + } + if (parameters.outboundNoContextTakeover) { + response = kj::str(response, "; server_no_context_takeover"); + } + if (parameters.inboundMaxWindowBits != nullptr) { + auto w = KJ_REQUIRE_NONNULL(parameters.inboundMaxWindowBits); + response = kj::str(response, "; client_max_window_bits=", w); + } + if (parameters.outboundMaxWindowBits != nullptr) { + auto w = KJ_REQUIRE_NONNULL(parameters.outboundMaxWindowBits); + response = kj::str(response, "; server_max_window_bits=", w); + } + return kj::mv(response); +} + +kj::OneOf tryParseExtensionAgreement( + const Maybe& clientOffer, + StringPtr agreedParameters) { + // Like `tryParseExtensionOffers`, but called by the client when parsing the server's Response. + // If the client must decline the agreement, we want to provide some details about what went wrong + // (since the client has to fail the connection). + constexpr auto FAILURE = "Server failed WebSocket handshake: "_kj; + auto e = KJ_EXCEPTION(FAILED); + + if (clientOffer == nullptr) { + // We've received extensions when we did not send any in the first place. + e.setDescription( + kj::str(FAILURE, "added Sec-WebSocket-Extensions when client did not offer any.")); + return kj::mv(e); + } + + auto offers = splitParts(agreedParameters, ','); + if (offers.size() != 1) { + constexpr auto EXPECT = "expected exactly one extension (permessage-deflate) but received " + "more than one."_kj; + e.setDescription(kj::str(FAILURE, EXPECT)); + return kj::mv(e); + } + auto splitOffer = splitParts(offers.front(), ';'); + + if (splitOffer.front() != "permessage-deflate"_kj) { + e.setDescription(kj::str(FAILURE, "response included a Sec-WebSocket-Extensions value that was " + "not permessage-deflate.")); + return kj::mv(e); + } + + // Verify the parameters of our single extension, and compare it with the clients original offer. + KJ_IF_MAYBE(config, tryExtractParameters(splitOffer, true)) { + const auto& client = KJ_ASSERT_NONNULL(clientOffer); + // The server might have ignored the client's hints regarding its compressor's configuration. + // That's fine, but as the client, we still want to use those outbound compression parameters. + if (config->outboundMaxWindowBits == nullptr) { + config->outboundMaxWindowBits = client.outboundMaxWindowBits; + } else KJ_IF_MAYBE(value, client.outboundMaxWindowBits) { + if (*value < KJ_ASSERT_NONNULL(config->outboundMaxWindowBits)) { + // If the client asked for a value smaller than what the server responded with, use the + // value that the client originally specified. + config->outboundMaxWindowBits = *value; + } + } + if (config->outboundNoContextTakeover == false) { + config->outboundNoContextTakeover = client.outboundNoContextTakeover; + } + return kj::mv(*config); + } + + // There was a problem parsing the server's `Sec-WebSocket-Extensions` response. + e.setDescription(kj::str(FAILURE, "the Sec-WebSocket-Extensions header in the Response included " + "an invalid value.")); + return kj::mv(e); +} +} // namespace _ (private) +namespace { +class NullInputStream final: public kj::AsyncInputStream { +public: + NullInputStream(kj::Maybe expectedLength = size_t(0)) + : expectedLength(expectedLength) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return constPromise(); + } + + kj::Maybe tryGetLength() override { + return expectedLength; + } + + kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return constPromise(); + } + +private: + kj::Maybe expectedLength; +}; + +class NullOutputStream final: public kj::AsyncOutputStream { +public: + Promise write(const void* buffer, size_t size) override { + return kj::READY_NOW; + } + Promise write(ArrayPtr> pieces) override { + return kj::READY_NOW; + } + Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. +}; + +class NullIoStream final: public kj::AsyncIoStream { +public: + void shutdownWrite() override {} + + Promise write(const void* buffer, size_t size) override { + return kj::READY_NOW; + } + Promise write(ArrayPtr> pieces) override { + return kj::READY_NOW; + } + Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return constPromise(); + } + + kj::Maybe tryGetLength() override { + return kj::Maybe((uint64_t)0); + } + + kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return constPromise(); + } +}; + +class HttpClientImpl final: public HttpClient, + private HttpClientErrorHandler { +public: + HttpClientImpl(const HttpHeaderTable& responseHeaderTable, kj::Own rawStream, + HttpClientSettings settings) + : httpInput(*rawStream, responseHeaderTable), + httpOutput(*rawStream), + ownStream(kj::mv(rawStream)), + settings(kj::mv(settings)) {} + + bool canReuse() { + // Returns true if we can immediately reuse this HttpClient for another message (so all + // previous messages have been fully read). + + return !upgraded && !closed && httpInput.canReuse() && httpOutput.canReuse(); + } + + Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_REQUIRE(!upgraded, + "can't make further requests on this HttpClient because it has been or is in the process " + "of being upgraded"); + KJ_REQUIRE(!closed, + "this HttpClient's connection has been closed by the server or due to an error"); + KJ_REQUIRE(httpOutput.canReuse(), + "can't start new request until previous request body has been fully written"); + closeWatcherTask = nullptr; + + kj::StringPtr connectionHeaders[HttpHeaders::CONNECTION_HEADERS_COUNT]; + kj::String lengthStr; + + bool isGet = method == HttpMethod::GET || method == HttpMethod::HEAD; + bool hasBody; + + KJ_IF_MAYBE(s, expectedBodySize) { + if (isGet && *s == 0) { + // GET with empty body; don't send any Content-Length. + hasBody = false; + } else { + lengthStr = kj::str(*s); + connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = lengthStr; + hasBody = true; + } + } else { + if (isGet && headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr) { + // GET with empty body; don't send any Transfer-Encoding. + hasBody = false; + } else { + // HACK: Normally GET requests shouldn't have bodies. But, if the caller set a + // Transfer-Encoding header on a GET, we use this as a special signal that it might + // actually want to send a body. This allows pass-through of a GET request with a chunked + // body to "just work". We strongly discourage writing any new code that sends + // full-bodied GETs. + connectionHeaders[HttpHeaders::BuiltinIndices::TRANSFER_ENCODING] = "chunked"; + hasBody = true; + } + } + + httpOutput.writeHeaders(headers.serializeRequest(method, url, connectionHeaders)); + + kj::Own bodyStream; + if (!hasBody) { + // No entity-body. + httpOutput.finishBody(); + bodyStream = heap(); + } else KJ_IF_MAYBE(s, expectedBodySize) { + bodyStream = heap(httpOutput, *s); + } else { + bodyStream = heap(httpOutput); + } + + auto id = ++counter; + + auto responsePromise = httpInput.readResponseHeaders().then( + [this,method,id](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) + -> HttpClient::Response { + KJ_SWITCH_ONEOF(responseOrProtocolError) { + KJ_CASE_ONEOF(response, HttpHeaders::Response) { + auto& responseHeaders = httpInput.getHeaders(); + HttpClient::Response result { + response.statusCode, + response.statusText, + &responseHeaders, + httpInput.getEntityBody( + HttpInputStreamImpl::RESPONSE, method, response.statusCode, responseHeaders) + }; + + if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>( + responseHeaders.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) { + closed = true; + } else if (counter == id) { + watchForClose(); + } else { + // Another request was already queued after this one, so we don't want to watch for + // stream closure because we're fully expecting another response. + } + return result; + } + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + closed = true; + return settings.errorHandler.orDefault(*this).handleProtocolError( + kj::mv(protocolError)); + } + } + + KJ_UNREACHABLE; + }); + + return { kj::mv(bodyStream), kj::mv(responsePromise) }; + } + + kj::Promise openWebSocket( + kj::StringPtr url, const HttpHeaders& headers) override { + KJ_REQUIRE(!upgraded, + "can't make further requests on this HttpClient because it has been or is in the process " + "of being upgraded"); + KJ_REQUIRE(!closed, + "this HttpClient's connection has been closed by the server or due to an error"); + closeWatcherTask = nullptr; + + // Mark upgraded for now, even though the upgrade could fail, because we can't allow pipelined + // requests in the meantime. + upgraded = true; + + byte keyBytes[16]; + KJ_ASSERT_NONNULL(settings.entropySource, + "can't use openWebSocket() because no EntropySource was provided when creating the " + "HttpClient").generate(keyBytes); + auto keyBase64 = kj::encodeBase64(keyBytes); + + kj::StringPtr connectionHeaders[HttpHeaders::WEBSOCKET_CONNECTION_HEADERS_COUNT]; + connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "Upgrade"; + connectionHeaders[HttpHeaders::BuiltinIndices::UPGRADE] = "websocket"; + connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_VERSION] = "13"; + connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_KEY] = keyBase64; + + kj::Maybe offeredExtensions; + kj::Maybe clientOffer; + kj::Vector extensions; + auto compressionMode = settings.webSocketCompressionMode; + + if (compressionMode == HttpClientSettings::MANUAL_COMPRESSION) { + KJ_IF_MAYBE(value, headers.get(HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + // Strip all `Sec-WebSocket-Extensions` except for `permessage-deflate`. + extensions = _::findValidExtensionOffers(*value); + } + } else if (compressionMode == HttpClientSettings::AUTOMATIC_COMPRESSION) { + // If AUTOMATIC_COMPRESSION is enabled, we send `Sec-WebSocket-Extensions: permessage-deflate` + // to the server and ignore the `headers` provided by the caller. + extensions.add(CompressionParameters()); + } + + if (extensions.size() > 0) { + clientOffer = extensions.front(); + // We hold on to a copy of the client's most preferred offer so even if the server + // ignores `client_no_context_takeover` or `client_max_window_bits`, we can still refer to + // the original offer made by the client (thereby allowing the client to use these parameters). + // + // It's safe to ignore the remaining offers because: + // 1. Offers are ordered by preference. + // 2. `client_x` parameters are hints to the server and do not result in rejections, so the + // client is likely to put them in every offer anyways. + connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_EXTENSIONS] = + offeredExtensions.emplace(_::generateExtensionRequest(extensions.asPtr())); + } + + httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders)); + + // No entity-body. + httpOutput.finishBody(); + + auto id = ++counter; + + return httpInput.readResponseHeaders() + .then([this,id,keyBase64 = kj::mv(keyBase64),clientOffer = kj::mv(clientOffer)]( + HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) + -> HttpClient::WebSocketResponse { + KJ_SWITCH_ONEOF(responseOrProtocolError) { + KJ_CASE_ONEOF(response, HttpHeaders::Response) { + auto& responseHeaders = httpInput.getHeaders(); + if (response.statusCode == 101) { + if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>( + responseHeaders.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) { + kj::String ownMessage; + kj::StringPtr message; + KJ_IF_MAYBE(actual, responseHeaders.get(HttpHeaderId::UPGRADE)) { + ownMessage = kj::str( + "Server failed WebSocket handshake: incorrect Upgrade header: " + "expected 'websocket', got '", *actual, "'."); + message = ownMessage; + } else { + message = "Server failed WebSocket handshake: missing Upgrade header."; + } + return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError({ + 502, "Bad Gateway", message, nullptr + }); + } + + auto expectedAccept = generateWebSocketAccept(keyBase64); + if (responseHeaders.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault(nullptr) + != expectedAccept) { + kj::String ownMessage; + kj::StringPtr message; + KJ_IF_MAYBE(actual, responseHeaders.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT)) { + ownMessage = kj::str( + "Server failed WebSocket handshake: incorrect Sec-WebSocket-Accept header: " + "expected '", expectedAccept, "', got '", *actual, "'."); + message = ownMessage; + } else { + message = "Server failed WebSocket handshake: missing Upgrade header."; + } + return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError({ + 502, "Bad Gateway", message, nullptr + }); + } + + kj::Maybe compressionParameters; + if (settings.webSocketCompressionMode != HttpClientSettings::NO_COMPRESSION) { + KJ_IF_MAYBE(agreedParameters, responseHeaders.get( + HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + + auto parseResult = _::tryParseExtensionAgreement(clientOffer, + *agreedParameters); + if (parseResult.is()) { + return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError({ + 502, "Bad Gateway", parseResult.get().getDescription(), nullptr}); + } + compressionParameters.emplace(kj::mv(parseResult.get())); + } + } + + return { + response.statusCode, + response.statusText, + &httpInput.getHeaders(), + upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, settings.entropySource, + kj::mv(compressionParameters)), + }; + } else { + upgraded = false; + HttpClient::WebSocketResponse result { + response.statusCode, + response.statusText, + &responseHeaders, + httpInput.getEntityBody(HttpInputStreamImpl::RESPONSE, HttpMethod::GET, + response.statusCode, responseHeaders) + }; + if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>( + responseHeaders.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) { + closed = true; + } else if (counter == id) { + watchForClose(); + } else { + // Another request was already queued after this one, so we don't want to watch for + // stream closure because we're fully expecting another response. + } + return result; + } + } + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError( + kj::mv(protocolError)); + } + } + + KJ_UNREACHABLE; + }); + } + + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) override { + KJ_REQUIRE(!upgraded, + "can't make further requests on this HttpClient because it has been or is in the process " + "of being upgraded"); + KJ_REQUIRE(!closed, + "this HttpClient's connection has been closed by the server or due to an error"); + KJ_REQUIRE(httpOutput.canReuse(), + "can't start new request until previous request body has been fully written"); + + if (settings.useTls) { + KJ_UNIMPLEMENTED("This HttpClient does not support TLS."); + } + + closeWatcherTask = nullptr; + + // Mark upgraded for now even though the tunnel could fail, because we can't allow pipelined + // requests in the meantime. + upgraded = true; + + kj::StringPtr connectionHeaders[HttpHeaders::CONNECTION_HEADERS_COUNT]; + + httpOutput.writeHeaders(headers.serializeConnectRequest(host, connectionHeaders)); + + auto id = ++counter; + + auto split = httpInput.readResponseHeaders().then( + [this, id](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) mutable + -> kj::Tuple, + kj::Promise>> { + KJ_SWITCH_ONEOF(responseOrProtocolError) { + KJ_CASE_ONEOF(response, HttpHeaders::Response) { + auto& responseHeaders = httpInput.getHeaders(); + if (response.statusCode < 200 || response.statusCode >= 300) { + // Any statusCode that is not in the 2xx range in interpreted + // as an HTTP response. Any status code in the 2xx range is + // interpreted as a successful CONNECT response. + closed = true; + return kj::tuple(ConnectRequest::Status( + response.statusCode, + kj::str(response.statusText), + kj::heap(responseHeaders.clone()), + httpInput.getEntityBody( + HttpInputStreamImpl::RESPONSE, + HttpConnectMethod(), + response.statusCode, + responseHeaders)), + KJ_EXCEPTION(DISCONNECTED, "the connect request was rejected")); + } + KJ_ASSERT(counter == id); + return kj::tuple(ConnectRequest::Status( + response.statusCode, + kj::str(response.statusText), + kj::heap(responseHeaders.clone()) + ), kj::Maybe(httpInput.releaseBuffer())); + } + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + closed = true; + auto response = handleProtocolError(protocolError); + return kj::tuple(ConnectRequest::Status( + response.statusCode, + kj::str(response.statusText), + kj::heap(response.headers->clone()), + kj::mv(response.body) + ), KJ_EXCEPTION(DISCONNECTED, "the connect request errored")); + } + } + KJ_UNREACHABLE; + }).split(); + + return ConnectRequest { + kj::mv(kj::get<0>(split)), // Promise for the result + heap( + kj::mv(ownStream), + kj::mv(kj::get<1>(split)) /* read guard (Promise for the ReleasedBuffer) */, + httpOutput.flush() /* write guard (void Promise) */) + }; + } + +private: + HttpInputStreamImpl httpInput; + HttpOutputStream httpOutput; + kj::Own ownStream; + HttpClientSettings settings; + kj::Maybe> closeWatcherTask; + bool upgraded = false; + bool closed = false; + + uint counter = 0; + // Counts requests for the sole purpose of detecting if more requests have been made after some + // point in history. + + void watchForClose() { + closeWatcherTask = httpInput.awaitNextMessage() + .then([this](bool hasData) -> kj::Promise { + if (hasData) { + // Uhh... The server sent some data before we asked for anything. Perhaps due to properties + // of this application, the server somehow already knows what the next request will be, and + // it is trying to optimize. Or maybe this is some sort of test and the server is just + // replaying a script. In any case, we will humor it -- leave the data in the buffer and + // let it become the response to the next request. + return kj::READY_NOW; + } else { + // EOF -- server disconnected. + closed = true; + if (httpOutput.isInBody()) { + // Huh, the application is still sending a request. We should let it finish. We do not + // need to proactively free the socket in this case because we know that we're not + // sitting in a reusable connection pool, because we know the application is still + // actively using the connection. + return kj::READY_NOW; + } else { + return httpOutput.flush().then([this]() { + // We might be sitting in NetworkAddressHttpClient's `availableClients` pool. We don't + // have a way to notify it to remove this client from the pool; instead, when it tries + // to pull this client from the pool later, it will notice the client is dead and will + // discard it then. But, we would like to avoid holding on to a socket forever. So, + // destroy the socket now. + // TODO(cleanup): Maybe we should arrange to proactively remove ourselves? Seems + // like the code will be awkward. + ownStream = nullptr; + }); + } + } + }).eagerlyEvaluate(nullptr); + } +}; + +} // namespace + +kj::Promise HttpClient::openWebSocket( + kj::StringPtr url, const HttpHeaders& headers) { + return request(HttpMethod::GET, url, headers, nullptr) + .response.then([](HttpClient::Response&& response) -> WebSocketResponse { + kj::OneOf, kj::Own> body; + body.init>(kj::mv(response.body)); + + return { + response.statusCode, + response.statusText, + response.headers, + kj::mv(body) + }; + }); +} + +HttpClient::ConnectRequest HttpClient::connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) { + KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpClient"); +} + +kj::Own newHttpClient( + const HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream, + HttpClientSettings settings) { + return kj::heap(responseHeaderTable, + kj::Own(&stream, kj::NullDisposer::instance), + kj::mv(settings)); +} + +HttpClient::Response HttpClientErrorHandler::handleProtocolError( + HttpHeaders::ProtocolError protocolError) { + KJ_FAIL_REQUIRE(protocolError.description) { break; } + return HttpClient::Response(); +} + +HttpClient::WebSocketResponse HttpClientErrorHandler::handleWebSocketProtocolError( + HttpHeaders::ProtocolError protocolError) { + auto response = handleProtocolError(protocolError); + return HttpClient::WebSocketResponse { + response.statusCode, response.statusText, response.headers, kj::mv(response.body) + }; +} + +kj::Exception WebSocketErrorHandler::handleWebSocketProtocolError( + WebSocket::ProtocolError protocolError) { + return KJ_EXCEPTION(FAILED, "WebSocket protocol error", protocolError.statusCode, protocolError.description); +} + +class PausableReadAsyncIoStream::PausableRead { +public: + PausableRead( + kj::PromiseFulfiller& fulfiller, PausableReadAsyncIoStream& parent, + void* buffer, size_t minBytes, size_t maxBytes) + : fulfiller(fulfiller), parent(parent), + operationBuffer(buffer), operationMinBytes(minBytes), operationMaxBytes(maxBytes), + innerRead(parent.tryReadImpl(operationBuffer, operationMinBytes, operationMaxBytes).then( + [&fulfiller](size_t size) mutable -> kj::Promise { + fulfiller.fulfill(kj::mv(size)); + return kj::READY_NOW; + }, [&fulfiller](kj::Exception&& err) { + fulfiller.reject(kj::mv(err)); + })) { + KJ_ASSERT(parent.maybePausableRead == nullptr); + parent.maybePausableRead = *this; + } + + ~PausableRead() noexcept(false) { + parent.maybePausableRead = nullptr; + } + + void pause() { + innerRead = nullptr; + } + + void unpause() { + innerRead = parent.tryReadImpl(operationBuffer, operationMinBytes, operationMaxBytes).then( + [this](size_t size) -> kj::Promise { + fulfiller.fulfill(kj::mv(size)); + return kj::READY_NOW; + }, [this](kj::Exception&& err) { + fulfiller.reject(kj::mv(err)); + }); + } + + void reject(kj::Exception&& exc) { + fulfiller.reject(kj::mv(exc)); + } +private: + kj::PromiseFulfiller& fulfiller; + PausableReadAsyncIoStream& parent; + + void* operationBuffer; + size_t operationMinBytes; + size_t operationMaxBytes; + // The parameters of the current tryRead call. Used to unpause a paused read. + + kj::Promise innerRead; + // The current pending read. +}; + +_::Deferred> PausableReadAsyncIoStream::trackRead() { + KJ_REQUIRE(!currentlyReading, "only one read is allowed at any one time"); + currentlyReading = true; + return kj::defer>([this]() { currentlyReading = false; }); +} + +_::Deferred> PausableReadAsyncIoStream::trackWrite() { + KJ_REQUIRE(!currentlyWriting, "only one write is allowed at any one time"); + currentlyWriting = true; + return kj::defer>([this]() { currentlyWriting = false; }); +} + +kj::Promise PausableReadAsyncIoStream::tryRead( + void* buffer, size_t minBytes, size_t maxBytes) { + return kj::newAdaptedPromise(*this, buffer, minBytes, maxBytes); +} + +kj::Promise PausableReadAsyncIoStream::tryReadImpl( + void* buffer, size_t minBytes, size_t maxBytes) { + // Hack: evalNow used here because `newAdaptedPromise` has a bug. We may need to change + // `PromiseDisposer::alloc` to not be `noexcept` but in order to do so we'll need to benchmark + // its performance. + return kj::evalNow([&]() -> kj::Promise { + return inner->tryRead(buffer, minBytes, maxBytes).attach(trackRead()); + }); +} + +kj::Maybe PausableReadAsyncIoStream::tryGetLength() { + return inner->tryGetLength(); +} + +kj::Promise PausableReadAsyncIoStream::pumpTo( + kj::AsyncOutputStream& output, uint64_t amount) { + return kj::unoptimizedPumpTo(*this, output, amount); +} + +kj::Promise PausableReadAsyncIoStream::write(const void* buffer, size_t size) { + return inner->write(buffer, size).attach(trackWrite()); +} + +kj::Promise PausableReadAsyncIoStream::write( + kj::ArrayPtr> pieces) { + return inner->write(pieces).attach(trackWrite()); +} + +kj::Maybe> PausableReadAsyncIoStream::tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount) { + auto result = inner->tryPumpFrom(input, amount); + KJ_IF_MAYBE(r, result) { + return r->attach(trackWrite()); + } else { + return nullptr; + } +} + +kj::Promise PausableReadAsyncIoStream::whenWriteDisconnected() { + return inner->whenWriteDisconnected(); +} + +void PausableReadAsyncIoStream::shutdownWrite() { + inner->shutdownWrite(); +} + +void PausableReadAsyncIoStream::abortRead() { + inner->abortRead(); +} + +kj::Maybe PausableReadAsyncIoStream::getFd() const { + return inner->getFd(); +} + +void PausableReadAsyncIoStream::pause() { + KJ_IF_MAYBE(pausable, maybePausableRead) { + pausable->pause(); + } +} + +void PausableReadAsyncIoStream::unpause() { + KJ_IF_MAYBE(pausable, maybePausableRead) { + pausable->unpause(); + } +} + +bool PausableReadAsyncIoStream::getCurrentlyReading() { + return currentlyReading; +} + +bool PausableReadAsyncIoStream::getCurrentlyWriting() { + return currentlyWriting; +} + +kj::Own PausableReadAsyncIoStream::takeStream() { + return kj::mv(inner); +} + +void PausableReadAsyncIoStream::replaceStream(kj::Own stream) { + inner = kj::mv(stream); +} + +void PausableReadAsyncIoStream::reject(kj::Exception&& exc) { + KJ_IF_MAYBE(pausable, maybePausableRead) { + pausable->reject(kj::mv(exc)); + } +} + +// ======================================================================================= + +namespace { + +class NetworkAddressHttpClient final: public HttpClient { +public: + NetworkAddressHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, + kj::Own address, HttpClientSettings settings) + : timer(timer), + responseHeaderTable(responseHeaderTable), + address(kj::mv(address)), + settings(kj::mv(settings)) {} + + bool isDrained() { + // Returns true if there are no open connections. + return activeConnectionCount == 0 && availableClients.empty(); + } + + kj::Promise onDrained() { + // Returns a promise which resolves the next time isDrained() transitions from false to true. + auto paf = kj::newPromiseAndFulfiller(); + drainedFulfiller = kj::mv(paf.fulfiller); + return kj::mv(paf.promise); + } + + Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + auto refcounted = getClient(); + auto result = refcounted->client->request(method, url, headers, expectedBodySize); + result.body = result.body.attach(kj::addRef(*refcounted)); + result.response = result.response.then( + [refcounted=kj::mv(refcounted)](Response&& response) mutable { + response.body = response.body.attach(kj::mv(refcounted)); + return kj::mv(response); + }); + return result; + } + + kj::Promise openWebSocket( + kj::StringPtr url, const HttpHeaders& headers) override { + auto refcounted = getClient(); + auto result = refcounted->client->openWebSocket(url, headers); + return result.then( + [refcounted=kj::mv(refcounted)](WebSocketResponse&& response) mutable { + KJ_SWITCH_ONEOF(response.webSocketOrBody) { + KJ_CASE_ONEOF(body, kj::Own) { + response.webSocketOrBody = body.attach(kj::mv(refcounted)); + } + KJ_CASE_ONEOF(ws, kj::Own) { + // The only reason we need to attach the client to the WebSocket is because otherwise + // the response headers will be deleted prematurely. Otherwise, the WebSocket has taken + // ownership of the connection. + // + // TODO(perf): Maybe we could transfer ownership of the response headers specifically? + response.webSocketOrBody = ws.attach(kj::mv(refcounted)); + } + } + return kj::mv(response); + }); + } + + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) override { + auto refcounted = getClient(); + auto request = refcounted->client->connect(host, headers, settings); + return ConnectRequest { + request.status.attach(kj::addRef(*refcounted)), + request.connection.attach(kj::mv(refcounted)) + }; + } + +private: + kj::Timer& timer; + const HttpHeaderTable& responseHeaderTable; + kj::Own address; + HttpClientSettings settings; + + kj::Maybe>> drainedFulfiller; + uint activeConnectionCount = 0; + + bool timeoutsScheduled = false; + kj::Promise timeoutTask = nullptr; + + struct AvailableClient { + kj::Own client; + kj::TimePoint expires; + }; + + std::deque availableClients; + + struct RefcountedClient final: public kj::Refcounted { + RefcountedClient(NetworkAddressHttpClient& parent, kj::Own client) + : parent(parent), client(kj::mv(client)) { + ++parent.activeConnectionCount; + } + ~RefcountedClient() noexcept(false) { + --parent.activeConnectionCount; + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + parent.returnClientToAvailable(kj::mv(client)); + })) { + KJ_LOG(ERROR, *exception); + } + } + + NetworkAddressHttpClient& parent; + kj::Own client; + }; + + kj::Own getClient() { + for (;;) { + if (availableClients.empty()) { + auto stream = newPromisedStream(address->connect()); + return kj::refcounted(*this, + kj::heap(responseHeaderTable, kj::mv(stream), settings)); + } else { + auto client = kj::mv(availableClients.back().client); + availableClients.pop_back(); + if (client->canReuse()) { + return kj::refcounted(*this, kj::mv(client)); + } + // Whoops, this client's connection was closed by the server at some point. Discard. + } + } + } + + void returnClientToAvailable(kj::Own client) { + // Only return the connection to the pool if it is reusable and if our settings indicate we + // should reuse connections. + if (client->canReuse() && settings.idleTimeout > 0 * kj::SECONDS) { + availableClients.push_back(AvailableClient { + kj::mv(client), timer.now() + settings.idleTimeout + }); + } + + // Call this either way because it also signals onDrained(). + if (!timeoutsScheduled) { + timeoutsScheduled = true; + timeoutTask = applyTimeouts(); + } + } + + kj::Promise applyTimeouts() { + if (availableClients.empty()) { + timeoutsScheduled = false; + if (activeConnectionCount == 0) { + KJ_IF_MAYBE(f, drainedFulfiller) { + f->get()->fulfill(); + drainedFulfiller = nullptr; + } + } + return kj::READY_NOW; + } else { + auto time = availableClients.front().expires; + return timer.atTime(time).then([this,time]() { + while (!availableClients.empty() && availableClients.front().expires <= time) { + availableClients.pop_front(); + } + return applyTimeouts(); + }); + } + } +}; + +class TransitionaryAsyncIoStream final: public kj::AsyncIoStream { + // This specialised AsyncIoStream is used by NetworkHttpClient to support startTls. +public: + TransitionaryAsyncIoStream(kj::Own unencryptedStream) + : inner(kj::heap(kj::mv(unencryptedStream))) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->tryRead(buffer, minBytes, maxBytes); + } + + kj::Maybe tryGetLength() override { + return inner->tryGetLength(); + } + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return inner->pumpTo(output, amount); + } + + kj::Promise write(const void* buffer, size_t size) override { + return inner->write(buffer, size); + } + + kj::Promise write(kj::ArrayPtr> pieces) override { + return inner->write(pieces); + } + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { + return inner->tryPumpFrom(input, amount); + } + + kj::Promise whenWriteDisconnected() override { + return inner->whenWriteDisconnected(); + } + + void shutdownWrite() override { + inner->shutdownWrite(); + } + + void abortRead() override { + inner->abortRead(); + } + + kj::Maybe getFd() const override { + return inner->getFd(); + } + + void startTls( + kj::SecureNetworkWrapper* wrapper, kj::StringPtr expectedServerHostname) { + // Pause any potential pending reads. + inner->pause(); + + KJ_ON_SCOPE_FAILURE({ + inner->reject(KJ_EXCEPTION(FAILED, "StartTls failed.")); + }); + + KJ_ASSERT(!inner->getCurrentlyReading() && !inner->getCurrentlyWriting(), + "Cannot call startTls while reads/writes are outstanding"); + kj::Promise> secureStream = + wrapper->wrapClient(inner->takeStream(), expectedServerHostname); + inner->replaceStream(kj::newPromisedStream(kj::mv(secureStream))); + // Resume any previous pending reads. + inner->unpause(); + } + +private: + kj::Own inner; +}; + +class PromiseNetworkAddressHttpClient final: public HttpClient { + // An HttpClient which waits for a promise to resolve then forwards all calls to the promised + // client. + +public: + PromiseNetworkAddressHttpClient(kj::Promise> promise) + : promise(promise.then([this](kj::Own&& client) { + this->client = kj::mv(client); + }).fork()) {} + + bool isDrained() { + KJ_IF_MAYBE(c, client) { + return c->get()->isDrained(); + } else { + return failed; + } + } + + kj::Promise onDrained() { + KJ_IF_MAYBE(c, client) { + return c->get()->onDrained(); + } else { + return promise.addBranch().then([this]() { + return KJ_ASSERT_NONNULL(client)->onDrained(); + }, [this](kj::Exception&& e) { + // Connecting failed. Treat as immediately drained. + failed = true; + return kj::READY_NOW; + }); + } + } + + Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_IF_MAYBE(c, client) { + return c->get()->request(method, url, headers, expectedBodySize); + } else { + // This gets complicated since request() returns a pair of a stream and a promise. + auto urlCopy = kj::str(url); + auto headersCopy = headers.clone(); + auto combined = promise.addBranch().then( + [this,method,expectedBodySize,url=kj::mv(urlCopy), headers=kj::mv(headersCopy)]() + -> kj::Tuple, kj::Promise> { + auto req = KJ_ASSERT_NONNULL(client)->request(method, url, headers, expectedBodySize); + return kj::tuple(kj::mv(req.body), kj::mv(req.response)); + }); + + auto split = combined.split(); + return { + newPromisedStream(kj::mv(kj::get<0>(split))), + kj::mv(kj::get<1>(split)) + }; + } + } + + kj::Promise openWebSocket( + kj::StringPtr url, const HttpHeaders& headers) override { + KJ_IF_MAYBE(c, client) { + return c->get()->openWebSocket(url, headers); + } else { + auto urlCopy = kj::str(url); + auto headersCopy = headers.clone(); + return promise.addBranch().then( + [this,url=kj::mv(urlCopy),headers=kj::mv(headersCopy)]() { + return KJ_ASSERT_NONNULL(client)->openWebSocket(url, headers); + }); + } + } + + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) override { + KJ_IF_MAYBE(c, client) { + return c->get()->connect(host, headers, settings); + } else { + auto split = promise.addBranch().then( + [this, host=kj::str(host), headers=headers.clone(), settings]() mutable + -> kj::Tuple, + kj::Promise>> { + auto request = KJ_ASSERT_NONNULL(client)->connect(host, headers, kj::mv(settings)); + return kj::tuple(kj::mv(request.status), kj::mv(request.connection)); + }).split(); + + return ConnectRequest { + kj::mv(kj::get<0>(split)), + kj::newPromisedStream(kj::mv(kj::get<1>(split))) + }; + } + } + +private: + kj::ForkedPromise promise; + kj::Maybe> client; + bool failed = false; +}; + +class NetworkHttpClient final: public HttpClient, private kj::TaskSet::ErrorHandler { +public: + NetworkHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, + kj::Network& network, kj::Maybe tlsNetwork, + HttpClientSettings settings) + : timer(timer), + responseHeaderTable(responseHeaderTable), + network(network), + tlsNetwork(tlsNetwork), + settings(kj::mv(settings)), + tasks(*this) {} + + Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + // We need to parse the proxy-style URL to convert it to host-style. + // Use URL parsing options that avoid unnecessary rewrites. + Url::Options urlOptions; + urlOptions.allowEmpty = true; + urlOptions.percentDecode = false; + + auto parsed = Url::parse(url, Url::HTTP_PROXY_REQUEST, urlOptions); + auto path = parsed.toString(Url::HTTP_REQUEST); + auto headersCopy = headers.clone(); + headersCopy.set(HttpHeaderId::HOST, parsed.host); + return getClient(parsed).request(method, path, headersCopy, expectedBodySize); + } + + kj::Promise openWebSocket( + kj::StringPtr url, const HttpHeaders& headers) override { + // We need to parse the proxy-style URL to convert it to host-style. + // Use URL parsing options that avoid unnecessary rewrites. + Url::Options urlOptions; + urlOptions.allowEmpty = true; + urlOptions.percentDecode = false; + + auto parsed = Url::parse(url, Url::HTTP_PROXY_REQUEST, urlOptions); + auto path = parsed.toString(Url::HTTP_REQUEST); + auto headersCopy = headers.clone(); + headersCopy.set(HttpHeaderId::HOST, parsed.host); + return getClient(parsed).openWebSocket(path, headersCopy); + } + + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, + HttpConnectSettings connectSettings) override { + // We want to connect directly instead of going through a proxy here. + // https://github.com/capnproto/capnproto/pull/1454#discussion_r900414879 + kj::Maybe>> addr; + if (connectSettings.useTls) { + kj::Network& tlsNet = KJ_REQUIRE_NONNULL(tlsNetwork, "this HttpClient doesn't support TLS"); + addr = tlsNet.parseAddress(host); + } else { + addr = network.parseAddress(host); + } + + auto split = KJ_ASSERT_NONNULL(addr).then([this](auto address) { + return address->connect().then([this](auto connection) + -> kj::Tuple, + kj::Promise>> { + return kj::tuple( + ConnectRequest::Status( + 200, + kj::str("OK"), + kj::heap(responseHeaderTable) // Empty headers + ), + kj::mv(connection)); + }).attach(kj::mv(address)); + }).split(); + + auto connection = kj::newPromisedStream(kj::mv(kj::get<1>(split))); + + if (!connectSettings.useTls) { + KJ_IF_MAYBE(wrapper, settings.tlsContext) { + KJ_IF_MAYBE(tlsStarter, connectSettings.tlsStarter) { + auto transitConnectionRef = kj::refcountedWrapper( + kj::heap(kj::mv(connection))); + Function(kj::StringPtr)> cb = + [wrapper, ref1 = transitConnectionRef->addWrappedRef()]( + kj::StringPtr expectedServerHostname) mutable { + ref1->startTls(wrapper, expectedServerHostname); + return kj::READY_NOW; + }; + connection = transitConnectionRef->addWrappedRef(); + *tlsStarter = kj::mv(cb); + } + } + } + + return ConnectRequest { + kj::mv(kj::get<0>(split)), + kj::mv(connection) + }; + } + +private: + kj::Timer& timer; + const HttpHeaderTable& responseHeaderTable; + kj::Network& network; + kj::Maybe tlsNetwork; + HttpClientSettings settings; + + struct Host { + kj::String name; // including port, if non-default + kj::Own client; + }; + + std::map httpHosts; + std::map httpsHosts; + + struct RequestInfo { + HttpMethod method; + kj::String hostname; + kj::String path; + HttpHeaders headers; + kj::Maybe expectedBodySize; + }; + + kj::TaskSet tasks; + + HttpClient& getClient(kj::Url& parsed) { + bool isHttps = parsed.scheme == "https"; + bool isHttp = parsed.scheme == "http"; + KJ_REQUIRE(isHttp || isHttps); + + auto& hosts = isHttps ? httpsHosts : httpHosts; + + // Look for a cached client for this host. + // TODO(perf): It would be nice to recognize when different hosts have the same address and + // reuse the same connection pool, but: + // - We'd need a reliable way to compare NetworkAddresses, e.g. .equals() and .hashCode(). + // It's very Java... ick. + // - Correctly handling TLS would be tricky: we'd need to verify that the new hostname is + // on the certificate. When SNI is in use we might have to request an additional + // certificate (is that possible?). + auto iter = hosts.find(parsed.host); + + if (iter == hosts.end()) { + // Need to open a new connection. + kj::Network* networkToUse = &network; + if (isHttps) { + networkToUse = &KJ_REQUIRE_NONNULL(tlsNetwork, "this HttpClient doesn't support HTTPS"); + } + + auto promise = networkToUse->parseAddress(parsed.host, isHttps ? 443 : 80) + .then([this](kj::Own addr) { + return kj::heap( + timer, responseHeaderTable, kj::mv(addr), settings); + }); + + Host host { + kj::mv(parsed.host), + kj::heap(kj::mv(promise)) + }; + kj::StringPtr nameRef = host.name; + + auto insertResult = hosts.insert(std::make_pair(nameRef, kj::mv(host))); + KJ_ASSERT(insertResult.second); + iter = insertResult.first; + + tasks.add(handleCleanup(hosts, iter)); + } + + return *iter->second.client; + } + + kj::Promise handleCleanup(std::map& hosts, + std::map::iterator iter) { + return iter->second.client->onDrained() + .then([this,&hosts,iter]() -> kj::Promise { + // Double-check that it's really drained to avoid race conditions. + if (iter->second.client->isDrained()) { + hosts.erase(iter); + return kj::READY_NOW; + } else { + return handleCleanup(hosts, iter); + } + }); + } + + void taskFailed(kj::Exception&& exception) override { + KJ_LOG(ERROR, exception); + } +}; + +} // namespace + +kj::Own newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, + kj::NetworkAddress& addr, HttpClientSettings settings) { + return kj::heap(timer, responseHeaderTable, + kj::Own(&addr, kj::NullDisposer::instance), kj::mv(settings)); +} + +kj::Own newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, + kj::Network& network, kj::Maybe tlsNetwork, + HttpClientSettings settings) { + return kj::heap( + timer, responseHeaderTable, network, tlsNetwork, kj::mv(settings)); +} + +// ======================================================================================= + +namespace { + +class ConcurrencyLimitingHttpClient final: public HttpClient { +public: + KJ_DISALLOW_COPY_AND_MOVE(ConcurrencyLimitingHttpClient); + ConcurrencyLimitingHttpClient( + kj::HttpClient& inner, uint maxConcurrentRequests, + kj::Function countChangedCallback) + : inner(inner), + maxConcurrentRequests(maxConcurrentRequests), + countChangedCallback(kj::mv(countChangedCallback)) {} + + ~ConcurrencyLimitingHttpClient() noexcept(false) { + if (concurrentRequests > 0) { + static bool logOnce KJ_UNUSED = ([&] { + KJ_LOG(ERROR, "ConcurrencyLimitingHttpClient getting destroyed when concurrent requests " + "are still active", concurrentRequests); + return true; + })(); + } + } + + Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + if (concurrentRequests < maxConcurrentRequests) { + auto counter = ConnectionCounter(*this); + auto request = inner.request(method, url, headers, expectedBodySize); + fireCountChanged(); + auto promise = attachCounter(kj::mv(request.response), kj::mv(counter)); + return { kj::mv(request.body), kj::mv(promise) }; + } + + auto paf = kj::newPromiseAndFulfiller(); + auto urlCopy = kj::str(url); + auto headersCopy = headers.clone(); + + auto combined = paf.promise + .then([this, + method, + urlCopy = kj::mv(urlCopy), + headersCopy = kj::mv(headersCopy), + expectedBodySize](ConnectionCounter&& counter) mutable { + auto req = inner.request(method, urlCopy, headersCopy, expectedBodySize); + return kj::tuple(kj::mv(req.body), attachCounter(kj::mv(req.response), kj::mv(counter))); + }); + auto split = combined.split(); + pendingRequests.push(kj::mv(paf.fulfiller)); + fireCountChanged(); + return { newPromisedStream(kj::mv(kj::get<0>(split))), kj::mv(kj::get<1>(split)) }; + } + + kj::Promise openWebSocket( + kj::StringPtr url, const kj::HttpHeaders& headers) override { + if (concurrentRequests < maxConcurrentRequests) { + auto counter = ConnectionCounter(*this); + auto response = inner.openWebSocket(url, headers); + fireCountChanged(); + return attachCounter(kj::mv(response), kj::mv(counter)); + } + + auto paf = kj::newPromiseAndFulfiller(); + auto urlCopy = kj::str(url); + auto headersCopy = headers.clone(); + + auto promise = paf.promise + .then([this, + urlCopy = kj::mv(urlCopy), + headersCopy = kj::mv(headersCopy)](ConnectionCounter&& counter) mutable { + return attachCounter(inner.openWebSocket(urlCopy, headersCopy), kj::mv(counter)); + }); + + pendingRequests.push(kj::mv(paf.fulfiller)); + fireCountChanged(); + return kj::mv(promise); + } + + ConnectRequest connect( + kj::StringPtr host, const kj::HttpHeaders& headers, HttpConnectSettings settings) override { + if (concurrentRequests < maxConcurrentRequests) { + auto counter = ConnectionCounter(*this); + auto response = inner.connect(host, headers, settings); + fireCountChanged(); + return attachCounter(kj::mv(response), kj::mv(counter)); + } + + auto paf = kj::newPromiseAndFulfiller(); + + auto split = paf.promise + .then([this, host=kj::str(host), headers=headers.clone(), settings] + (ConnectionCounter&& counter) mutable + -> kj::Tuple, + kj::Promise>> { + auto request = attachCounter(inner.connect(host, headers, settings), kj::mv(counter)); + return kj::tuple(kj::mv(request.status), kj::mv(request.connection)); + }).split(); + + pendingRequests.push(kj::mv(paf.fulfiller)); + fireCountChanged(); + + return ConnectRequest { + kj::mv(kj::get<0>(split)), + kj::newPromisedStream(kj::mv(kj::get<1>(split))) + }; + } + +private: + struct ConnectionCounter; + + kj::HttpClient& inner; + uint maxConcurrentRequests; + uint concurrentRequests = 0; + kj::Function countChangedCallback; + + std::queue>> pendingRequests; + // TODO(someday): want maximum cap on queue size? + + struct ConnectionCounter final { + ConnectionCounter(ConcurrencyLimitingHttpClient& client) : parent(&client) { + ++parent->concurrentRequests; + } + KJ_DISALLOW_COPY(ConnectionCounter); + ~ConnectionCounter() noexcept(false) { + if (parent != nullptr) { + --parent->concurrentRequests; + parent->serviceQueue(); + parent->fireCountChanged(); + } + } + ConnectionCounter(ConnectionCounter&& other) : parent(other.parent) { + other.parent = nullptr; + } + ConnectionCounter& operator=(ConnectionCounter&& other) { + if (this != &other) { + this->parent = other.parent; + other.parent = nullptr; + } + return *this; + } + + ConcurrencyLimitingHttpClient* parent; + }; + + void serviceQueue() { + while (concurrentRequests < maxConcurrentRequests && !pendingRequests.empty()) { + auto fulfiller = kj::mv(pendingRequests.front()); + pendingRequests.pop(); + // ConnectionCounter's destructor calls this function, so we can avoid unnecessary recursion + // if we only create a ConnectionCounter when we find a waiting fulfiller. + if (fulfiller->isWaiting()) { + fulfiller->fulfill(ConnectionCounter(*this)); + } + } + } + + void fireCountChanged() { + countChangedCallback(concurrentRequests, pendingRequests.size()); + } + + using WebSocketOrBody = kj::OneOf, kj::Own>; + static WebSocketOrBody attachCounter(WebSocketOrBody&& webSocketOrBody, + ConnectionCounter&& counter) { + KJ_SWITCH_ONEOF(webSocketOrBody) { + KJ_CASE_ONEOF(ws, kj::Own) { + return ws.attach(kj::mv(counter)); + } + KJ_CASE_ONEOF(body, kj::Own) { + return body.attach(kj::mv(counter)); + } + } + KJ_UNREACHABLE; + } + + static kj::Promise attachCounter(kj::Promise&& promise, + ConnectionCounter&& counter) { + return promise.then([counter = kj::mv(counter)](WebSocketResponse&& response) mutable { + return WebSocketResponse { + response.statusCode, + response.statusText, + response.headers, + attachCounter(kj::mv(response.webSocketOrBody), kj::mv(counter)) + }; + }); + } + + static kj::Promise attachCounter(kj::Promise&& promise, + ConnectionCounter&& counter) { + return promise.then([counter = kj::mv(counter)](Response&& response) mutable { + return Response { + response.statusCode, + response.statusText, + response.headers, + response.body.attach(kj::mv(counter)) + }; + }); + } + + static ConnectRequest attachCounter( + ConnectRequest&& request, + ConnectionCounter&& counter) { + // Notice here that we are only attaching the counter to the connection stream. In the case + // where the connect tunnel request is rejected and the status promise resolves with an + // errorBody, there is a possibility that the consuming code might drop the connection stream + // and the counter while the error body stream is still be consumed. Technically speaking that + // means we could potentially exceed our concurrency limit temporarily but we consider that + // acceptable here since the error body is an exception path (plus not requiring that we + // attach to the errorBody keeps ConnectionCounter from having to be a refcounted heap + // allocation). + request.connection = request.connection.attach(kj::mv(counter)); + return kj::mv(request); + } +}; + +} + +kj::Own newConcurrencyLimitingHttpClient( + HttpClient& inner, uint maxConcurrentRequests, + kj::Function countChangedCallback) { + return kj::heap(inner, maxConcurrentRequests, + kj::mv(countChangedCallback)); +} + +// ======================================================================================= + +namespace { + +class HttpClientAdapter final: public HttpClient { +public: + HttpClientAdapter(HttpService& service): service(service) {} + + Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + // We have to clone the URL and headers because HttpService implementation are allowed to + // assume that they remain valid until the service handler completes whereas HttpClient callers + // are allowed to destroy them immediately after the call. + auto urlCopy = kj::str(url); + auto headersCopy = kj::heap(headers.clone()); + + auto pipe = newOneWayPipe(expectedBodySize); + + // TODO(cleanup): The ownership relationships here are a mess. Can we do something better + // involving a PromiseAdapter, maybe? + auto paf = kj::newPromiseAndFulfiller(); + auto responder = kj::refcounted(method, kj::mv(paf.fulfiller)); + + auto requestPaf = kj::newPromiseAndFulfiller>(); + responder->setPromise(kj::mv(requestPaf.promise)); + + auto promise = service.request(method, urlCopy, *headersCopy, *pipe.in, *responder) + .attach(kj::mv(pipe.in), kj::mv(urlCopy), kj::mv(headersCopy)); + requestPaf.fulfiller->fulfill(kj::mv(promise)); + + return { + kj::mv(pipe.out), + paf.promise.attach(kj::mv(responder)) + }; + } + + kj::Promise openWebSocket( + kj::StringPtr url, const HttpHeaders& headers) override { + // We have to clone the URL and headers because HttpService implementation are allowed to + // assume that they remain valid until the service handler completes whereas HttpClient callers + // are allowed to destroy them immediately after the call. Also we need to add + // `Upgrade: websocket` so that headers.isWebSocket() returns true on the service side. + auto urlCopy = kj::str(url); + auto headersCopy = kj::heap(headers.clone()); + headersCopy->set(HttpHeaderId::UPGRADE, "websocket"); + KJ_DASSERT(headersCopy->isWebSocket()); + + auto paf = kj::newPromiseAndFulfiller(); + auto responder = kj::refcounted(kj::mv(paf.fulfiller)); + + auto requestPaf = kj::newPromiseAndFulfiller>(); + responder->setPromise(kj::mv(requestPaf.promise)); + + auto in = kj::heap(); + auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, *responder) + .attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy)); + requestPaf.fulfiller->fulfill(kj::mv(promise)); + + return paf.promise.attach(kj::mv(responder)); + } + + ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings) override { + // We have to clone the host and the headers because HttpServer implementation are allowed to + // assusme that they remain valid until the service handler completes whereas HttpClient callers + // are allowed to destroy them immediately after the call. + auto hostCopy = kj::str(host); + auto headersCopy = kj::heap(headers.clone()); + + // 1. Create a new TwoWayPipe, one will be returned with the ConnectRequest, + // the other will be held by the ConnectResponseImpl. + auto pipe = kj::newTwoWayPipe(); + + // 2. Create a promise/fulfiller pair for the status. The promise will be + // returned with the ConnectResponse, the fulfiller will be held by the + // ConnectResponseImpl. + auto paf = kj::newPromiseAndFulfiller(); + + // 3. Create the ConnectResponseImpl + auto response = kj::refcounted(kj::mv(paf.fulfiller), + kj::mv(pipe.ends[0])); + + // 5. Call service.connect, passing in the tunnel. + // The call to tunnel->getConnectStream() returns a guarded stream that will buffer + // writes until the status is indicated by calling accept/reject. + auto connectStream = response->getConnectStream(); + auto promise = service.connect(hostCopy, *headersCopy, *connectStream, *response, settings) + .eagerlyEvaluate([response=kj::mv(response), + host=kj::mv(hostCopy), + headers=kj::mv(headersCopy), + connectStream=kj::mv(connectStream)](kj::Exception&& ex) mutable { + // A few things need to happen here. + // 1. We'll log the exception. + // 2. We'll break the pipe. + // 3. We'll reject the status promise if it is still pending. + // + // We'll do all of this within the ConnectResponseImpl, however, since it + // maintains the state necessary here. + response->handleException(kj::mv(ex), kj::mv(connectStream)); + }); + + // TODO(bug): There's a challenge with attaching the service.connect promise to the + // connection stream below in that the client will likely drop the connection as soon + // as it reads EOF, but the promise representing the service connect() call may still + // be running and want to do some cleanup after it has sent EOF. That cleanup will be + // canceled. For regular HTTP calls, DelayedEofInputStream was created to address this + // exact issue but with connect() being bidirectional it's rather more difficult. We + // want a delay similar to what DelayedEofInputStream adds but only when both directions + // have been closed. That currently is not possible until we have an alternative to + // shutdownWrite() that returns a Promise (e.g. Promise end()). For now, we can + // live with the current limitation. + return ConnectRequest { + kj::mv(paf.promise), + pipe.ends[1].attach(kj::mv(promise)), + }; + } + +private: + HttpService& service; + + class DelayedEofInputStream final: public kj::AsyncInputStream { + // An AsyncInputStream wrapper that, when it reaches EOF, delays the final read until some + // promise completes. + + public: + DelayedEofInputStream(kj::Own inner, kj::Promise completionTask) + : inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return wrap(minBytes, inner->tryRead(buffer, minBytes, maxBytes)); + } + + kj::Maybe tryGetLength() override { + return inner->tryGetLength(); + } + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return wrap(amount, inner->pumpTo(output, amount)); + } + + private: + kj::Own inner; + kj::Maybe> completionTask; + + template + kj::Promise wrap(T requested, kj::Promise innerPromise) { + return innerPromise.then([this,requested](T actual) -> kj::Promise { + if (actual < requested) { + // Must have reached EOF. + KJ_IF_MAYBE(t, completionTask) { + // Delay until completion. + auto result = t->then([actual]() { return actual; }); + completionTask = nullptr; + return result; + } else { + // Must have called tryRead() again after we already signaled EOF. Fine. + return actual; + } + } else { + return actual; + } + }, [this](kj::Exception&& e) -> kj::Promise { + // The stream threw an exception, but this exception is almost certainly just complaining + // that the other end of the stream was dropped. In all likelihood, the HttpService + // request() call itself will throw a much more interesting error -- we'd rather propagate + // that one, if so. + KJ_IF_MAYBE(t, completionTask) { + auto result = t->then([e = kj::mv(e)]() mutable -> kj::Promise { + // Looks like the service didn't throw. I guess we should propagate the stream error + // after all. + return kj::mv(e); + }); + completionTask = nullptr; + return result; + } else { + // Must have called tryRead() again after we already signaled EOF or threw. Fine. + return kj::mv(e); + } + }); + } + }; + + class ResponseImpl final: public HttpService::Response, public kj::Refcounted { + public: + ResponseImpl(kj::HttpMethod method, + kj::Own> fulfiller) + : method(method), fulfiller(kj::mv(fulfiller)) {} + + void setPromise(kj::Promise promise) { + task = promise.eagerlyEvaluate([this](kj::Exception&& exception) { + if (fulfiller->isWaiting()) { + fulfiller->reject(kj::mv(exception)); + } else { + // We need to cause the response stream's read() to throw this, so we should propagate it. + kj::throwRecoverableException(kj::mv(exception)); + } + }); + } + + kj::Own send( + uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + // The caller of HttpClient is allowed to assume that the statusText and headers remain + // valid until the body stream is dropped, but the HttpService implementation is allowed to + // send values that are only valid until send() returns, so we have to copy. + auto statusTextCopy = kj::str(statusText); + auto headersCopy = kj::heap(headers.clone()); + + if (method == kj::HttpMethod::HEAD || expectedBodySize.orDefault(1) == 0) { + // We're not expecting any body. We need to delay reporting completion to the client until + // the server side has actually returned from the service method, otherwise we may + // prematurely cancel it. + + task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy), + headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable { + fulfiller->fulfill({ + statusCode, statusTextCopy, headersCopy.get(), + kj::heap(expectedBodySize) + .attach(kj::mv(statusTextCopy), kj::mv(headersCopy)) + }); + }).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + return kj::heap(); + } else { + auto pipe = newOneWayPipe(expectedBodySize); + + // Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until + // the service's request promise has finished. + auto wrapper = kj::heap( + kj::mv(pipe.in), task.attach(kj::addRef(*this))); + + fulfiller->fulfill({ + statusCode, statusTextCopy, headersCopy.get(), + wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy)) + }); + return kj::mv(pipe.out); + } + } + + kj::Own acceptWebSocket(const HttpHeaders& headers) override { + KJ_FAIL_REQUIRE("a WebSocket was not requested"); + } + + private: + kj::HttpMethod method; + kj::Own> fulfiller; + kj::Promise task = nullptr; + }; + + class DelayedCloseWebSocket final: public WebSocket { + // A WebSocket wrapper that, when it reaches Close (in both directions), delays the final close + // operation until some promise completes. + + public: + DelayedCloseWebSocket(kj::Own inner, kj::Promise completionTask) + : inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {} + + kj::Promise send(kj::ArrayPtr message) override { + return inner->send(message); + } + kj::Promise send(kj::ArrayPtr message) override { + return inner->send(message); + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + return inner->close(code, reason) + .then([this]() { + return afterSendClosed(); + }); + } + kj::Promise disconnect() override { + return inner->disconnect(); + } + void abort() override { + // Don't need to worry about completion task in this case -- cancelling it is reasonable. + inner->abort(); + } + kj::Promise whenAborted() override { + return inner->whenAborted(); + } + kj::Promise receive(size_t maxSize) override { + return inner->receive(maxSize).then([this](Message&& message) -> kj::Promise { + if (message.is()) { + return afterReceiveClosed() + .then([message = kj::mv(message)]() mutable { return kj::mv(message); }); + } + return kj::mv(message); + }); + } + kj::Promise pumpTo(WebSocket& other) override { + return inner->pumpTo(other).then([this]() { + return afterReceiveClosed(); + }); + } + kj::Maybe> tryPumpFrom(WebSocket& other) override { + return other.pumpTo(*inner).then([this]() { + return afterSendClosed(); + }); + } + + uint64_t sentByteCount() override { return inner->sentByteCount(); } + uint64_t receivedByteCount() override { return inner->receivedByteCount(); } + + private: + kj::Own inner; + kj::Maybe> completionTask; + + bool sentClose = false; + bool receivedClose = false; + + kj::Promise afterSendClosed() { + sentClose = true; + if (receivedClose) { + KJ_IF_MAYBE(t, completionTask) { + auto result = kj::mv(*t); + completionTask = nullptr; + return result; + } + } + return kj::READY_NOW; + } + + kj::Promise afterReceiveClosed() { + receivedClose = true; + if (sentClose) { + KJ_IF_MAYBE(t, completionTask) { + auto result = kj::mv(*t); + completionTask = nullptr; + return result; + } + } + return kj::READY_NOW; + } + }; + + class WebSocketResponseImpl final: public HttpService::Response, public kj::Refcounted { + public: + WebSocketResponseImpl(kj::Own> fulfiller) + : fulfiller(kj::mv(fulfiller)) {} + + void setPromise(kj::Promise promise) { + task = promise.eagerlyEvaluate([this](kj::Exception&& exception) { + if (fulfiller->isWaiting()) { + fulfiller->reject(kj::mv(exception)); + } else { + // We need to cause the client-side WebSocket to throw on close, so propagate the + // exception. + kj::throwRecoverableException(kj::mv(exception)); + } + }); + } + + kj::Own send( + uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + // The caller of HttpClient is allowed to assume that the statusText and headers remain + // valid until the body stream is dropped, but the HttpService implementation is allowed to + // send values that are only valid until send() returns, so we have to copy. + auto statusTextCopy = kj::str(statusText); + auto headersCopy = kj::heap(headers.clone()); + + if (expectedBodySize.orDefault(1) == 0) { + // We're not expecting any body. We need to delay reporting completion to the client until + // the server side has actually returned from the service method, otherwise we may + // prematurely cancel it. + + task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy), + headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable { + fulfiller->fulfill({ + statusCode, statusTextCopy, headersCopy.get(), + kj::Own(kj::heap(expectedBodySize) + .attach(kj::mv(statusTextCopy), kj::mv(headersCopy))) + }); + }).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + return kj::heap(); + } else { + auto pipe = newOneWayPipe(expectedBodySize); + + // Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until + // the service's request promise has finished. + kj::Own wrapper = + kj::heap(kj::mv(pipe.in), task.attach(kj::addRef(*this))); + + fulfiller->fulfill({ + statusCode, statusTextCopy, headersCopy.get(), + wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy)) + }); + return kj::mv(pipe.out); + } + } + + kj::Own acceptWebSocket(const HttpHeaders& headers) override { + // The caller of HttpClient is allowed to assume that the headers remain valid until the body + // stream is dropped, but the HttpService implementation is allowed to send headers that are + // only valid until acceptWebSocket() returns, so we have to copy. + auto headersCopy = kj::heap(headers.clone()); + + auto pipe = newWebSocketPipe(); + + // Wrap the client-side WebSocket in a wrapper that delays clean close of the WebSocket until + // the service's request promise has finished. + kj::Own wrapper = + kj::heap(kj::mv(pipe.ends[0]), task.attach(kj::addRef(*this))); + fulfiller->fulfill({ + 101, "Switching Protocols", headersCopy.get(), + wrapper.attach(kj::mv(headersCopy)) + }); + return kj::mv(pipe.ends[1]); + } + + private: + kj::Own> fulfiller; + kj::Promise task = nullptr; + }; + + class ConnectResponseImpl final: public HttpService::ConnectResponse, public kj::Refcounted { + public: + ConnectResponseImpl( + kj::Own> fulfiller, + kj::Own stream) + : fulfiller(kj::mv(fulfiller)), + streamAndFulfiller(initStreamsAndFulfiller(kj::mv(stream))) {} + + ~ConnectResponseImpl() noexcept(false) { + if (fulfiller->isWaiting() || streamAndFulfiller.fulfiller->isWaiting()) { + auto ex = KJ_EXCEPTION(FAILED, + "service's connect() implementation never called accept() nor reject()"); + if (fulfiller->isWaiting()) { + fulfiller->reject(kj::cp(ex)); + } + if (streamAndFulfiller.fulfiller->isWaiting()) { + streamAndFulfiller.fulfiller->reject(kj::mv(ex)); + } + } + } + + void accept(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) override { + KJ_REQUIRE(statusCode >= 200 && statusCode < 300, "the statusCode must be 2xx for accept"); + respond(statusCode, statusText, headers); + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) override { + KJ_REQUIRE(statusCode < 200 || statusCode >= 300, + "the statusCode must not be 2xx for reject."); + auto pipe = kj::newOneWayPipe(); + respond(statusCode, statusText, headers, kj::mv(pipe.in)); + return kj::mv(pipe.out); + } + + private: + struct StreamsAndFulfiller { + // guarded is the wrapped/guarded stream that wraps a reference to + // the underlying stream but blocks reads until the connection is accepted + // or rejected. + // This will be handed off when getConnectStream() is called. + // The fulfiller is used to resolve the guard for the second stream. This will + // be fulfilled or rejected when accept/reject is called. + kj::Own guarded; + kj::Own> fulfiller; + }; + + kj::Own> fulfiller; + StreamsAndFulfiller streamAndFulfiller; + bool connectStreamDetached = false; + + StreamsAndFulfiller initStreamsAndFulfiller(kj::Own stream) { + auto paf = kj::newPromiseAndFulfiller(); + auto guarded = kj::heap( + kj::mv(stream), + kj::Maybe(nullptr), + kj::mv(paf.promise)); + return StreamsAndFulfiller { + kj::mv(guarded), + kj::mv(paf.fulfiller) + }; + } + + void handleException(kj::Exception&& ex, kj::Own connectStream) { + // Log the exception... + KJ_LOG(ERROR, "Error in HttpClientAdapter connect()", kj::cp(ex)); + // Reject the status promise if it is still pending... + if (fulfiller->isWaiting()) { + fulfiller->reject(kj::cp(ex)); + } + if (streamAndFulfiller.fulfiller->isWaiting()) { + // If the guard hasn't yet ben released, we can fail the pending reads by + // rejecting the fulfiller here. + streamAndFulfiller.fulfiller->reject(kj::mv(ex)); + } else { + // The guard has already been released at this point. + // TODO(connect): How to properly propagate the actual exception to the + // connect stream? Here we "simply" shut it down. + connectStream->abortRead(); + connectStream->shutdownWrite(); + } + } + + kj::Own getConnectStream() { + KJ_ASSERT(!connectStreamDetached, "the connect stream was already detached"); + connectStreamDetached = true; + return streamAndFulfiller.guarded.attach(kj::addRef(*this)); + } + + void respond(uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe> errorBody = nullptr) { + if (errorBody == nullptr) { + streamAndFulfiller.fulfiller->fulfill(); + } else { + streamAndFulfiller.fulfiller->reject( + KJ_EXCEPTION(DISCONNECTED, "the connect request was rejected")); + } + fulfiller->fulfill(HttpClient::ConnectRequest::Status( + statusCode, + kj::str(statusText), + kj::heap(headers.clone()), + kj::mv(errorBody))); + } + + friend class HttpClientAdapter; + }; + +}; + +} // namespace + +kj::Own newHttpClient(HttpService& service) { + return kj::heap(service); +} + +// ======================================================================================= + +namespace { + +class HttpServiceAdapter final: public HttpService { +public: + HttpServiceAdapter(HttpClient& client): client(client) {} + + kj::Promise request( + HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, + kj::AsyncInputStream& requestBody, Response& response) override { + if (!headers.isWebSocket()) { + auto innerReq = client.request(method, url, headers, requestBody.tryGetLength()); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(requestBody.pumpTo(*innerReq.body).ignoreResult() + .attach(kj::mv(innerReq.body)).eagerlyEvaluate(nullptr)); + + promises.add(innerReq.response + .then([&response](HttpClient::Response&& innerResponse) { + auto out = response.send( + innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers, + innerResponse.body->tryGetLength()); + auto promise = innerResponse.body->pumpTo(*out); + return promise.ignoreResult().attach(kj::mv(out), kj::mv(innerResponse.body)); + })); + + return kj::joinPromisesFailFast(promises.finish()); + } else { + return client.openWebSocket(url, headers) + .then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise { + KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) { + KJ_CASE_ONEOF(ws, kj::Own) { + auto ws2 = response.acceptWebSocket(*innerResponse.headers); + auto promises = kj::heapArrayBuilder>(2); + promises.add(ws->pumpTo(*ws2)); + promises.add(ws2->pumpTo(*ws)); + return kj::joinPromisesFailFast(promises.finish()).attach(kj::mv(ws), kj::mv(ws2)); + } + KJ_CASE_ONEOF(body, kj::Own) { + auto out = response.send( + innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers, + body->tryGetLength()); + auto promise = body->pumpTo(*out); + return promise.ignoreResult().attach(kj::mv(out), kj::mv(body)); + } + } + KJ_UNREACHABLE; + }); + } + } + + kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + HttpConnectSettings settings) override { + KJ_REQUIRE(!headers.isWebSocket(), "WebSocket upgrade headers are not permitted in a connect."); + + auto request = client.connect(host, headers, settings); + + // This operates optimistically. In order to support pipelining, we connect the + // input and outputs streams immediately, even if we're not yet certain that the + // tunnel can actually be established. + auto promises = kj::heapArrayBuilder>(2); + + // For the inbound pipe (from the clients stream to the passed in stream) + // We want to guard reads pending the acceptance of the tunnel. If the + // tunnel is not accepted, the guard will be rejected, causing pending + // reads to fail. + auto paf = kj::newPromiseAndFulfiller>(); + auto io = kj::heap( + kj::mv(request.connection), + kj::mv(paf.promise) /* read guard */, + kj::READY_NOW /* write guard */); + + // Writing from connection to io is unguarded and allowed immediately. + promises.add(connection.pumpTo(*io).then([&io=*io](uint64_t size) { + io.shutdownWrite(); + })); + + promises.add(io->pumpTo(connection).then([&connection](uint64_t size) { + connection.shutdownWrite(); + })); + + auto pumpPromise = kj::joinPromisesFailFast(promises.finish()); + + return request.status.then( + [&response,&connection,fulfiller=kj::mv(paf.fulfiller), + pumpPromise=kj::mv(pumpPromise)] + (HttpClient::ConnectRequest::Status status) mutable -> kj::Promise { + if (status.statusCode >= 200 && status.statusCode < 300) { + // Release the read guard! + fulfiller->fulfill(kj::Maybe(nullptr)); + response.accept(status.statusCode, status.statusText, *status.headers); + return kj::mv(pumpPromise); + } else { + // If the connect request is rejected, we want to shutdown the tunnel + // and pipeline the status.errorBody to the AsyncOutputStream returned by + // reject if it exists. + pumpPromise = nullptr; + connection.shutdownWrite(); + fulfiller->reject(KJ_EXCEPTION(DISCONNECTED, "the connect request was rejected")); + KJ_IF_MAYBE(errorBody, status.errorBody) { + auto out = response.reject(status.statusCode, status.statusText, *status.headers, + errorBody->get()->tryGetLength()); + return (*errorBody)->pumpTo(*out).then([](uint64_t) -> kj::Promise { + return kj::READY_NOW; + }).attach(kj::mv(out), kj::mv(*errorBody)); + } else { + response.reject(status.statusCode, status.statusText, *status.headers, (uint64_t)0); + return kj::READY_NOW; + } + } + }).attach(kj::mv(io)); + } + +private: + HttpClient& client; +}; + +} // namespace + +kj::Own newHttpService(HttpClient& client) { + return kj::heap(client); +} + +// ======================================================================================= + +kj::Promise HttpService::Response::sendError( + uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) { + auto stream = send(statusCode, statusText, headers, statusText.size()); + auto promise = stream->write(statusText.begin(), statusText.size()); + return promise.attach(kj::mv(stream)); +} + +kj::Promise HttpService::Response::sendError( + uint statusCode, kj::StringPtr statusText, const HttpHeaderTable& headerTable) { + return sendError(statusCode, statusText, HttpHeaders(headerTable)); +} + +kj::Promise HttpService::connect( + kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + kj::HttpConnectSettings settings) { + KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService"); +} + +class HttpServer::Connection final: private HttpService::Response, + private HttpService::ConnectResponse, + private HttpServerErrorHandler { +public: + Connection(HttpServer& server, kj::AsyncIoStream& stream, + SuspendableHttpServiceFactory factory, kj::Maybe suspendedRequest, + bool wantCleanDrain) + : server(server), + stream(stream), + factory(kj::mv(factory)), + httpInput(makeHttpInput(stream, server.requestHeaderTable, kj::mv(suspendedRequest))), + httpOutput(stream), + wantCleanDrain(wantCleanDrain) { + ++server.connectionCount; + } + ~Connection() noexcept(false) { + if (--server.connectionCount == 0) { + KJ_IF_MAYBE(f, server.zeroConnectionsFulfiller) { + f->get()->fulfill(); + } + } + } + +public: + kj::Promise startLoop(bool firstRequest) { + return loop(firstRequest).catch_([this](kj::Exception&& e) -> kj::Promise { + // Exception; report 5xx. + + KJ_IF_MAYBE(p, webSocketError) { + // sendWebSocketError() was called. Finish sending and close the connection. Don't log + // the exception because it's probably a side-effect of this. + auto promise = kj::mv(*p); + webSocketError = nullptr; + return kj::mv(promise); + } + + KJ_IF_MAYBE(p, tunnelRejected) { + // reject() was called to reject a CONNECT request. Finish sending and close the connection. + // Don't log the exception because it's probably a side-effect of this. + auto promise = kj::mv(*p); + tunnelRejected = nullptr; + return kj::mv(promise); + } + + return sendError(kj::mv(e)); + }); + } + + SuspendedRequest suspend(SuspendableRequest& suspendable) { + KJ_REQUIRE(httpInput.canSuspend(), + "suspend() may only be called before the request body is consumed"); + KJ_DEFER(suspended = true); + auto released = httpInput.releaseBuffer(); + return { + kj::mv(released.buffer), + released.leftover, + suspendable.method, + suspendable.url, + suspendable.headers.cloneShallow(), + }; + } private: HttpServer& server; - HttpInputStream httpInput; + kj::AsyncIoStream& stream; + + SuspendableHttpServiceFactory factory; + // Creates a new kj::Own for each request we handle on this connection. + + HttpInputStreamImpl httpInput; HttpOutputStream httpOutput; - kj::Own ownStream; - kj::Maybe currentMethod; + kj::Maybe> currentMethod; bool timedOut = false; + bool closed = false; + bool upgraded = false; + bool webSocketOrConnectClosed = false; + bool closeAfterSend = false; // True if send() should set Connection: close. + bool wantCleanDrain = false; + bool suspended = false; + kj::Maybe> webSocketError; + kj::Maybe> tunnelRejected; + kj::Maybe>> tunnelWriteGuard; + + static HttpInputStreamImpl makeHttpInput( + kj::AsyncIoStream& stream, + const kj::HttpHeaderTable& table, + kj::Maybe suspendedRequest) { + // Constructor helper function to create our HttpInputStreamImpl. + + KJ_IF_MAYBE(sr, suspendedRequest) { + return HttpInputStreamImpl(stream, + sr->buffer.releaseAsChars(), + sr->leftover.asChars(), + sr->method, + sr->url, + kj::mv(sr->headers)); + } + return HttpInputStreamImpl(stream, table); + } + + kj::Promise loop(bool firstRequest) { + if (!firstRequest && server.draining && httpInput.isCleanDrain()) { + // Don't call awaitNextMessage() in this case because that will initiate a read() which will + // immediately be canceled, losing data. + return true; + } + + auto firstByte = httpInput.awaitNextMessage(); + + if (!firstRequest) { + // For requests after the first, require that the first byte arrive before the pipeline + // timeout, otherwise treat it like the connection was simply closed. + auto timeoutPromise = server.timer.afterDelay(server.settings.pipelineTimeout); + + if (httpInput.isCleanDrain()) { + // If we haven't buffered any data, then we can safely drain here, so allow the wait to + // be canceled by the onDrain promise. + auto cleanDrainPromise = server.onDrain.addBranch() + .then([this]() -> kj::Promise { + // This is a little tricky... drain() has been called, BUT we could have read some data + // into the buffer in the meantime, and we don't want to lose that. If any data has + // arrived, then we have no choice but to read the rest of the request and respond to + // it. + if (!httpInput.isCleanDrain()) { + return kj::NEVER_DONE; + } + + // OK... As far as we know, no data has arrived in the buffer. However, unfortunately, + // we don't *really* know that, because read() is asynchronous. It may have already + // delivered some bytes, but we just haven't received the notification yet, because it's + // still queued on the event loop. As a horrible hack, we use evalLast(), so that any + // such pending notifications get a chance to be delivered. + // TODO(someday): Does this actually work on Windows, where the notification could also + // be queued on the IOCP? + return kj::evalLast([this]() -> kj::Promise { + if (httpInput.isCleanDrain()) { + return kj::READY_NOW; + } else { + return kj::NEVER_DONE; + } + }); + }); + timeoutPromise = timeoutPromise.exclusiveJoin(kj::mv(cleanDrainPromise)); + } + + firstByte = firstByte.exclusiveJoin(timeoutPromise.then([this]() -> bool { + timedOut = true; + return false; + })); + } + + auto receivedHeaders = firstByte + .then([this,firstRequest](bool hasData) + -> kj::Promise { + if (hasData) { + auto readHeaders = httpInput.readRequestHeaders(); + if (!firstRequest) { + // On requests other than the first, the header timeout starts ticking when we receive + // the first byte of a pipeline response. + readHeaders = readHeaders.exclusiveJoin( + server.timer.afterDelay(server.settings.headerTimeout) + .then([this]() -> HttpHeaders::RequestConnectOrProtocolError { + timedOut = true; + return HttpHeaders::ProtocolError { + 408, "Request Timeout", + "Timed out waiting for next request headers.", nullptr + }; + })); + } + return kj::mv(readHeaders); + } else { + // Client closed connection or pipeline timed out with no bytes received. This is not an + // error, so don't report one. + this->closed = true; + return HttpHeaders::RequestConnectOrProtocolError(HttpHeaders::ProtocolError { + 408, "Request Timeout", + "Client closed connection or connection timeout " + "while waiting for request headers.", nullptr + }); + } + }); + + if (firstRequest) { + // On the first request, the header timeout starts ticking immediately upon request opening. + // NOTE: Since we assume that the client wouldn't have formed a connection if they did not + // intend to send a request, we immediately treat this connection as having an active + // request, i.e. we do NOT cancel it if drain() is called. + auto timeoutPromise = server.timer.afterDelay(server.settings.headerTimeout) + .then([this]() -> HttpHeaders::RequestConnectOrProtocolError { + timedOut = true; + return HttpHeaders::ProtocolError { + 408, "Request Timeout", + "Timed out waiting for initial request headers.", nullptr + }; + }); + receivedHeaders = receivedHeaders.exclusiveJoin(kj::mv(timeoutPromise)); + } + + return receivedHeaders + .then([this](HttpHeaders::RequestConnectOrProtocolError&& requestOrProtocolError) + -> kj::Promise { + if (timedOut) { + // Client took too long to send anything, so we're going to close the connection. In + // theory, we should send back an HTTP 408 error -- it is designed exactly for this + // purpose. Alas, in practice, Google Chrome does not have any special handling for 408 + // errors -- it will assume the error is a response to the next request it tries to send, + // and will happily serve the error to the user. OTOH, if we simply close the connection, + // Chrome does the "right thing", apparently. (Though I'm not sure what happens if a + // request is in-flight when we close... if it's a GET, the browser should retry. But if + // it's a POST, retrying may be dangerous. This is why 408 exists -- it unambiguously + // tells the client that it should retry.) + // + // Also note that if we ever decide to send 408 again, we might want to send some other + // error in the case that the server is draining, which also sets timedOut = true; see + // above. + + return httpOutput.flush().then([this]() { + return server.draining && httpInput.isCleanDrain(); + }); + } + + if (closed) { + // Client closed connection. Close our end too. + return httpOutput.flush().then([]() { return false; }); + } + + KJ_SWITCH_ONEOF(requestOrProtocolError) { + KJ_CASE_ONEOF(request, HttpHeaders::ConnectRequest) { + auto& headers = httpInput.getHeaders(); + + currentMethod = HttpConnectMethod(); + + // The HTTP specification says that CONNECT requests have no meaningful payload + // but stops short of saying that CONNECT *cannot* have a payload. Implementations + // can choose to either accept payloads or reject them. We choose to reject it. + // Specifically, if there are Content-Length or Transfer-Encoding headers in the + // request headers, we'll automatically reject the CONNECT request. + // + // The key implication here is that any data that immediately follows the headers + // block of the CONNECT request is considered to be part of the tunnel if it is + // established. + + KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) { + return sendError(HttpHeaders::ProtocolError { + 400, + "Bad Request"_kj, + "Bad Request"_kj, + nullptr, + }); + } + KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) { + return sendError(HttpHeaders::ProtocolError { + 400, + "Bad Request"_kj, + "Bad Request"_kj, + nullptr, + }); + } + + SuspendableRequest suspendable(*this, HttpConnectMethod(), request.authority, headers); + auto maybeService = factory(suspendable); + + if (suspended) { + return false; + } + + auto service = KJ_ASSERT_NONNULL(kj::mv(maybeService), + "SuspendableHttpServiceFactory did not suspend, but returned nullptr."); + auto connectStream = getConnectStream(); + auto promise = service->connect( + request.authority, headers, *connectStream, *this, {}) + .attach(kj::mv(service), kj::mv(connectStream)); + return promise.then([this]() mutable -> kj::Promise { + KJ_IF_MAYBE(p, tunnelRejected) { + // reject() was called to reject a CONNECT attempt. + // Finish sending and close the connection. + auto promise = kj::mv(*p); + tunnelRejected = nullptr; + return kj::mv(promise); + } + + if (httpOutput.isBroken()) { + return false; + } + + return httpOutput.flush().then([]() mutable -> kj::Promise { + // There is really no reasonable path to reusing a CONNECT connection. + return false; + }); + }); + } + KJ_CASE_ONEOF(request, HttpHeaders::Request) { + auto& headers = httpInput.getHeaders(); + + currentMethod = request.method; + + SuspendableRequest suspendable(*this, request.method, request.url, headers); + auto maybeService = factory(suspendable); + + if (suspended) { + return false; + } + + auto service = KJ_ASSERT_NONNULL(kj::mv(maybeService), + "SuspendableHttpServiceFactory did not suspend, but returned nullptr."); + + // TODO(perf): If the client disconnects, should we cancel the response? Probably, to + // prevent permanent deadlock. It's slightly weird in that arguably the client should + // be able to shutdown the upstream but still wait on the downstream, but I believe many + // other HTTP servers do similar things. + + auto body = httpInput.getEntityBody( + HttpInputStreamImpl::REQUEST, request.method, 0, headers); + + auto promise = service->request( + request.method, request.url, headers, *body, *this).attach(kj::mv(service)); + return promise.then([this, body = kj::mv(body)]() mutable -> kj::Promise { + // Response done. Await next request. + + KJ_IF_MAYBE(p, webSocketError) { + // sendWebSocketError() was called. Finish sending and close the connection. + auto promise = kj::mv(*p); + webSocketError = nullptr; + return kj::mv(promise); + } + + if (upgraded) { + // We've upgraded to WebSocket, and by now we should have closed the WebSocket. + if (!webSocketOrConnectClosed) { + // This is gonna segfault later so abort now instead. + KJ_LOG(FATAL, "Accepted WebSocket object must be destroyed before HttpService " + "request handler completes."); + abort(); + } + + // Once we start a WebSocket there's no going back to HTTP. + return false; + } + + if (currentMethod != nullptr) { + return sendError(); + } + + if (httpOutput.isBroken()) { + // We started a response but didn't finish it. But HttpService returns success? + // Perhaps it decided that it doesn't want to finish this response. We'll have to + // disconnect here. If the response body is not complete (e.g. Content-Length not + // reached), the client should notice. We don't want to log an error because this + // condition might be intentional on the service's part. + return false; + } + + return httpOutput.flush().then( + [this, body = kj::mv(body)]() mutable -> kj::Promise { + if (httpInput.canReuse()) { + // Things look clean. Go ahead and accept the next request. + + if (closeAfterSend) { + // We sent Connection: close, so drop the connection now. + return false; + } else { + // Note that we don't have to handle server.draining here because we'll take care + // of it the next time around the loop. + return loop(false); + } + } else { + // Apparently, the application did not read the request body. Maybe this is a bug, + // or maybe not: maybe the client tried to upload too much data and the application + // legitimately wants to cancel the upload without reading all it it. + // + // We have a problem, though: We did send a response, and we didn't send + // `Connection: close`, so the client may expect that it can send another request. + // Perhaps the client has even finished sending the previous request's body, in + // which case the moment it finishes receiving the response, it could be completely + // within its rights to start a new request. If we close the socket now, we might + // interrupt that new request. + // + // Or maybe we did send `Connection: close`, as indicated by `closeAfterSend` being + // true. Even in that case, we should still try to read and ignore the request, + // otherwise when we close the connection the client may get a "connection reset" + // error before they get a chance to actually read the response body that we sent + // them. + // + // There's no way we can get out of this perfectly cleanly. HTTP just isn't good + // enough at connection management. The best we can do is give the client some grace + // period and then abort the connection. + + auto dummy = kj::heap(); + auto lengthGrace = kj::evalNow([&]() { + return body->pumpTo(*dummy, server.settings.canceledUploadGraceBytes); + }).catch_([](kj::Exception&& e) -> uint64_t { + // Reading from the input failed in some way. This may actually be the whole + // reason we got here in the first place so don't propagate this error, just + // give up on discarding the input. + return 0; // This zero is ignored but `canReuse()` will return false below. + }).then([this](uint64_t amount) { + if (httpInput.canReuse()) { + // Success, we can continue. + return true; + } else { + // Still more data. Give up. + return false; + } + }); + lengthGrace = lengthGrace.attach(kj::mv(dummy), kj::mv(body)); + + auto timeGrace = server.timer.afterDelay(server.settings.canceledUploadGracePeriod) + .then([]() { return false; }); + + return lengthGrace.exclusiveJoin(kj::mv(timeGrace)) + .then([this](bool clean) -> kj::Promise { + if (clean && !closeAfterSend) { + // We recovered. Continue loop. + return loop(false); + } else { + // Client still not done, or we sent Connection: close and so want to drop the + // connection anyway. Return broken. + return false; + } + }); + } + }); + }); + } + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + // Bad request. + + // sendError() uses Response::send(), which requires that we have a currentMethod, but we + // never read one. GET seems like the correct choice here. + currentMethod = HttpMethod::GET; + return sendError(kj::mv(protocolError)); + } + } + + KJ_UNREACHABLE; + }); + } kj::Own send( uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, kj::Maybe expectedBodySize) override { - auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called startResponse()"); + auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); currentMethod = nullptr; - HttpHeaders::ConnectionHeaders connectionHeaders; + kj::StringPtr connectionHeaders[HttpHeaders::CONNECTION_HEADERS_COUNT]; kj::String lengthStr; - if (statusCode == 204 || statusCode == 205 || statusCode == 304) { + if (!closeAfterSend) { + // Check if application wants us to close connections. + // + // If the application used listenHttpClientDrain() to listen, then it expects that after a + // clean drain, the connection is still open and can receive more requests. Otherwise, after + // receiving drain(), we will close the connection, so we should send a `Connection: close` + // header. + if (server.draining && !wantCleanDrain) { + closeAfterSend = true; + } else KJ_IF_MAYBE(c, server.settings.callbacks) { + // The application has registered its own callback to decide whether to send + // `Connection: close`. + if (c->shouldClose()) { + closeAfterSend = true; + } + } + } + + if (closeAfterSend) { + connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "close"; + } + + bool isHeadRequest = method.tryGet().map([](auto& m) { + return m == HttpMethod::HEAD; + }).orDefault(false); + + if (statusCode == 204 || statusCode == 304) { // No entity-body. + } else if (statusCode == 205) { + // Status code 205 also has no body, but unlike 204 and 304, it must explicitly encode an + // empty body, e.g. using content-length: 0. I'm guessing this is one of those things, + // where some early clients expected an explicit body while others assumed an empty body, + // and so the standard had to choose the common denominator. + // + // Spec: https://tools.ietf.org/html/rfc7231#section-6.3.6 + connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = "0"; } else KJ_IF_MAYBE(s, expectedBodySize) { - lengthStr = kj::str(*s); - connectionHeaders.contentLength = lengthStr; + // HACK: We interpret a zero-length expected body length on responses to HEAD requests to + // mean "don't set a Content-Length header at all." This provides a way to omit a body + // header on HEAD responses with non-null-body status codes. This is a hack that *only* + // makes sense for HEAD responses. + if (!isHeadRequest || *s > 0) { + lengthStr = kj::str(*s); + connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = lengthStr; + } } else { - connectionHeaders.transferEncoding = "chunked"; + connectionHeaders[HttpHeaders::BuiltinIndices::TRANSFER_ENCODING] = "chunked"; + } + + // For HEAD requests, if the application specified a Content-Length or Transfer-Encoding + // header, use that instead of whatever we decided above. + kj::ArrayPtr connectionHeadersArray = connectionHeaders; + if (isHeadRequest) { + if (headers.get(HttpHeaderId::CONTENT_LENGTH) != nullptr || + headers.get(HttpHeaderId::TRANSFER_ENCODING) != nullptr) { + connectionHeadersArray = connectionHeadersArray + .slice(0, HttpHeaders::HEAD_RESPONSE_CONNECTION_HEADERS_COUNT); + } } - httpOutput.writeHeaders(headers.serializeResponse(statusCode, statusText, connectionHeaders)); + httpOutput.writeHeaders(headers.serializeResponse( + statusCode, statusText, connectionHeadersArray)); kj::Own bodyStream; - if (method == HttpMethod::HEAD) { + if (isHeadRequest) { // Ignore entity-body. httpOutput.finishBody(); return heap(); @@ -1751,32 +7546,239 @@ private: } } - kj::Promise sendError(uint statusCode, kj::StringPtr statusText, kj::String body) { - auto bodySize = kj::str(body.size()); + kj::Own acceptWebSocket(const HttpHeaders& headers) override { + auto& requestHeaders = httpInput.getHeaders(); + KJ_REQUIRE(requestHeaders.isWebSocket(), + "can't call acceptWebSocket() if the request headers didn't have Upgrade: WebSocket"); - HttpHeaders failed(server.requestHeaderTable); - HttpHeaders::ConnectionHeaders connHeaders; - connHeaders.connection = "close"; - connHeaders.contentLength = bodySize; + auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); + KJ_REQUIRE(method.tryGet().map([](auto& m) { + return m == HttpMethod::GET; + }).orDefault(false), "WebSocket must be initiated with a GET request."); - failed.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); + if (requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") { + return sendWebSocketError("The requested WebSocket version is not supported."); + } - httpOutput.writeHeaders(failed.serializeResponse(statusCode, statusText, connHeaders)); - httpOutput.writeBodyData(kj::mv(body)); - httpOutput.finishBody(); - return httpOutput.flush(); // loop ends after flush + kj::String key; + KJ_IF_MAYBE(k, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) { + key = kj::str(*k); + } else { + return sendWebSocketError("Missing Sec-WebSocket-Key"); + } + + kj::Maybe acceptedParameters; + kj::String agreedParameters; + auto compressionMode = server.settings.webSocketCompressionMode; + if (compressionMode == HttpServerSettings::AUTOMATIC_COMPRESSION) { + // If AUTOMATIC_COMPRESSION is enabled, we ignore the `headers` passed by the application and + // strictly refer to the `requestHeaders` from the client. + KJ_IF_MAYBE(value, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + // Perform compression parameter negotiation. + KJ_IF_MAYBE(config, _::tryParseExtensionOffers(*value)) { + acceptedParameters = kj::mv(*config); + } + } + } else if (compressionMode == HttpServerSettings::MANUAL_COMPRESSION) { + // If MANUAL_COMPRESSION is enabled, we use the `headers` passed in by the application, and + // try to find a configuration that respects both the server's preferred configuration, + // as well as the client's requested configuration. + KJ_IF_MAYBE(value, headers.get(HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + // First, we get the manual configuration using `headers`. + KJ_IF_MAYBE(manualConfig, _::tryParseExtensionOffers(*value)) { + KJ_IF_MAYBE(requestOffers, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_EXTENSIONS)) { + // Next, we to find a configuration that both the client and server can accept. + acceptedParameters = _::tryParseAllExtensionOffers(*requestOffers, *manualConfig); + } + } + } + } + + auto websocketAccept = generateWebSocketAccept(key); + + kj::StringPtr connectionHeaders[HttpHeaders::WEBSOCKET_CONNECTION_HEADERS_COUNT]; + connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_ACCEPT] = websocketAccept; + connectionHeaders[HttpHeaders::BuiltinIndices::UPGRADE] = "websocket"; + connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "Upgrade"; + KJ_IF_MAYBE(parameters, acceptedParameters) { + agreedParameters = _::generateExtensionResponse(*parameters); + connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_EXTENSIONS] = agreedParameters; + } + + // Since we're about to write headers, we should nullify `currentMethod`. This tells + // `sendError(kj::Exception)` (called from `HttpServer::Connection::startLoop()`) not to expose + // the `HttpService::Response&` reference to the HttpServer's error `handleApplicationError()` + // callback. This prevents the error handler from inadvertently trying to send another error on + // the connection. + currentMethod = nullptr; + + httpOutput.writeHeaders(headers.serializeResponse( + 101, "Switching Protocols", connectionHeaders)); + + upgraded = true; + // We need to give the WebSocket an Own, but we only have a reference. This is + // safe because the application is expected to drop the WebSocket object before returning + // from the request handler. For some extra safety, we check that webSocketOrConnectClosed has + // been set true when the handler returns. + auto deferNoteClosed = kj::defer([this]() { webSocketOrConnectClosed = true; }); + kj::Own ownStream(&stream, kj::NullDisposer::instance); + return upgradeToWebSocket(ownStream.attach(kj::mv(deferNoteClosed)), + httpInput, httpOutput, nullptr, kj::mv(acceptedParameters), + server.settings.webSocketErrorHandler); + } + + kj::Promise sendError(HttpHeaders::ProtocolError protocolError) { + closeAfterSend = true; + + // Client protocol errors always happen on request headers parsing, before we call into the + // HttpService, meaning no response has been sent and we can provide a Response object. + auto promise = server.settings.errorHandler.orDefault(*this).handleClientProtocolError( + kj::mv(protocolError), *this); + return finishSendingError(kj::mv(promise)); + } + + kj::Promise sendError(kj::Exception&& exception) { + closeAfterSend = true; + + // We only provide the Response object if we know we haven't already sent a response. + auto promise = server.settings.errorHandler.orDefault(*this).handleApplicationError( + kj::mv(exception), currentMethod.map([this](auto&&) -> Response& { return *this; })); + return finishSendingError(kj::mv(promise)); + } + + kj::Promise sendError() { + closeAfterSend = true; + + // We can provide a Response object, since none has already been sent. + auto promise = server.settings.errorHandler.orDefault(*this).handleNoResponse(*this); + return finishSendingError(kj::mv(promise)); + } + + kj::Promise finishSendingError(kj::Promise promise) { + return promise.then([this]() -> kj::Promise { + if (httpOutput.isBroken()) { + // Skip flush for broken streams, since it will throw an exception that may be worse than + // the one we just handled. + return kj::READY_NOW; + } else { + return httpOutput.flush(); + } + }).then([]() { return false; }); // loop ends after flush + } + + kj::Own sendWebSocketError(StringPtr errorMessage) { + kj::Exception exception = KJ_EXCEPTION(FAILED, + "received bad WebSocket handshake", errorMessage); + webSocketError = sendError( + HttpHeaders::ProtocolError { 400, "Bad Request", errorMessage, nullptr }); + kj::throwRecoverableException(kj::mv(exception)); + + // Fallback path when exceptions are disabled. + class BrokenWebSocket final: public WebSocket { + public: + BrokenWebSocket(kj::Exception exception): exception(kj::mv(exception)) {} + + kj::Promise send(kj::ArrayPtr message) override { + return kj::cp(exception); + } + kj::Promise send(kj::ArrayPtr message) override { + return kj::cp(exception); + } + kj::Promise close(uint16_t code, kj::StringPtr reason) override { + return kj::cp(exception); + } + kj::Promise disconnect() override { + return kj::cp(exception); + } + void abort() override { + kj::throwRecoverableException(kj::cp(exception)); + } + kj::Promise whenAborted() override { + return kj::cp(exception); + } + kj::Promise receive(size_t maxSize) override { + return kj::cp(exception); + } + + uint64_t sentByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } + uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } + + private: + kj::Exception exception; + }; + + return kj::heap(KJ_EXCEPTION(FAILED, + "received bad WebSocket handshake", errorMessage)); + } + + kj::Own getConnectStream() { + // Returns an AsyncIoStream over the internal stream but that waits for a Promise to be + // resolved to allow writes after either accept or reject are called. Reads are allowed + // immediately. + KJ_REQUIRE(tunnelWriteGuard == nullptr, "the tunnel stream was already retrieved"); + auto paf = kj::newPromiseAndFulfiller(); + tunnelWriteGuard = kj::mv(paf.fulfiller); + + kj::Own ownStream(&stream, kj::NullDisposer::instance); + auto releasedBuffer = httpInput.releaseBuffer(); + auto deferNoteClosed = kj::defer([this]() { webSocketOrConnectClosed = true; }); + return kj::heap( + kj::heap( + kj::mv(ownStream), + kj::mv(releasedBuffer.buffer), + releasedBuffer.leftover).attach(kj::mv(deferNoteClosed)), + kj::Maybe(nullptr), + kj::mv(paf.promise)); + } + + void accept(uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) override { + auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); + currentMethod = nullptr; + KJ_ASSERT(method.is(), "only use accept() with CONNECT requests"); + KJ_REQUIRE(statusCode >= 200 && statusCode < 300, "the statusCode must be 2xx for accept"); + tunnelRejected = nullptr; + + auto& fulfiller = KJ_ASSERT_NONNULL(tunnelWriteGuard, "the tunnel stream was not initialized"); + httpOutput.writeHeaders(headers.serializeResponse(statusCode, statusText)); + auto promise = httpOutput.flush().then([&fulfiller]() { + fulfiller->fulfill(); + }).eagerlyEvaluate(nullptr); + fulfiller = fulfiller.attach(kj::mv(promise)); + } + + kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize) override { + auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); + KJ_REQUIRE(method.is(), "Only use reject() with CONNECT requests."); + KJ_REQUIRE(statusCode < 200 || statusCode >= 300, "the statusCode must not be 2xx for reject."); + tunnelRejected = Maybe>(true); + + auto& fulfiller = KJ_ASSERT_NONNULL(tunnelWriteGuard, "the tunnel stream was not initialized"); + fulfiller->reject(KJ_EXCEPTION(DISCONNECTED, "the tunnel request was rejected")); + closeAfterSend = true; + return send(statusCode, statusText, headers, expectedBodySize); } }; -HttpServer::HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service, - Settings settings) - : HttpServer(timer, requestHeaderTable, service, settings, +HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, + HttpService& service, Settings settings) + : HttpServer(timer, requestHeaderTable, &service, settings, kj::newPromiseAndFulfiller()) {} -HttpServer::HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service, +HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, + HttpServiceFactory serviceFactory, Settings settings) + : HttpServer(timer, requestHeaderTable, kj::mv(serviceFactory), settings, + kj::newPromiseAndFulfiller()) {} + +HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, + kj::OneOf service, Settings settings, kj::PromiseFulfillerPair paf) - : timer(timer), requestHeaderTable(requestHeaderTable), service(service), settings(settings), - onDrain(paf.promise.fork()), drainFulfiller(kj::mv(paf.fulfiller)), tasks(*this) {} + : timer(timer), requestHeaderTable(requestHeaderTable), service(kj::mv(service)), + settings(settings), onDrain(paf.promise.fork()), drainFulfiller(kj::mv(paf.fulfiller)), + tasks(*this) {} kj::Promise HttpServer::drain() { KJ_REQUIRE(!draining, "you can only call drain() once"); @@ -1800,27 +7802,194 @@ kj::Promise HttpServer::listenHttp(kj::ConnectionReceiver& port) { kj::Promise HttpServer::listenLoop(kj::ConnectionReceiver& port) { return port.accept() .then([this,&port](kj::Own&& connection) -> kj::Promise { - if (draining) { - // Can get here if we *just* started draining. - return kj::READY_NOW; - } - - tasks.add(listenHttp(kj::mv(connection))); + tasks.add(kj::evalNow([&]() { return listenHttp(kj::mv(connection)); })); return listenLoop(port); }); } kj::Promise HttpServer::listenHttp(kj::Own connection) { - auto obj = heap(*this, kj::mv(connection)); - auto promise = obj->loop(); + auto promise = listenHttpImpl(*connection, false /* wantCleanDrain */).ignoreResult(); + + // eagerlyEvaluate() to maintain historical guarantee that this method eagerly closes the + // connection when done. + return promise.attach(kj::mv(connection)).eagerlyEvaluate(nullptr); +} + +kj::Promise HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection) { + return listenHttpImpl(connection, true /* wantCleanDrain */); +} + +kj::Promise HttpServer::listenHttpImpl(kj::AsyncIoStream& connection, bool wantCleanDrain) { + kj::Own srv; + + KJ_SWITCH_ONEOF(service) { + KJ_CASE_ONEOF(ptr, HttpService*) { + // Fake Own okay because we can assume the HttpService outlives this HttpServer, and we can + // assume `this` HttpServer outlives the returned `listenHttpCleanDrain()` promise, which will + // own the fake Own. + srv = kj::Own(ptr, kj::NullDisposer::instance); + } + KJ_CASE_ONEOF(func, HttpServiceFactory) { + srv = func(connection); + } + } + + KJ_ASSERT(srv.get() != nullptr); + + return listenHttpImpl(connection, [srv = kj::mv(srv)](SuspendableRequest&) mutable { + // This factory function will be owned by the Connection object, meaning the Connection object + // will own the HttpService. We also know that the Connection object outlives all + // service.request() promises (service.request() is called from a Connection member function). + // The Owns we return from this function are attached to the service.request() promises, + // meaning this factory function will outlive all Owns we return. So, it's safe to return a fake + // Own. + return kj::Own(srv.get(), kj::NullDisposer::instance); + }, nullptr /* suspendedRequest */, wantCleanDrain); +} + +kj::Promise HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection, + SuspendableHttpServiceFactory factory, + kj::Maybe suspendedRequest) { + // Don't close on drain, because a "clean drain" means we return the connection to the + // application still-open between requests so that it can continue serving future HTTP requests + // on it. + return listenHttpImpl(connection, kj::mv(factory), kj::mv(suspendedRequest), + true /* wantCleanDrain */); +} + +kj::Promise HttpServer::listenHttpImpl(kj::AsyncIoStream& connection, + SuspendableHttpServiceFactory factory, + kj::Maybe suspendedRequest, + bool wantCleanDrain) { + auto obj = heap(*this, connection, kj::mv(factory), kj::mv(suspendedRequest), + wantCleanDrain); + + // Start reading requests and responding to them, but immediately cancel processing if the client + // disconnects. + auto promise = obj->startLoop(true) + .exclusiveJoin(connection.whenWriteDisconnected().then([]() {return false;})); // Eagerly evaluate so that we drop the connection when the promise resolves, even if the caller // doesn't eagerly evaluate. return promise.attach(kj::mv(obj)).eagerlyEvaluate(nullptr); } -void HttpServer::taskFailed(kj::Exception&& exception) { +namespace { +void defaultHandleListenLoopException(kj::Exception&& exception) { KJ_LOG(ERROR, "unhandled exception in HTTP server", exception); } +} // namespace + +void HttpServer::taskFailed(kj::Exception&& exception) { + KJ_IF_MAYBE(handler, settings.errorHandler) { + handler->handleListenLoopException(kj::mv(exception)); + } else { + defaultHandleListenLoopException(kj::mv(exception)); + } +} + +HttpServer::SuspendedRequest::SuspendedRequest( + kj::Array bufferParam, kj::ArrayPtr leftoverParam, + kj::OneOf method, + kj::StringPtr url, HttpHeaders headers) + : buffer(kj::mv(bufferParam)), + leftover(leftoverParam), + method(method), + url(url), + headers(kj::mv(headers)) { + if (leftover.size() > 0) { + // We have a `leftover`; make sure it is a slice of `buffer`. + KJ_ASSERT(leftover.begin() >= buffer.begin() && leftover.begin() <= buffer.end()); + KJ_ASSERT(leftover.end() >= buffer.begin() && leftover.end() <= buffer.end()); + } else { + // We have no `leftover`, but we still expect it to point into `buffer` somewhere. This is + // important so that `messageHeaderEnd` is initialized correctly in HttpInputStreamImpl's + // constructor. + KJ_ASSERT(leftover.begin() >= buffer.begin() && leftover.begin() <= buffer.end()); + } +} + +HttpServer::SuspendedRequest HttpServer::SuspendableRequest::suspend() { + return connection.suspend(*this); +} + +kj::Promise HttpServerErrorHandler::handleClientProtocolError( + HttpHeaders::ProtocolError protocolError, kj::HttpService::Response& response) { + // Default error handler implementation. + + HttpHeaderTable headerTable {}; + HttpHeaders headers(headerTable); + headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); + + auto errorMessage = kj::str("ERROR: ", protocolError.description); + auto body = response.send(protocolError.statusCode, protocolError.statusMessage, + headers, errorMessage.size()); + + return body->write(errorMessage.begin(), errorMessage.size()) + .attach(kj::mv(errorMessage), kj::mv(body)); +} + +kj::Promise HttpServerErrorHandler::handleApplicationError( + kj::Exception exception, kj::Maybe response) { + // Default error handler implementation. + + if (exception.getType() == kj::Exception::Type::DISCONNECTED) { + // How do we tell an HTTP client that there was a transient network error, and it should + // try again immediately? There's no HTTP status code for this (503 is meant for "try + // again later, not now"). Here's an idea: Don't send any response; just close the + // connection, so that it looks like the connection between the HTTP client and server + // was dropped. A good client should treat this exactly the way we want. + // + // We also bail here to avoid logging the disconnection, which isn't very interesting. + return kj::READY_NOW; + } + + KJ_IF_MAYBE(r, response) { + KJ_LOG(INFO, "threw exception while serving HTTP response", exception); + + HttpHeaderTable headerTable {}; + HttpHeaders headers(headerTable); + headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); + + kj::String errorMessage; + kj::Own body; + + if (exception.getType() == kj::Exception::Type::OVERLOADED) { + errorMessage = kj::str( + "ERROR: The server is temporarily unable to handle your request. Details:\n\n", exception); + body = r->send(503, "Service Unavailable", headers, errorMessage.size()); + } else if (exception.getType() == kj::Exception::Type::UNIMPLEMENTED) { + errorMessage = kj::str( + "ERROR: The server does not implement this operation. Details:\n\n", exception); + body = r->send(501, "Not Implemented", headers, errorMessage.size()); + } else { + errorMessage = kj::str( + "ERROR: The server threw an exception. Details:\n\n", exception); + body = r->send(500, "Internal Server Error", headers, errorMessage.size()); + } + + return body->write(errorMessage.begin(), errorMessage.size()) + .attach(kj::mv(errorMessage), kj::mv(body)); + } + + KJ_LOG(ERROR, "HttpService threw exception after generating a partial response", + "too late to report error to client", exception); + return kj::READY_NOW; +} + +void HttpServerErrorHandler::handleListenLoopException(kj::Exception&& exception) { + defaultHandleListenLoopException(kj::mv(exception)); +} + +kj::Promise HttpServerErrorHandler::handleNoResponse(kj::HttpService::Response& response) { + HttpHeaderTable headerTable {}; + HttpHeaders headers(headerTable); + headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); + + constexpr auto errorMessage = "ERROR: The HttpService did not generate a response."_kj; + auto body = response.send(500, "Internal Server Error", headers, errorMessage.size()); + + return body->write(errorMessage.begin(), errorMessage.size()).attach(kj::mv(body)); +} } // namespace kj diff --git a/c++/src/kj/compat/http.h b/c++/src/kj/compat/http.h index 8d455cc258..151222a562 100644 --- a/c++/src/kj/compat/http.h +++ b/c++/src/kj/compat/http.h @@ -19,8 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_COMPAT_HTTP_H_ -#define KJ_COMPAT_HTTP_H_ +#pragma once // The KJ HTTP client/server library. // // This is a simple library which can be used to implement an HTTP client or server. Properties @@ -40,6 +39,9 @@ #include #include #include +#include + +KJ_BEGIN_HEADER namespace kj { @@ -55,7 +57,7 @@ namespace kj { MACRO(TRACE) \ /* standard methods */ \ /* */ \ - /* (CONNECT is intentionally omitted since it is handled specially in HttpHandler) */ \ + /* (CONNECT is intentionally omitted since it should be handled specially in HttpServer) */ \ \ MACRO(COPY) \ MACRO(LOCK) \ @@ -65,6 +67,7 @@ namespace kj { MACRO(PROPPATCH) \ MACRO(SEARCH) \ MACRO(UNLOCK) \ + MACRO(ACL) \ /* WebDAV */ \ \ MACRO(REPORT) \ @@ -79,15 +82,6 @@ namespace kj { MACRO(UNSUBSCRIBE) /* UPnP */ -#define KJ_HTTP_FOR_EACH_CONNECTION_HEADER(MACRO) \ - MACRO(connection, "Connection") \ - MACRO(contentLength, "Content-Length") \ - MACRO(keepAlive, "Keep-Alive") \ - MACRO(te, "TE") \ - MACRO(trailer, "Trailer") \ - MACRO(transferEncoding, "Transfer-Encoding") \ - MACRO(upgrade, "Upgrade") - enum class HttpMethod { // Enum of known HTTP methods. // @@ -99,8 +93,17 @@ KJ_HTTP_FOR_EACH_METHOD(DECLARE_METHOD) #undef DECLARE_METHOD }; +struct HttpConnectMethod {}; +// CONNECT is handled specially and separately from the other HttpMethods. + kj::StringPtr KJ_STRINGIFY(HttpMethod method); +kj::StringPtr KJ_STRINGIFY(HttpConnectMethod method); kj::Maybe tryParseHttpMethod(kj::StringPtr name); +kj::Maybe> tryParseHttpMethodAllowingConnect( + kj::StringPtr name); +// Like tryParseHttpMethod but, as the name suggests, explicitly allows for the CONNECT +// method. Added as a separate function instead of modifying tryParseHttpMethod to avoid +// breaking API changes in existing uses of tryParseHttpMethod. class HttpHeaderTable; @@ -126,25 +129,44 @@ class HttpHeaderId { inline bool operator>=(const HttpHeaderId& other) const { return id >= other.id; } inline size_t hashCode() const { return id; } + // Returned value is guaranteed to be small and never collide with other headers on the same + // table. kj::StringPtr toString() const; - void requireFrom(HttpHeaderTable& table) const; + void requireFrom(const HttpHeaderTable& table) const; // In debug mode, throws an exception if the HttpHeaderId is not from the given table. // // In opt mode, no-op. #define KJ_HTTP_FOR_EACH_BUILTIN_HEADER(MACRO) \ + /* Headers that are always read-only. */ \ + MACRO(CONNECTION, "Connection") \ + MACRO(KEEP_ALIVE, "Keep-Alive") \ + MACRO(TE, "TE") \ + MACRO(TRAILER, "Trailer") \ + MACRO(UPGRADE, "Upgrade") \ + \ + /* Headers that are read-only except in the case of a response to a HEAD request. */ \ + MACRO(CONTENT_LENGTH, "Content-Length") \ + MACRO(TRANSFER_ENCODING, "Transfer-Encoding") \ + \ + /* Headers that are read-only for WebSocket handshakes. */ \ + MACRO(SEC_WEBSOCKET_KEY, "Sec-WebSocket-Key") \ + MACRO(SEC_WEBSOCKET_VERSION, "Sec-WebSocket-Version") \ + MACRO(SEC_WEBSOCKET_ACCEPT, "Sec-WebSocket-Accept") \ + MACRO(SEC_WEBSOCKET_EXTENSIONS, "Sec-WebSocket-Extensions") \ + \ + /* Headers that you can write. */ \ MACRO(HOST, "Host") \ MACRO(DATE, "Date") \ MACRO(LOCATION, "Location") \ MACRO(CONTENT_TYPE, "Content-Type") - // For convenience, these very-common headers are valid for all HttpHeaderTables. You can refer - // to them like: + // For convenience, these headers are valid for all HttpHeaderTables. You can refer to them like: // // HttpHeaderId::HOST // - // TODO(0.7): Fill this out with more common headers. + // TODO(someday): Fill this out with more common headers. #define DECLARE_HEADER(id, name) \ static const HttpHeaderId id; @@ -154,10 +176,11 @@ class HttpHeaderId { #undef DECLARE_HEADER private: - HttpHeaderTable* table; + const HttpHeaderTable* table; uint id; - inline explicit constexpr HttpHeaderId(HttpHeaderTable* table, uint id): table(table), id(id) {} + inline explicit constexpr HttpHeaderId(const HttpHeaderTable* table, uint id) + : table(table), id(id) {} friend class HttpHeaderTable; friend class HttpHeaders; }; @@ -208,24 +231,35 @@ class HttpHeaderTable { kj::Own table; }; - KJ_DISALLOW_COPY(HttpHeaderTable); // Can't copy because HttpHeaderId points to the table. + KJ_DISALLOW_COPY_AND_MOVE(HttpHeaderTable); // Can't copy because HttpHeaderId points to the table. ~HttpHeaderTable() noexcept(false); - uint idCount(); + uint idCount() const; // Return the number of IDs in the table. - kj::Maybe stringToId(kj::StringPtr name); + kj::Maybe stringToId(kj::StringPtr name) const; // Try to find an ID for the given name. The matching is case-insensitive, per the HTTP spec. // // Note: if `name` contains characters that aren't allowed in HTTP header names, this may return // a bogus value rather than null, due to optimizations used in case-insensitive matching. - kj::StringPtr idToString(HttpHeaderId id); + kj::StringPtr idToString(HttpHeaderId id) const; // Get the canonical string name for the given ID. + bool isReady() const; + // Returns true if this HttpHeaderTable either was default constructed or its Builder has + // invoked `build()` and released it. + private: kj::Vector namesById; kj::Own idsByName; + + enum class BuildStatus { + UNSTARTED = 0, + BUILDING = 1, + FINISHED = 2, + }; + BuildStatus buildStatus = BuildStatus::UNSTARTED; }; class HttpHeaders { @@ -236,12 +270,26 @@ class HttpHeaders { // exception. public: - explicit HttpHeaders(HttpHeaderTable& table); + explicit HttpHeaders(const HttpHeaderTable& table); + + static bool isValidHeaderValue(kj::StringPtr value); + // This returns whether the value is a valid parameter to the set call. While the HTTP spec + // suggests that only printable ASCII characters are allowed in header values, in practice that + // turns out to not be the case. We follow the browser's lead in disallowing \r and \n. + // https://github.com/httpwg/http11bis/issues/19 + // Use this if you want to validate the value before supplying it to set() if you want to avoid + // an exception being thrown (e.g. you have custom error reporting). NOTE that set will still + // validate the value. If performance is a problem this API needs to be adjusted to a + // `validateHeaderValue` function that returns a special type that set can be confident has + // already passed through the validation routine. KJ_DISALLOW_COPY(HttpHeaders); HttpHeaders(HttpHeaders&&) = default; HttpHeaders& operator=(HttpHeaders&&) = default; + size_t size() const; + // Returns the number of headers that forEach() would iterate over. + void clear(); // Clears all contents, as if the object was freshly-allocated. However, calling this rather // than actually re-allocating the object may avoid re-allocation of internal objects. @@ -253,14 +301,33 @@ class HttpHeaders { // Creates a shallow clone of the HttpHeaders. The returned object references the same strings // as the original, owning none of them. + bool isWebSocket() const; + // Convenience method that checks for the presence of the header `Upgrade: websocket`. + // + // Note that this does not actually validate that the request is a complete WebSocket handshake + // with the correct version number -- such validation will occur if and when you call + // acceptWebSocket(). + kj::Maybe get(HttpHeaderId id) const; // Read a header. + // + // Note that there is intentionally no method to look up a header by string name rather than + // header ID. The intent is that you should always allocate a header ID for any header that you + // care about, so that you can get() it by ID. Headers with registered IDs are stored in an array + // indexed by ID, making lookup fast. Headers without registered IDs are stored in a separate list + // that is optimized for re-transmission of the whole list, but not for lookup. template void forEach(Func&& func) const; // Calls `func(name, value)` for each header in the set -- including headers that aren't mapped // to IDs in the header table. Both inputs are of type kj::StringPtr. + template + void forEach(Func1&& func1, Func2&& func2) const; + // Calls `func1(id, value)` for each header in the set that has a registered HttpHeaderId, and + // `func2(name, value)` for each header that does not. All calls to func1() precede all calls to + // func2(). + void set(HttpHeaderId id, kj::StringPtr value); void set(HttpHeaderId id, kj::String&& value); // Sets a header value, overwriting the existing value. @@ -293,53 +360,110 @@ class HttpHeaders { void takeOwnership(kj::String&& string); void takeOwnership(kj::Array&& chars); void takeOwnership(HttpHeaders&& otherHeaders); - // Takes overship of a string so that it lives until the HttpHeaders object is destroyed. Useful + // Takes ownership of a string so that it lives until the HttpHeaders object is destroyed. Useful // when you've passed a dynamic value to set() or add() or parse*(). - struct ConnectionHeaders { - // These headers govern details of the specific HTTP connection or framing of the content. - // Hence, they are managed internally within the HTTP library, and never appear in an - // HttpHeaders structure. - -#define DECLARE_HEADER(id, name) \ - kj::StringPtr id; - KJ_HTTP_FOR_EACH_CONNECTION_HEADER(DECLARE_HEADER) -#undef DECLARE_HEADER - }; - struct Request { HttpMethod method; kj::StringPtr url; - ConnectionHeaders connectionHeaders; + }; + struct ConnectRequest { + kj::StringPtr authority; }; struct Response { uint statusCode; kj::StringPtr statusText; - ConnectionHeaders connectionHeaders; }; - kj::Maybe tryParseRequest(kj::ArrayPtr content); - kj::Maybe tryParseResponse(kj::ArrayPtr content); + struct ProtocolError { + // Represents a protocol error, such as a bad request method or invalid headers. Debugging such + // errors is difficult without a copy of the data which we tried to parse, but this data is + // sensitive, so we can't just lump it into the error description directly. ProtocolError + // provides this sensitive data separate from the error description. + // + // TODO(cleanup): Should maybe not live in HttpHeaders? HttpServerErrorHandler::ProtocolError? + // Or HttpProtocolError? Or maybe we need a more general way of attaching sensitive context to + // kj::Exceptions? + + uint statusCode; + // Suggested HTTP status code that should be used when returning an error to the client. + // + // Most errors are 400. An unrecognized method will be 501. + + kj::StringPtr statusMessage; + // HTTP status message to go with `statusCode`, e.g. "Bad Request". + + kj::StringPtr description; + // An error description safe for all the world to see. + + kj::ArrayPtr rawContent; + // Unredacted data which led to the error condition. This may contain anything transported over + // HTTP, to include sensitive PII, so you must take care to sanitize this before using it in any + // error report that may leak to unprivileged eyes. + // + // This ArrayPtr is merely a copy of the `content` parameter passed to `tryParseRequest()` / + // `tryParseResponse()`, thus it remains valid for as long as a successfully-parsed HttpHeaders + // object would remain valid. + }; + + using RequestOrProtocolError = kj::OneOf; + using ResponseOrProtocolError = kj::OneOf; + using RequestConnectOrProtocolError = kj::OneOf; + + RequestOrProtocolError tryParseRequest(kj::ArrayPtr content); + RequestConnectOrProtocolError tryParseRequestOrConnect(kj::ArrayPtr content); + ResponseOrProtocolError tryParseResponse(kj::ArrayPtr content); + // Parse an HTTP header blob and add all the headers to this object. // - // `content` should be all text from the start of the request to the first occurrance of two + // `content` should be all text from the start of the request to the first occurrence of two // newlines in a row -- including the first of these two newlines, but excluding the second. // // The parse is performed with zero copies: The callee clobbers `content` with '\0' characters // to split it into a bunch of shorter strings. The caller must keep `content` valid until the // `HttpHeaders` is destroyed, or pass it to `takeOwnership()`. + bool tryParse(kj::ArrayPtr content); + // Like tryParseRequest()/tryParseResponse(), but don't expect any request/response line. + kj::String serializeRequest(HttpMethod method, kj::StringPtr url, - const ConnectionHeaders& connectionHeaders) const; + kj::ArrayPtr connectionHeaders = nullptr) const; + kj::String serializeConnectRequest(kj::StringPtr authority, + kj::ArrayPtr connectionHeaders = nullptr) const; kj::String serializeResponse(uint statusCode, kj::StringPtr statusText, - const ConnectionHeaders& connectionHeaders) const; + kj::ArrayPtr connectionHeaders = nullptr) const; + // **Most applications will not use these methods; they are called by the HTTP client and server + // implementations.** + // // Serialize the headers as a complete request or response blob. The blob uses '\r\n' newlines // and includes the double-newline to indicate the end of the headers. + // + // `connectionHeaders`, if provided, contains connection-level headers supplied by the HTTP + // implementation, in the order specified by the KJ_HTTP_FOR_EACH_BUILTIN_HEADER macro. These + // headers values override any corresponding header value in the HttpHeaders object. The + // CONNECTION_HEADERS_COUNT constants below can help you construct this `connectionHeaders` array. + + enum class BuiltinIndicesEnum { + #define HEADER_ID(id, name) id, + KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID) + #undef HEADER_ID + }; + + struct BuiltinIndices { + #define HEADER_ID(id, name) static constexpr uint id = static_cast(BuiltinIndicesEnum::id); + KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID) + #undef HEADER_ID + }; + + static constexpr uint HEAD_RESPONSE_CONNECTION_HEADERS_COUNT = BuiltinIndices::CONTENT_LENGTH; + static constexpr uint CONNECTION_HEADERS_COUNT = BuiltinIndices::SEC_WEBSOCKET_KEY; + static constexpr uint WEBSOCKET_CONNECTION_HEADERS_COUNT = BuiltinIndices::HOST; + // Constants for use with HttpHeaders::serialize*(). kj::String toString() const; private: - HttpHeaderTable* table; + const HttpHeaderTable* table; kj::Array indexedHeaders; // Size is always table->idCount(). @@ -352,29 +476,292 @@ class HttpHeaders { kj::Vector> ownedStrings; - kj::Maybe addNoCheck(kj::StringPtr name, kj::StringPtr value); + void addNoCheck(kj::StringPtr name, kj::StringPtr value); kj::StringPtr cloneToOwn(kj::StringPtr str); kj::String serialize(kj::ArrayPtr word1, kj::ArrayPtr word2, kj::ArrayPtr word3, - const ConnectionHeaders& connectionHeaders) const; + kj::ArrayPtr connectionHeaders) const; - bool parseHeaders(char* ptr, char* end, ConnectionHeaders& connectionHeaders); + bool parseHeaders(char* ptr, char* end); // TODO(perf): Arguably we should store a map, but header sets are never very long // TODO(perf): We could optimize for common headers by storing them directly as fields. We could // also add direct accessors for those headers. }; +class HttpInputStream { + // Low-level interface to receive HTTP-formatted messages (headers followed by body) from an + // input stream, without a paired output stream. + // + // Most applications will not use this. Regular HTTP clients and servers don't need this. This + // is mainly useful for apps implementing various protocols that look like HTTP but aren't + // really. + +public: + struct Request { + HttpMethod method; + kj::StringPtr url; + const HttpHeaders& headers; + kj::Own body; + }; + virtual kj::Promise readRequest() = 0; + // Reads one HTTP request from the input stream. + // + // The returned struct contains pointers directly into a buffer that is invalidated on the next + // message read. + + struct Connect { + kj::StringPtr authority; + const HttpHeaders& headers; + kj::Own body; + }; + virtual kj::Promise> readRequestAllowingConnect() = 0; + // Reads one HTTP request from the input stream. + // + // The returned struct contains pointers directly into a buffer that is invalidated on the next + // message read. + + struct Response { + uint statusCode; + kj::StringPtr statusText; + const HttpHeaders& headers; + kj::Own body; + }; + virtual kj::Promise readResponse(HttpMethod requestMethod) = 0; + // Reads one HTTP response from the input stream. + // + // You must provide the request method because responses to HEAD requests require special + // treatment. + // + // The returned struct contains pointers directly into a buffer that is invalidated on the next + // message read. + + struct Message { + const HttpHeaders& headers; + kj::Own body; + }; + virtual kj::Promise readMessage() = 0; + // Reads an HTTP header set followed by a body, with no request or response line. This is not + // useful for HTTP but may be useful for other protocols that make the unfortunate choice to + // mimic HTTP message format, such as Visual Studio Code's JSON-RPC transport. + // + // The returned struct contains pointers directly into a buffer that is invalidated on the next + // message read. + + virtual kj::Promise awaitNextMessage() = 0; + // Waits until more data is available, but doesn't consume it. Returns false on EOF. +}; + +class EntropySource { + // Interface for an object that generates entropy. Typically, cryptographically-random entropy + // is expected. + // + // TODO(cleanup): Put this somewhere more general. + +public: + virtual void generate(kj::ArrayPtr buffer) = 0; +}; + +struct CompressionParameters { + // These are the parameters for `Sec-WebSocket-Extensions` permessage-deflate extension. + // Since we cannot distinguish the client/server in `upgradeToWebSocket`, we use the prefixes + // `inbound` and `outbound` instead. + bool outboundNoContextTakeover = false; + bool inboundNoContextTakeover = false; + kj::Maybe outboundMaxWindowBits = nullptr; + kj::Maybe inboundMaxWindowBits = nullptr; +}; + class WebSocket { + // Interface representincg an open WebSocket session. + // + // Each side can send and receive data and "close" messages. + // + // Ping/Pong and message fragmentation are not exposed through this interface. These features of + // the underlying WebSocket protocol are not exposed by the browser-level JavaScript API either, + // and thus applications typically need to implement these features at the application protocol + // level instead. The implementation is, however, expected to reply to Ping messages it receives. + public: - WebSocket(kj::Own stream); - // Create a WebSocket wrapping the given I/O stream. + virtual kj::Promise send(kj::ArrayPtr message) = 0; + virtual kj::Promise send(kj::ArrayPtr message) = 0; + // Send a message (binary or text). The underlying buffer must remain valid, and you must not + // call send() again, until the returned promise resolves. + + virtual kj::Promise close(uint16_t code, kj::StringPtr reason) = 0; + // Send a Close message. + // + // Note that the returned Promise resolves once the message has been sent -- it does NOT wait + // for the other end to send a Close reply. The application should await a reply before dropping + // the WebSocket object. + + virtual kj::Promise disconnect() = 0; + // Sends EOF on the underlying connection without sending a "close" message. This is NOT a clean + // shutdown, but is sometimes useful when you want the other end to trigger whatever behavior + // it normally triggers when a connection is dropped. + + virtual void abort() = 0; + // Forcefully close this WebSocket, such that the remote end should get a DISCONNECTED error if + // it continues to write. This differs from disconnect(), which only closes the sending + // direction, but still allows receives. + + virtual kj::Promise whenAborted() = 0; + // Resolves when the remote side aborts the connection such that send() would throw DISCONNECTED, + // if this can be detected without actually writing a message. (If not, this promise never + // resolves, but send() or receive() will throw DISCONNECTED when appropriate. See also + // kj::AsyncOutputStream::whenWriteDisconnected().) + + struct ProtocolError { + // Represents a protocol error, such as a bad opcode or oversize message. - kj::Promise send(kj::ArrayPtr message); - kj::Promise send(kj::ArrayPtr message); + uint statusCode; + // Suggested WebSocket status code that should be used when returning an error to the client. + // + // Most errors are 1002; an oversize message will be 1009. + + kj::StringPtr description; + // An error description safe for all the world to see. This should be at most 123 bytes so that + // it can be used as the body of a Close frame (RFC 6455 sections 5.5 and 5.5.1). + }; + + struct Close { + uint16_t code; + kj::String reason; + }; + + typedef kj::OneOf, Close> Message; + + static constexpr size_t SUGGESTED_MAX_MESSAGE_SIZE = 1u << 20; // 1MB + + virtual kj::Promise receive(size_t maxSize = SUGGESTED_MAX_MESSAGE_SIZE) = 0; + // Read one message from the WebSocket and return it. Can only call once at a time. Do not call + // again after Close is received. + + virtual kj::Promise pumpTo(WebSocket& other); + // Continuously receives messages from this WebSocket and send them to `other`. + // + // On EOF, calls other.disconnect(), then resolves. + // + // On other read errors, calls other.close() with the error, then resolves. + // + // On write error, rejects with the error. + + virtual kj::Maybe> tryPumpFrom(WebSocket& other); + // Either returns null, or performs the equivalent of other.pumpTo(*this). Only returns non-null + // if this WebSocket implementation is able to perform the pump in an optimized way, better than + // the default implementation of pumpTo(). The default implementation of pumpTo() always tries + // calling this first, and the default implementation of tryPumpFrom() always returns null. + + virtual uint64_t sentByteCount() = 0; + virtual uint64_t receivedByteCount() = 0; + + enum ExtensionsContext { + // Indicate whether a Sec-WebSocket-Extension header should be rendered for use in request + // headers or response headers. + REQUEST, + RESPONSE + }; + virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) { return nullptr; } + // If pumpTo() / tryPumpFrom() is able to be optimized only if the other WebSocket is using + // certain extensions (e.g. compression settings), then this method returns what those extensions + // are. For example, matching extensions between standard WebSockets allows pumping to be + // implemented by pumping raw bytes between network connections, without reading individual frames. + // + // A null return value indicates that there is no preference. A non-null return value containing + // an empty string indicates a preference for no extensions to be applied. +}; + +using TlsStarterCallback = kj::Maybe(kj::StringPtr)>>; +struct HttpConnectSettings { + bool useTls = false; + // Requests to automatically establish a TLS session over the connection. The remote party + // will be expected to present a valid certificate matching the requested hostname. + kj::Maybe tlsStarter; + // This is an output parameter. It doesn't need to be set. But if it is set, then it may get + // filled with a callback function. It will get filled with `nullptr` if any of the following + // are true: + // + // * kj is not built with TLS support + // * the underlying HttpClient does not support the startTls mechanism + // * `useTls` has been set to `true` and so TLS has already been started + // + // The callback function itself can be called to initiate a TLS handshake on the connection in + // between write() operations. It is not allowed to initiate a TLS handshake while a write + // operation or a pump operation to the connection exists. Read operations are not subject to + // the same constraint, however: implementations are required to be able to handle TLS + // initiation while a read operation or pump operation from the connection exists. Once the + // promise returned from the callback is fulfilled, the connection has become a secure stream, + // and write operations are once again permitted. The StringPtr parameter to the callback, + // expectedServerHostname may be dropped after the function synchronously returns. + // + // The PausableReadAsyncIoStream class defined below can be used to ensure that read operations + // are not pending when the tlsStarter is invoked. + // + // This mechanism is required for certain protocols, more info can be found on + // https://en.wikipedia.org/wiki/Opportunistic_TLS. +}; + + +class PausableReadAsyncIoStream final: public kj::AsyncIoStream { + // A custom AsyncIoStream which can pause pending reads. This is used by startTls to pause a + // a read before TLS is initiated. + // + // TODO(cleanup): this class should be rewritten to use a CRTP mixin approach so that pumps + // can be optimised once startTls is invoked. + class PausableRead; +public: + PausableReadAsyncIoStream(kj::Own stream) + : inner(kj::mv(stream)), currentlyWriting(false), currentlyReading(false) {} + + _::Deferred> trackRead(); + + _::Deferred> trackWrite(); + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + + kj::Promise tryReadImpl(void* buffer, size_t minBytes, size_t maxBytes); + + kj::Maybe tryGetLength() override; + + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override; + + kj::Promise write(const void* buffer, size_t size) override; + + kj::Promise write(kj::ArrayPtr> pieces) override; + + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override; + + kj::Promise whenWriteDisconnected() override; + + void shutdownWrite() override; + + void abortRead() override; + + kj::Maybe getFd() const override; + + void pause(); + + void unpause(); + + bool getCurrentlyReading(); + + bool getCurrentlyWriting(); + + kj::Own takeStream(); + + void replaceStream(kj::Own stream); + + void reject(kj::Exception&& exc); + +private: + kj::Own inner; + kj::Maybe maybePausableRead; + bool currentlyWriting; + bool currentlyReading; }; class HttpClient { @@ -392,7 +779,7 @@ class HttpClient { kj::StringPtr statusText; const HttpHeaders* headers; kj::Own body; - // `statusText` and `headers` remain valid until `body` is dropped. + // `statusText` and `headers` remain valid until `body` is dropped or read from. }; struct Request { @@ -403,7 +790,7 @@ class HttpClient { // Content-Length: 0. kj::Promise response; - // Promise for the eventual respnose. + // Promise for the eventual response. }; virtual Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, @@ -424,18 +811,55 @@ class HttpClient { uint statusCode; kj::StringPtr statusText; const HttpHeaders* headers; - kj::OneOf, kj::Own> upstreamOrBody; - // `statusText` and `headers` remain valid until `upstreamOrBody` is dropped. + kj::OneOf, kj::Own> webSocketOrBody; + // `statusText` and `headers` remain valid until `webSocketOrBody` is dropped or read from. }; virtual kj::Promise openWebSocket( - kj::StringPtr url, const HttpHeaders& headers, kj::Own downstream); + kj::StringPtr url, const HttpHeaders& headers); // Tries to open a WebSocket. Default implementation calls send() and never returns a WebSocket. // - // `url` and `headers` are invalidated when the returned promise resolves. + // `url` and `headers` need only remain valid until `openWebSocket()` returns (they can be + // stack-allocated). + + struct ConnectRequest { + struct Status { + uint statusCode; + kj::String statusText; + kj::Own headers; + kj::Maybe> errorBody; + // If the connect request is rejected, the statusCode can be any HTTP status code + // outside the 200-299 range and errorBody *may* be specified if there is a rejection + // payload. + + // TODO(perf): Having Status own the statusText and headers is a bit unfortunate. + // Ideally we could have these be non-owned so that the headers object could just + // point directly into HttpOutputStream's buffer and not be copied. That's a bit + // more difficult to with CONNECT since the lifetimes of the buffers are a little + // different than with regular HTTP requests. It should still be possible but for + // now copying and owning the status text and headers is easier. + + Status(uint statusCode, + kj::String statusText, + kj::Own headers, + kj::Maybe> errorBody = nullptr) + : statusCode(statusCode), + statusText(kj::mv(statusText)), + headers(kj::mv(headers)), + errorBody(kj::mv(errorBody)) {} + }; + + kj::Promise status; + kj::Own connection; + }; - virtual kj::Promise> connect(kj::String host); - // Handles CONNECT requests. Only relevant for proxy clients. Default implementation throws - // UNIMPLEMENTED. + virtual ConnectRequest connect( + kj::StringPtr host, const HttpHeaders& headers, HttpConnectSettings settings); + // Handles CONNECT requests. + // + // `host` must specify both the host and port (e.g. "example.org:1234"). + // + // The `host` and `headers` need only remain valid until `connect()` returns (it can be + // stack-allocated). }; class HttpService { @@ -463,6 +887,33 @@ class HttpService { // // `statusText` and `headers` need only remain valid until send() returns (they can be // stack-allocated). + // + // `send()` may only be called a single time. Calling it a second time will cause an exception + // to be thrown. + + virtual kj::Own acceptWebSocket(const HttpHeaders& headers) = 0; + // If headers.isWebSocket() is true then you can call acceptWebSocket() instead of send(). + // + // If the request is an invalid WebSocket request (e.g., it has an Upgrade: websocket header, + // but other WebSocket-related headers are invalid), `acceptWebSocket()` will throw an + // exception, and the HttpServer will return a 400 Bad Request response and close the + // connection. In this circumstance, the HttpServer will ignore any exceptions which propagate + // from the `HttpService::request()` promise. `HttpServerErrorHandler::handleApplicationError()` + // will not be invoked, and the HttpServer's listen task will be fulfilled normally. + // + // `acceptWebSocket()` may only be called a single time. Calling it a second time will cause an + // exception to be thrown. + + kj::Promise sendError(uint statusCode, kj::StringPtr statusText, + const HttpHeaders& headers); + kj::Promise sendError(uint statusCode, kj::StringPtr statusText, + const HttpHeaderTable& headerTable); + // Convenience wrapper around send() which sends a basic error. A generic error page specifying + // the error code is sent as the body. + // + // You must provide headers or a header table because downstream service wrappers may be + // expecting response headers built with a particular table so that they can insert additional + // headers. }; virtual kj::Promise request( @@ -475,40 +926,151 @@ class HttpService { // // `url` and `headers` are invalidated on the first read from `requestBody` or when the returned // promise resolves, whichever comes first. - - class WebSocketResponse: public Response { + // + // Request processing can be canceled by dropping the returned promise. HttpServer may do so if + // the client disconnects prematurely. + // + // The implementation of `request()` should usually not try to use `response` in any way in + // exception-handling code, because it is often not possible to tell whether `Response::send()` or + // `Response::acceptWebSocket()` has already been called. Instead, to generate error HTTP + // responses for the client, implement an HttpServerErrorHandler and pass it to the HttpServer via + // HttpServerSettings. If the `HttpService::request()` promise rejects and no response has yet + // been sent, `HttpServerErrorHandler::handleApplicationError()` will be passed a non-null + // `Maybe` parameter. + + class ConnectResponse { public: - kj::Own startWebSocket( - uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, - WebSocket& upstream); - // Begin the response. - // - // `statusText` and `headers` need only remain valid until startWebSocket() returns (they can - // be stack-allocated). + virtual void accept( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers) = 0; + // Signals acceptance of the CONNECT tunnel. + + virtual kj::Own reject( + uint statusCode, + kj::StringPtr statusText, + const HttpHeaders& headers, + kj::Maybe expectedBodySize = nullptr) = 0; + // Signals rejection of the CONNECT tunnel. }; - virtual kj::Promise openWebSocket( - kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response); - // Tries to open a WebSocket. Default implementation calls request() and never returns a - // WebSocket. + virtual kj::Promise connect(kj::StringPtr host, + const HttpHeaders& headers, + kj::AsyncIoStream& connection, + ConnectResponse& response, + HttpConnectSettings settings); + // Handles CONNECT requests. + // + // The `host` must include host and port. + // + // `host` and `headers` are invalidated when accept or reject is called on the ConnectResponse + // or when the returned promise resolves, whichever comes first. + // + // The connection is provided to support pipelining. Writes to the connection will be blocked + // until one of either accept() or reject() is called on tunnel. Reads from the connection are + // permitted at any time. + // + // Request processing can be canceled by dropping the returned promise. HttpServer may do so if + // the client disconnects prematurely. +}; + +class HttpClientErrorHandler { +public: + virtual HttpClient::Response handleProtocolError(HttpHeaders::ProtocolError protocolError); + // Override this function to customize error handling when the client receives an HTTP message + // that fails to parse. The default implementations throws an exception. // - // `url` and `headers` are invalidated when the returned promise resolves. + // There are two main use cases for overriding this: + // 1. `protocolError` contains the actual header content that failed to parse, giving you the + // opportunity to log it for debugging purposes. The default implementation throws away this + // content. + // 2. You could potentially convert protocol errors into HTTP error codes, e.g. 502 Bad Gateway. + // + // Note that `protocolError` may contain pointers into buffers that are no longer valid once + // this method returns; you will have to make copies if you want to keep them. + + virtual HttpClient::WebSocketResponse handleWebSocketProtocolError( + HttpHeaders::ProtocolError protocolError); + // Like handleProtocolError() but for WebSocket requests. The default implementation calls + // handleProtocolError() and converts the Response to WebSocketResponse. There is probably very + // little reason to override this. +}; + +struct HttpClientSettings { + kj::Duration idleTimeout = 5 * kj::SECONDS; + // For clients which automatically create new connections, any connection idle for at least this + // long will be closed. Set this to 0 to prevent connection reuse entirely. + + kj::Maybe entropySource = nullptr; + // Must be provided in order to use `openWebSocket`. If you don't need WebSockets, this can be + // omitted. The WebSocket protocol uses random values to avoid triggering flaws (including + // security flaws) in certain HTTP proxy software. Specifically, entropy is used to generate the + // `Sec-WebSocket-Key` header and to generate frame masks. If you know that there are no broken + // or vulnerable proxies between you and the server, you can provide a dummy entropy source that + // doesn't generate real entropy (e.g. returning the same value every time). Otherwise, you must + // provide a cryptographically-random entropy source. + + kj::Maybe errorHandler = nullptr; + // Customize how protocol errors are handled by the HttpClient. If null, HttpClientErrorHandler's + // default implementation will be used. + + enum WebSocketCompressionMode { + NO_COMPRESSION, + MANUAL_COMPRESSION, // Lets the application decide the compression configuration (if any). + AUTOMATIC_COMPRESSION, // Automatically includes the compression header in the WebSocket request. + }; + WebSocketCompressionMode webSocketCompressionMode = NO_COMPRESSION; - virtual kj::Promise> connect(kj::String host); - // Handles CONNECT requests. Only relevant for proxy services. Default implementation throws - // UNIMPLEMENTED. + kj::Maybe tlsContext; + // A reference to a TLS context that will be used when tlsStarter is invoked. +}; + +class WebSocketErrorHandler { +public: + virtual kj::Exception handleWebSocketProtocolError(WebSocket::ProtocolError protocolError); + // Handles low-level protocol errors in received WebSocket data. + // + // This is called when the WebSocket peer sends us bad data *after* a successful WebSocket + // upgrade, e.g. a continuation frame without a preceding start frame, a frame with an unknown + // opcode, or similar. + // + // You would override this method in order to customize the exception. You cannot prevent the + // exception from being thrown. }; -kj::Own newHttpClient(HttpHeaderTable& responseHeaderTable, kj::Network& network, - kj::Maybe tlsNetwork = nullptr); -// Creates a proxy HttpClient that connects to hosts over the given network. +kj::Own newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, + kj::Network& network, kj::Maybe tlsNetwork, + HttpClientSettings settings = HttpClientSettings()); +// Creates a proxy HttpClient that connects to hosts over the given network. The URL must always +// be an absolute URL; the host is parsed from the URL. This implementation will automatically +// add an appropriate Host header (and convert the URL to just a path) once it has connected. +// +// Note that if you wish to route traffic through an HTTP proxy server rather than connect to +// remote hosts directly, you should use the form of newHttpClient() that takes a NetworkAddress, +// and supply the proxy's address. // // `responseHeaderTable` is used when parsing HTTP responses. Requests can use any header table. // -// `tlsNetwork` is required to support HTTPS destination URLs. Otherwise, only HTTP URLs can be +// `tlsNetwork` is required to support HTTPS destination URLs. If null, only HTTP URLs can be // fetched. -kj::Own newHttpClient(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream); +kj::Own newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, + kj::NetworkAddress& addr, + HttpClientSettings settings = HttpClientSettings()); +// Creates an HttpClient that always connects to the given address no matter what URL is requested. +// The client will open and close connections as needed. It will attempt to reuse connections for +// multiple requests but will not send a new request before the previous response on the same +// connection has completed, as doing so can result in head-of-line blocking issues. The client may +// be used as a proxy client or a host client depending on whether the peer is operating as +// a proxy. (Hint: This is the best kind of client to use when routing traffic through an HTTP +// proxy. `addr` should be the address of the proxy, and the proxy itself will resolve remote hosts +// based on the URLs passed to it.) +// +// `responseHeaderTable` is used when parsing HTTP responses. Requests can use any header table. + +kj::Own newHttpClient(const HttpHeaderTable& responseHeaderTable, + kj::AsyncIoStream& stream, + HttpClientSettings settings = HttpClientSettings()); // Creates an HttpClient that speaks over the given pre-established connection. The client may // be used as a proxy client or a host client depending on whether the peer is operating as // a proxy. @@ -519,10 +1081,63 @@ kj::Own newHttpClient(HttpHeaderTable& responseHeaderTable, kj::Asyn // subsequent requests will fail. If a response takes a long time, it blocks subsequent responses. // If a WebSocket is opened successfully, all subsequent requests fail. +kj::Own newConcurrencyLimitingHttpClient( + HttpClient& inner, uint maxConcurrentRequests, + kj::Function countChangedCallback); +// Creates an HttpClient that is limited to a maximum number of concurrent requests. Additional +// requests are queued, to be opened only after an open request completes. `countChangedCallback` +// is called when a new connection is opened or enqueued and when an open connection is closed, +// passing the number of open and pending connections. + kj::Own newHttpClient(HttpService& service); kj::Own newHttpService(HttpClient& client); // Adapts an HttpClient to an HttpService and vice versa. +kj::Own newHttpInputStream( + kj::AsyncInputStream& input, const HttpHeaderTable& headerTable); +// Create an HttpInputStream on top of the given stream. Normally applications would not call this +// directly, but it can be useful for implementing protocols that aren't quite HTTP but use similar +// message delimiting. +// +// The HttpInputStream implementation does read-ahead buffering on `input`. Therefore, when the +// HttpInputStream is destroyed, some data read from `input` may be lost, so it's not possible to +// continue reading from `input` in a reliable way. + +kj::Own newWebSocket(kj::Own stream, + kj::Maybe maskEntropySource, + kj::Maybe compressionConfig = nullptr, + kj::Maybe errorHandler = nullptr); +// Create a new WebSocket on top of the given stream. It is assumed that the HTTP -> WebSocket +// upgrade handshake has already occurred (or is not needed), and messages can immediately be +// sent and received on the stream. Normally applications would not call this directly. +// +// `maskEntropySource` is used to generate cryptographically-random frame masks. If null, outgoing +// frames will not be masked. Servers are required NOT to mask their outgoing frames, but clients +// ARE required to do so. So, on the client side, you MUST specify an entropy source. The mask +// must be crytographically random if the data being sent on the WebSocket may be malicious. The +// purpose of the mask is to prevent badly-written HTTP proxies from interpreting "things that look +// like HTTP requests" in a message as being actual HTTP requests, which could result in cache +// poisoning. See RFC6455 section 10.3. +// +// `compressionConfig` is an optional argument that allows us to specify how the WebSocket should +// compress and decompress messages. The configuration is determined by the +// `Sec-WebSocket-Extensions` header during WebSocket negotiation. +// +// `errorHandler` is an optional argument that lets callers throw custom exceptions for WebSocket +// protocol errors. + +struct WebSocketPipe { + kj::Own ends[2]; +}; + +WebSocketPipe newWebSocketPipe(); +// Create a WebSocket pipe. Messages written to one end of the pipe will be readable from the other +// end. No buffering occurs -- a message send does not complete until a corresponding receive +// accepts the message. + +class HttpServerErrorHandler; +class HttpServerCallbacks; + struct HttpServerSettings { kj::Duration headerTimeout = 15 * kj::SECONDS; // After initial connection open, or after receiving the first byte of a pipelined request, @@ -531,20 +1146,106 @@ struct HttpServerSettings { kj::Duration pipelineTimeout = 5 * kj::SECONDS; // After one request/response completes, we'll wait up to this long for a pipelined request to // arrive. + + kj::Duration canceledUploadGracePeriod = 1 * kj::SECONDS; + size_t canceledUploadGraceBytes = 65536; + // If the HttpService sends a response and returns without having read the entire request body, + // then we have to decide whether to close the connection or wait for the client to finish the + // request so that it can pipeline the next one. We'll give them a grace period defined by the + // above two values -- if they hit either one, we'll close the socket, but if the request + // completes, we'll let the connection stay open to handle more requests. + + kj::Maybe errorHandler = nullptr; + // Customize how client protocol errors and service application exceptions are handled by the + // HttpServer. If null, HttpServerErrorHandler's default implementation will be used. + + kj::Maybe callbacks = nullptr; + // Additional optional callbacks used to control some server behavior. + + kj::Maybe webSocketErrorHandler = nullptr; + // Customize exceptions thrown on WebSocket protocol errors. + + enum WebSocketCompressionMode { + NO_COMPRESSION, + MANUAL_COMPRESSION, // Gives the application more control when considering whether to compress. + AUTOMATIC_COMPRESSION, // Will perform compression parameter negotiation if client requests it. + }; + WebSocketCompressionMode webSocketCompressionMode = NO_COMPRESSION; +}; + +class HttpServerErrorHandler { +public: + virtual kj::Promise handleClientProtocolError( + HttpHeaders::ProtocolError protocolError, kj::HttpService::Response& response); + virtual kj::Promise handleApplicationError( + kj::Exception exception, kj::Maybe response); + virtual kj::Promise handleNoResponse(kj::HttpService::Response& response); + // Override these functions to customize error handling during the request/response cycle. + // + // Client protocol errors arise when the server receives an HTTP message that fails to parse. As + // such, HttpService::request() will not have been called yet, and the handler is always + // guaranteed an opportunity to send a response. The default implementation of + // handleClientProtocolError() replies with a 400 Bad Request response. + // + // Application errors arise when HttpService::request() throws an exception. The default + // implementation of handleApplicationError() maps the following exception types to HTTP statuses, + // and generates bodies from the stringified exceptions: + // + // - OVERLOADED: 503 Service Unavailable + // - UNIMPLEMENTED: 501 Not Implemented + // - DISCONNECTED: (no response) + // - FAILED: 500 Internal Server Error + // + // No-response errors occur when HttpService::request() allows its promise to settle before + // sending a response. The default implementation of handleNoResponse() replies with a 500 + // Internal Server Error response. + // + // Unlike `HttpService::request()`, when calling `response.send()` in the context of one of these + // functions, a "Connection: close" header will be added, and the connection will be closed. + // + // Also unlike `HttpService::request()`, it is okay to return kj::READY_NOW without calling + // `response.send()`. In this case, no response will be sent, and the connection will be closed. + + virtual void handleListenLoopException(kj::Exception&& exception); + // Override this function to customize error handling for individual connections in the + // `listenHttp()` overload which accepts a ConnectionReceiver reference. + // + // The default handler uses KJ_LOG() to log the exception as an error. }; -class HttpServer: private kj::TaskSet::ErrorHandler { +class HttpServerCallbacks { +public: + virtual bool shouldClose() { return false; } + // Whenever the HttpServer begins response headers, it will check `shouldClose()` to decide + // whether to send a `Connection: close` header and close the connection. + // + // This can be useful e.g. if the server has too many connections open and wants to shed some + // of them. Note that to implement graceful shutdown of a server, you should use + // `HttpServer::drain()` instead. +}; + +class HttpServer final: private kj::TaskSet::ErrorHandler { // Class which listens for requests on ports or connections and sends them to an HttpService. public: typedef HttpServerSettings Settings; + typedef kj::Function(kj::AsyncIoStream&)> HttpServiceFactory; + class SuspendableRequest; + typedef kj::Function>(SuspendableRequest&)> + SuspendableHttpServiceFactory; - HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service, + HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, HttpService& service, Settings settings = Settings()); // Set up an HttpServer that directs incoming connections to the given service. The service // may be a host service or a proxy service depending on whether you are intending to implement // an HTTP server or an HTTP proxy. + HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, + HttpServiceFactory serviceFactory, Settings settings = Settings()); + // Like the other constructor, but allows a new HttpService object to be used for each + // connection, based on the connection object. This is particularly useful for capturing the + // client's IP address and injecting it as a header. + kj::Promise drain(); // Stop accepting new connections or new requests on existing connections. Finish any requests // that are already executing, then close the connections. Returns once no more requests are @@ -561,15 +1262,73 @@ class HttpServer: private kj::TaskSet::ErrorHandler { // Reads HTTP requests from the given connection and directs them to the handler. A successful // completion of the promise indicates that all requests received on the connection resulted in // a complete response, and the client closed the connection gracefully or drain() was called. - // The promise throws if an unparseable request is received or if some I/O error occurs. Dropping + // The promise throws if an unparsable request is received or if some I/O error occurs. Dropping // the returned promise will cancel all I/O on the connection and cancel any in-flight requests. + kj::Promise listenHttpCleanDrain(kj::AsyncIoStream& connection); + // Like listenHttp(), but allows you to potentially drain the server without closing connections. + // The returned promise resolves to `true` if the connection has been left in a state where a + // new HttpServer could potentially accept further requests from it. If `false`, then the + // connection is either in an inconsistent state or already completed a closing handshake; the + // caller should close it without any further reads/writes. Note this only ever returns `true` + // if you called `drain()` -- otherwise this server would keep handling the connection. + + class SuspendedRequest { + // SuspendedRequest is a representation of a request immediately after parsing the method line and + // headers. You can obtain one of these by suspending a request by calling + // SuspendableRequest::suspend(), then later resume the request with another call to + // listenHttpCleanDrain(). + + public: + // Nothing, this is an opaque type. + + private: + SuspendedRequest(kj::Array, kj::ArrayPtr, kj::OneOf, kj::StringPtr, HttpHeaders); + + kj::Array buffer; + // A buffer containing at least the request's method, URL, and headers, and possibly content + // thereafter. + + kj::ArrayPtr leftover; + // Pointer to the end of the request headers. If this has a non-zero length, then our buffer + // contains additional content, presumably the head of the request body. + + kj::OneOf method; + kj::StringPtr url; + HttpHeaders headers; + // Parsed request front matter. `url` and `headers` both store pointers into `buffer`. + + friend class HttpServer; + }; + + kj::Promise listenHttpCleanDrain(kj::AsyncIoStream& connection, + SuspendableHttpServiceFactory factory, + kj::Maybe suspendedRequest = nullptr); + // Like listenHttpCleanDrain(), but allows you to suspend requests. + // + // When this overload is in use, the HttpServer's default HttpService or HttpServiceFactory is not + // used. Instead, the HttpServer reads the request method line and headers, then calls `factory` + // with a SuspendableRequest representing the request parsed so far. The factory may then return + // a kj::Own for that specific request, or it may call SuspendableRequest::suspend() + // and return nullptr. (It is an error for the factory to return nullptr without also calling + // suspend(); this will result in a rejected listenHttpCleanDrain() promise.) + // + // If the factory chooses to suspend, the listenHttpCleanDrain() promise is resolved with false + // at the earliest opportunity. + // + // SuspendableRequest::suspend() returns a SuspendedRequest. You can resume this request later by + // calling this same listenHttpCleanDrain() overload with the original connection stream, and the + // SuspendedRequest in question. + // + // This overload of listenHttpCleanDrain() implements draining, as documented above. Note that the + // returned promise will resolve to false (not clean) if a request is suspended. + private: class Connection; kj::Timer& timer; - HttpHeaderTable& requestHeaderTable; - HttpService& service; + const HttpHeaderTable& requestHeaderTable; + kj::OneOf service; Settings settings; bool draining = false; @@ -581,28 +1340,73 @@ class HttpServer: private kj::TaskSet::ErrorHandler { kj::TaskSet tasks; - HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service, + HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, + kj::OneOf service, Settings settings, kj::PromiseFulfillerPair paf); kj::Promise listenLoop(kj::ConnectionReceiver& port); void taskFailed(kj::Exception&& exception) override; + + kj::Promise listenHttpImpl(kj::AsyncIoStream& connection, bool wantCleanDrain); + kj::Promise listenHttpImpl(kj::AsyncIoStream& connection, + SuspendableHttpServiceFactory factory, + kj::Maybe suspendedRequest, + bool wantCleanDrain); +}; + +class HttpServer::SuspendableRequest { + // Interface passed to the SuspendableHttpServiceFactory parameter of listenHttpCleanDrain(). + +public: + kj::OneOf method; + kj::StringPtr url; + const HttpHeaders& headers; + // Parsed request front matter, so the implementer can decide whether to suspend the request. + + SuspendedRequest suspend(); + // Signal to the HttpServer that the current request loop should be exited. Return a + // SuspendedRequest, containing HTTP method, URL, and headers access, along with the actual header + // buffer. The request can be later resumed with a call to listenHttpCleanDrain() using the same + // connection. + +private: + explicit SuspendableRequest( + Connection& connection, kj::OneOf method, kj::StringPtr url, const HttpHeaders& headers) + : method(method), url(url), headers(headers), connection(connection) {} + KJ_DISALLOW_COPY_AND_MOVE(SuspendableRequest); + + Connection& connection; + + friend class Connection; }; // ======================================================================================= // inline implementation -inline void HttpHeaderId::requireFrom(HttpHeaderTable& table) const { +inline void HttpHeaderId::requireFrom(const HttpHeaderTable& table) const { KJ_IREQUIRE(this->table == nullptr || this->table == &table, "the provided HttpHeaderId is from the wrong HttpHeaderTable"); } -inline kj::Own HttpHeaderTable::Builder::build() { return kj::mv(table); } +inline kj::Own HttpHeaderTable::Builder::build() { + table->buildStatus = BuildStatus::FINISHED; + return kj::mv(table); +} inline HttpHeaderTable& HttpHeaderTable::Builder::getFutureTable() { return *table; } -inline uint HttpHeaderTable::idCount() { return namesById.size(); } +inline uint HttpHeaderTable::idCount() const { return namesById.size(); } +inline bool HttpHeaderTable::isReady() const { + switch (buildStatus) { + case BuildStatus::UNSTARTED: return true; + case BuildStatus::BUILDING: return false; + case BuildStatus::FINISHED: return true; + } + + KJ_UNREACHABLE; +} -inline kj::StringPtr HttpHeaderTable::idToString(HttpHeaderId id) { +inline kj::StringPtr HttpHeaderTable::idToString(HttpHeaderId id) const { id.requireFrom(*this); return namesById[id.id]; } @@ -631,6 +1435,70 @@ inline void HttpHeaders::forEach(Func&& func) const { } } +template +inline void HttpHeaders::forEach(Func1&& func1, Func2&& func2) const { + for (auto i: kj::indices(indexedHeaders)) { + if (indexedHeaders[i] != nullptr) { + func1(HttpHeaderId(table, i), indexedHeaders[i]); + } + } + + for (auto& header: unindexedHeaders) { + func2(header.name, header.value); + } +} + +// ======================================================================================= +namespace _ { // private implementation details for WebSocket compression + +kj::ArrayPtr splitNext(kj::ArrayPtr& cursor, char delimiter); + +void stripLeadingAndTrailingSpace(ArrayPtr& str); + +kj::Vector> splitParts(kj::ArrayPtr input, char delim); + +struct KeyMaybeVal { + ArrayPtr key; + kj::Maybe> val; +}; + +kj::Array toKeysAndVals(const kj::ArrayPtr>& params); + +struct UnverifiedConfig { + // An intermediate representation of the final `CompressionParameters` struct; used during parsing. + // We use it to ensure the structure of an offer is generally correct, see + // `populateUnverifiedConfig()` for details. + bool clientNoContextTakeover = false; + bool serverNoContextTakeover = false; + kj::Maybe> clientMaxWindowBits = nullptr; + kj::Maybe> serverMaxWindowBits = nullptr; +}; + +kj::Maybe populateUnverifiedConfig(kj::Array& params); + +kj::Maybe validateCompressionConfig(UnverifiedConfig&& config, + bool isAgreement); + +kj::Vector findValidExtensionOffers(StringPtr offers); + +kj::String generateExtensionRequest(const ArrayPtr& extensions); + +kj::Maybe tryParseExtensionOffers(StringPtr offers); + +kj::Maybe tryParseAllExtensionOffers(StringPtr offers, + CompressionParameters manualConfig); + +kj::Maybe compareClientAndServerConfigs(CompressionParameters requestConfig, + CompressionParameters manualConfig); + +kj::String generateExtensionResponse(const CompressionParameters& parameters); + +kj::OneOf tryParseExtensionAgreement( + const Maybe& clientOffer, + StringPtr agreedParameters); + +}; // namespace _ (private) + } // namespace kj -#endif // KJ_COMPAT_HTTP_H_ +KJ_END_HEADER diff --git a/c++/src/kj/compat/make-test-certs.sh b/c++/src/kj/compat/make-test-certs.sh new file mode 100755 index 0000000000..33725affb3 --- /dev/null +++ b/c++/src/kj/compat/make-test-certs.sh @@ -0,0 +1,162 @@ +#! /bin/bash +# Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +# Licensed under the MIT License: +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +# This script generates the test keys and certificates used in tls-test.c++. + +set -euxo pipefail + +mkdir -p tmp/test-certs +cd tmp/test-certs + +# Clean up from previous runs. +rm -rf demoCA *.key *.csr *.crt + +# Function to fake out OpenSSL CA configuration. Pass base name of files as parameter. +setup_ca_dir() { + rm -rf demoCA + mkdir -p demoCA/private demoCA/newcerts + ln -s ../../$1.key demoCA/private/cakey.pem + ln -s ../$1.crt demoCA/cacert.pem + touch demoCA/index.txt + echo 1000 > demoCA/serial +} + +# Create CA key and root cert +openssl genrsa -out ca.key 4096 +openssl req -key ca.key -new -x509 -days 36500 -sha256 -extensions v3_ca -out ca.crt << EOF +US +California +Palo Alto +Sandstorm.io +Testing Department +ca.example.com +garply@sandstorm.io +EOF +echo + +# Create intermediate certificate and CSR. +openssl genrsa -out int.key 4096 +openssl req -new -sha256 -key int.key -out int.csr << EOF +US +California +Palo Alto +Sandstorm.io +Testing Department +int-ca.example.com +garply@sandstorm.io + + +EOF +echo + +# Sign the intermediate cert with the CA key. +setup_ca_dir ca +openssl ca -extensions v3_ca -days 36500 -notext -md sha256 -in int.csr -out int.crt << EOF +y +y +EOF +cat ca.crt int.crt > ca-chain.crt + +# Create host key and CSR +openssl genrsa -out example.key 4096 +openssl req -new -sha256 -key example.key -out example.csr << EOF +US +California +Palo Alto +Sandstorm.io +Testing Department +example.com +garply@sandstorm.io + + +EOF +echo + +# Sign valid host certificate with intermediate CA. +setup_ca_dir int +openssl ca -extensions v3_ca -days 36524 -notext -md sha256 -in example.csr -out valid.crt << EOF +y +y +EOF + +# Sign expired host certificate with intermediate CA. +setup_ca_dir int +openssl ca -extensions v3_ca -startdate 160101000000Z -enddate 160101000000Z -notext -md sha256 -in example.csr -out expired.crt << EOF +y +y +EOF + +# Create alternate host key and CSR +openssl genrsa -out example2.key 4096 +openssl req -new -sha256 -key example2.key -out example2.csr << EOF +US +California +Palo Alto +Sandstorm.io +Testing Department +example.net +garply@sandstorm.io + + +EOF +echo + +# Sign valid host certificate with intermediate CA. +setup_ca_dir int +openssl ca -extensions v3_ca -days 36524 -notext -md sha256 -in example2.csr -out valid2.crt << EOF +y +y +EOF + +# Create self-signed host certificate. +openssl req -key example.key -new -x509 -days 36524 -sha256 -out self.crt << EOF +US +California +Palo Alto +Sandstorm.io +Testing Department +example.com +garply@sandstorm.io +EOF +echo + +# Cleanup +rm -rf demoCA + +# Output code. +write_constant() { + echo "static constexpr char $1[] =" + sed -e 's/^.*$/ "\0\\n"/g;s/--END .*$/\0;/g' $2 + echo +} + +echo "Writing code to: tmp/test-certs/test-keys.h" + +exec 1> test-keys.h +write_constant CA_CERT ca.crt +write_constant INTERMEDIATE_CERT int.crt +write_constant HOST_KEY example.key +write_constant VALID_CERT valid.crt +write_constant HOST_KEY2 example2.key +write_constant VALID_CERT2 valid2.crt +write_constant EXPIRED_CERT expired.crt +write_constant SELF_SIGNED_CERT self.crt diff --git a/c++/src/kj/compat/readiness-io-test.c++ b/c++/src/kj/compat/readiness-io-test.c++ new file mode 100644 index 0000000000..4db1d7cb46 --- /dev/null +++ b/c++/src/kj/compat/readiness-io-test.c++ @@ -0,0 +1,307 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "readiness-io.h" +#include +#include + +namespace kj { +namespace { + +KJ_TEST("readiness IO: write small") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + char buf[4]; + auto readPromise = pipe.in->read(buf, 3, 4); + + ReadyOutputStreamWrapper out(*pipe.out); + KJ_ASSERT(KJ_ASSERT_NONNULL(out.write(kj::StringPtr("foo").asBytes())) == 3); + + KJ_ASSERT(readPromise.wait(io.waitScope) == 3); + buf[3] = '\0'; + KJ_ASSERT(kj::StringPtr(buf) == "foo"); +} + +KJ_TEST("readiness IO: write many odd") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + ReadyOutputStreamWrapper out(*pipe.out); + + size_t totalWritten = 0; + for (;;) { + KJ_IF_MAYBE(n, out.write(kj::StringPtr("bar").asBytes())) { + totalWritten += *n; + if (*n < 3) { + break; + } + } else { + KJ_FAIL_ASSERT("pipe buffer is divisible by 3? really?"); + } + } + + auto buf = kj::heapArray(totalWritten + 1); + size_t n = pipe.in->read(buf.begin(), totalWritten, buf.size()).wait(io.waitScope); + KJ_ASSERT(n == totalWritten); + for (size_t i = 0; i < totalWritten; i++) { + KJ_ASSERT(buf[i] == "bar"[i%3]); + } +} + +KJ_TEST("readiness IO: write even") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + ReadyOutputStreamWrapper out(*pipe.out); + + size_t totalWritten = 0; + for (;;) { + KJ_IF_MAYBE(n, out.write(kj::StringPtr("ba").asBytes())) { + totalWritten += *n; + if (*n < 2) { + KJ_FAIL_ASSERT("pipe buffer is not divisible by 2? really?"); + } + } else { + break; + } + } + + auto buf = kj::heapArray(totalWritten + 1); + size_t n = pipe.in->read(buf.begin(), totalWritten, buf.size()).wait(io.waitScope); + KJ_ASSERT(n == totalWritten); + for (size_t i = 0; i < totalWritten; i++) { + KJ_ASSERT(buf[i] == "ba"[i%2]); + } +} + +KJ_TEST("readiness IO: write while corked") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + char buf[7]; + auto readPromise = pipe.in->read(buf, 3, 7); + + ReadyOutputStreamWrapper out(*pipe.out); + auto cork = out.cork(); + KJ_ASSERT(KJ_ASSERT_NONNULL(out.write(kj::StringPtr("foo").asBytes())) == 3); + + // Data hasn't been written yet. + KJ_ASSERT(!readPromise.poll(io.waitScope)); + + // Write some more, and observe it still isn't flushed out yet. + KJ_ASSERT(KJ_ASSERT_NONNULL(out.write(kj::StringPtr("bar").asBytes())) == 3); + KJ_ASSERT(!readPromise.poll(io.waitScope)); + + // After reenabling pumping, the full read should succeed. + // We start this block with `if (true) {` instead of just `{` to avoid g++-8 compiler warnings + // telling us that this block isn't treated as part of KJ_ASSERT's internal `for` loop. + if (true) { + auto tmp = kj::mv(cork); + } + KJ_ASSERT(readPromise.wait(io.waitScope) == 6); + buf[6] = '\0'; + KJ_ASSERT(kj::StringPtr(buf) == "foobar"); +} + +KJ_TEST("readiness IO: write many odd while corked") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + // The even/odd tests should work just as before even with automatic pumping + // corked, since we should still pump when the buffer fills up. + ReadyOutputStreamWrapper out(*pipe.out); + auto cork = out.cork(); + + size_t totalWritten = 0; + for (;;) { + KJ_IF_MAYBE(n, out.write(kj::StringPtr("bar").asBytes())) { + totalWritten += *n; + if (*n < 3) { + break; + } + } else { + KJ_FAIL_ASSERT("pipe buffer is divisible by 3? really?"); + } + } + + auto buf = kj::heapArray(totalWritten + 1); + size_t n = pipe.in->read(buf.begin(), totalWritten, buf.size()).wait(io.waitScope); + KJ_ASSERT(n == totalWritten); + for (size_t i = 0; i < totalWritten; i++) { + KJ_ASSERT(buf[i] == "bar"[i%3]); + } + + // Eager pumping should still be corked. + KJ_ASSERT(KJ_ASSERT_NONNULL(out.write(kj::StringPtr("bar").asBytes())) == 3); + auto readPromise = pipe.in->read(buf.begin(), 3, buf.size()); + KJ_ASSERT(!readPromise.poll(io.waitScope)); +} + +KJ_TEST("readiness IO: write many even while corked") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + ReadyOutputStreamWrapper out(*pipe.out); + auto cork = out.cork(); + + size_t totalWritten = 0; + for (;;) { + KJ_IF_MAYBE(n, out.write(kj::StringPtr("ba").asBytes())) { + totalWritten += *n; + if (*n < 2) { + KJ_FAIL_ASSERT("pipe buffer is not divisible by 2? really?"); + } + } else { + break; + } + } + + auto buf = kj::heapArray(totalWritten + 1); + size_t n = pipe.in->read(buf.begin(), totalWritten, buf.size()).wait(io.waitScope); + KJ_ASSERT(n == totalWritten); + for (size_t i = 0; i < totalWritten; i++) { + KJ_ASSERT(buf[i] == "ba"[i%2]); + } + + // Eager pumping should still be corked. + KJ_ASSERT(KJ_ASSERT_NONNULL(out.write(kj::StringPtr("ba").asBytes())) == 2); + auto readPromise = pipe.in->read(buf.begin(), 2, buf.size()); + KJ_ASSERT(!readPromise.poll(io.waitScope)); +} + +KJ_TEST("readiness IO: read small") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + ReadyInputStreamWrapper in(*pipe.in); + char buf[4]; + KJ_ASSERT(in.read(kj::ArrayPtr(buf).asBytes()) == nullptr); + + pipe.out->write("foo", 3).wait(io.waitScope); + + in.whenReady().wait(io.waitScope); + KJ_ASSERT(KJ_ASSERT_NONNULL(in.read(kj::ArrayPtr(buf).asBytes())) == 3); + buf[3] = '\0'; + KJ_ASSERT(kj::StringPtr(buf) == "foo"); + + pipe.out = nullptr; + + kj::Maybe finalRead; + for (;;) { + finalRead = in.read(kj::ArrayPtr(buf).asBytes()); + KJ_IF_MAYBE(n, finalRead) { + KJ_ASSERT(*n == 0); + break; + } else { + in.whenReady().wait(io.waitScope); + } + } +} + +KJ_TEST("readiness IO: read many odd") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + char dummy[8192]; + for (auto i: kj::indices(dummy)) { + dummy[i] = "bar"[i%3]; + } + auto writeTask = pipe.out->write(dummy, sizeof(dummy)).then([&]() { + // shutdown + pipe.out = nullptr; + }).eagerlyEvaluate(nullptr); + + ReadyInputStreamWrapper in(*pipe.in); + char buf[3]; + + for (;;) { + auto result = in.read(kj::ArrayPtr(buf).asBytes()); + KJ_IF_MAYBE(n, result) { + for (size_t i = 0; i < *n; i++) { + KJ_ASSERT(buf[i] == "bar"[i]); + } + KJ_ASSERT(*n != 0, "ended at wrong spot"); + if (*n < 3) { + break; + } + } else { + in.whenReady().wait(io.waitScope); + } + } + + kj::Maybe finalRead; + for (;;) { + finalRead = in.read(kj::ArrayPtr(buf).asBytes()); + KJ_IF_MAYBE(n, finalRead) { + KJ_ASSERT(*n == 0); + break; + } else { + in.whenReady().wait(io.waitScope); + } + } +} + +KJ_TEST("readiness IO: read many even") { + auto io = setupAsyncIo(); + auto pipe = io.provider->newOneWayPipe(); + + char dummy[8192]; + for (auto i: kj::indices(dummy)) { + dummy[i] = "ba"[i%2]; + } + auto writeTask = pipe.out->write(dummy, sizeof(dummy)).then([&]() { + // shutdown + pipe.out = nullptr; + }).eagerlyEvaluate(nullptr); + + ReadyInputStreamWrapper in(*pipe.in); + char buf[2]; + + for (;;) { + auto result = in.read(kj::ArrayPtr(buf).asBytes()); + KJ_IF_MAYBE(n, result) { + for (size_t i = 0; i < *n; i++) { + KJ_ASSERT(buf[i] == "ba"[i]); + } + if (*n == 0) { + break; + } + KJ_ASSERT(*n == 2, "ended at wrong spot"); + } else { + in.whenReady().wait(io.waitScope); + } + } + + kj::Maybe finalRead; + for (;;) { + finalRead = in.read(kj::ArrayPtr(buf).asBytes()); + KJ_IF_MAYBE(n, finalRead) { + KJ_ASSERT(*n == 0); + break; + } else { + in.whenReady().wait(io.waitScope); + } + } +} + +} // namespace +} // namespace kj diff --git a/c++/src/kj/compat/readiness-io.c++ b/c++/src/kj/compat/readiness-io.c++ new file mode 100644 index 0000000000..ca85feba35 --- /dev/null +++ b/c++/src/kj/compat/readiness-io.c++ @@ -0,0 +1,159 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "readiness-io.h" + +namespace kj { + +static size_t copyInto(kj::ArrayPtr dst, kj::ArrayPtr& src) { + size_t n = kj::min(dst.size(), src.size()); + memcpy(dst.begin(), src.begin(), n); + src = src.slice(n, src.size()); + return n; +} + +// ======================================================================================= + +ReadyInputStreamWrapper::ReadyInputStreamWrapper(AsyncInputStream& input): input(input) {} +ReadyInputStreamWrapper::~ReadyInputStreamWrapper() noexcept(false) {} + +kj::Maybe ReadyInputStreamWrapper::read(kj::ArrayPtr dst) { + if (eof || dst.size() == 0) return size_t(0); + + if (content.size() == 0) { + // No data available. Try to read more. + if (!isPumping) { + isPumping = true; + pumpTask = kj::evalNow([&]() { + return input.tryRead(buffer, 1, sizeof(buffer)).then([this](size_t n) { + if (n == 0) { + eof = true; + } else { + content = kj::arrayPtr(buffer, n); + } + isPumping = false; + }); + }).fork(); + } + + return nullptr; + } + + return copyInto(dst, content); +} + +kj::Promise ReadyInputStreamWrapper::whenReady() { + return pumpTask.addBranch(); +} + +// ======================================================================================= + +ReadyOutputStreamWrapper::ReadyOutputStreamWrapper(AsyncOutputStream& output): output(output) {} +ReadyOutputStreamWrapper::~ReadyOutputStreamWrapper() noexcept(false) {} + +kj::Maybe ReadyOutputStreamWrapper::write(kj::ArrayPtr data) { + if (data.size() == 0) return size_t(0); + + if (filled == sizeof(buffer)) { + // No space. + return nullptr; + } + + uint end = start + filled; + size_t result = 0; + if (end < sizeof(buffer)) { + // The filled part of the buffer is somewhere in the middle. + + // Copy into space after filled space. + result += copyInto(kj::arrayPtr(buffer + end, buffer + sizeof(buffer)), data); + + // Copy into space before filled space. + result += copyInto(kj::arrayPtr(buffer, buffer + start), data); + } else { + // Fill currently loops, so we only have one segment of empty space to copy into. + + // Copy into the space between the fill's end and the fill's start. + result += copyInto(kj::arrayPtr(buffer + end % sizeof(buffer), buffer + start), data); + } + + filled += result; + + if (!isPumping && (!corked || filled == sizeof(buffer))) { + isPumping = true; + pumpTask = kj::evalNow([&]() { + return pump(); + }).fork(); + } + + return result; +} + +kj::Promise ReadyOutputStreamWrapper::whenReady() { + return pumpTask.addBranch(); +} + +ReadyOutputStreamWrapper::Cork ReadyOutputStreamWrapper::cork() { + corked = true; + return Cork(*this); +} + +void ReadyOutputStreamWrapper::uncork() { + corked = false; + if (!isPumping && filled > 0) { + isPumping = true; + pumpTask = kj::evalNow([&]() { + return pump(); + }).fork(); + } +} + +kj::Promise ReadyOutputStreamWrapper::pump() { + uint oldFilled = filled; + uint end = start + filled; + + kj::Promise promise = nullptr; + if (end <= sizeof(buffer)) { + promise = output.write(buffer + start, filled); + } else { + end = end % sizeof(buffer); + segments[0] = kj::arrayPtr(buffer + start, buffer + sizeof(buffer)); + segments[1] = kj::arrayPtr(buffer, buffer + end); + promise = output.write(segments); + } + + return promise.then([this,oldFilled,end]() -> kj::Promise { + filled -= oldFilled; + start = end; + + if (filled > 0) { + return pump(); + } else { + isPumping = false; + // As a small optimization, reset to the start of the buffer when it's empty so we can provide + // the underlying layer just one contiguous chunk of memory instead of two when possible. + start = 0; + return kj::READY_NOW; + } + }); +} + +} // namespace kj + diff --git a/c++/src/kj/compat/readiness-io.h b/c++/src/kj/compat/readiness-io.h new file mode 100644 index 0000000000..2ed9468416 --- /dev/null +++ b/c++/src/kj/compat/readiness-io.h @@ -0,0 +1,131 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include + +KJ_BEGIN_HEADER + +namespace kj { + +class ReadyInputStreamWrapper { + // Provides readiness-based Async I/O as a wrapper around KJ's standard completion-based API, for + // compatibility with libraries that use readiness-based abstractions (e.g. OpenSSL). + // + // Unfortunately this requires buffering, so is not very efficient. + +public: + ReadyInputStreamWrapper(AsyncInputStream& input); + ~ReadyInputStreamWrapper() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(ReadyInputStreamWrapper); + + kj::Maybe read(kj::ArrayPtr dst); + // Reads bytes into `dst`, returning the number of bytes read. Returns zero only at EOF. Returns + // nullptr if not ready. + + kj::Promise whenReady(); + // Returns a promise that resolves when read() will return non-null. + + bool isAtEnd() { return eof; } + // Returns true if read() would return zero. + +private: + AsyncInputStream& input; + kj::ForkedPromise pumpTask = nullptr; + bool isPumping = false; + bool eof = false; + + kj::ArrayPtr content = nullptr; // Points to currently-valid part of `buffer`. + byte buffer[8192]; +}; + +class ReadyOutputStreamWrapper { + // Provides readiness-based Async I/O as a wrapper around KJ's standard completion-based API, for + // compatibility with libraries that use readiness-based abstractions (e.g. OpenSSL). + // + // Unfortunately this requires buffering, so is not very efficient. + +public: + ReadyOutputStreamWrapper(AsyncOutputStream& output); + ~ReadyOutputStreamWrapper() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(ReadyOutputStreamWrapper); + + kj::Maybe write(kj::ArrayPtr src); + // Writes bytes from `src`, returning the number of bytes written. Never returns zero for + // a non-empty `src`. Returns nullptr if not ready. + + kj::Promise whenReady(); + // Returns a promise that resolves when write() will return non-null. + + class Cork; + // An object that, when destructed, will uncork its parent stream. + + Cork cork(); + // After calling, data won't be pumped until either the internal buffer fills up or the returned + // object is destructed. Use this if you know multiple small write() calls will be happening in + // the near future and want to flush them all at once. + // Once the returned object is destructed, behavior goes back to normal. The returned object + // must be destructed before the ReadyOutputStreamWrapper. + // TODO(perf): This is an ugly hack to avoid sending lots of tiny packets when using TLS, which + // has to work around OpenSSL's readiness-based I/O layer. We could certainly do better here. + +private: + AsyncOutputStream& output; + ArrayPtr segments[2]; + kj::ForkedPromise pumpTask = nullptr; + bool isPumping = false; + bool corked = false; + + uint start = 0; // index of first byte + uint filled = 0; // number of bytes currently in buffer + + byte buffer[8192]; + + void uncork(); + + kj::Promise pump(); + // Asynchronously push the buffer out to the underlying stream. +}; + +class ReadyOutputStreamWrapper::Cork { + // An object that, when destructed, will uncork its parent stream. +public: + ~Cork() { + KJ_IF_MAYBE(p, parent) { + p->uncork(); + } + } + Cork(Cork&& other) : parent(kj::mv(other.parent)) { + other.parent = nullptr; + } + KJ_DISALLOW_COPY(Cork); + +private: + Cork(ReadyOutputStreamWrapper& parent) : parent(parent) {} + + kj::Maybe parent; + friend class ReadyOutputStreamWrapper; +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/compat/tls-test.c++ b/c++/src/kj/compat/tls-test.c++ new file mode 100644 index 0000000000..dddefa5747 --- /dev/null +++ b/c++/src/kj/compat/tls-test.c++ @@ -0,0 +1,1292 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if KJ_HAS_OPENSSL + +#if _WIN32 +#include +#endif + +#include "tls.h" + +#include "http.h" + +#include + +#include + +#if _WIN32 +#include + +#include +#else +#include +#endif + +#include +#include + +namespace kj { +namespace { + +// ======================================================================================= +// test data +// +// made with make-test-certs.sh +static constexpr char CA_CERT[] = + "-----BEGIN CERTIFICATE-----\n" + "MIIGMzCCBBugAwIBAgIUDxGXACZeJ0byrswV8gyWskZF2Q8wDQYJKoZIhvcNAQEL\n" + "BQAwgacxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRIwEAYDVQQH\n" + "DAlQYWxvIEFsdG8xFTATBgNVBAoMDFNhbmRzdG9ybS5pbzEbMBkGA1UECwwSVGVz\n" + "dGluZyBEZXBhcnRtZW50MRcwFQYDVQQDDA5jYS5leGFtcGxlLmNvbTEiMCAGCSqG\n" + "SIb3DQEJARYTZ2FycGx5QHNhbmRzdG9ybS5pbzAgFw0yMDA2MjcwMDQyNTJaGA8y\n" + "MTIwMDYwMzAwNDI1MlowgacxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9y\n" + "bmlhMRIwEAYDVQQHDAlQYWxvIEFsdG8xFTATBgNVBAoMDFNhbmRzdG9ybS5pbzEb\n" + "MBkGA1UECwwSVGVzdGluZyBEZXBhcnRtZW50MRcwFQYDVQQDDA5jYS5leGFtcGxl\n" + "LmNvbTEiMCAGCSqGSIb3DQEJARYTZ2FycGx5QHNhbmRzdG9ybS5pbzCCAiIwDQYJ\n" + "KoZIhvcNAQEBBQADggIPADCCAgoCggIBAKpp8VF/WDPw1V1aD36/uDWI4XRk9OaJ\n" + "i8tkAbTPutJ7NU4AWv9OzreKIR1PPhj0DtxVOj5KYRTwL1r4DsFWh6D0rBV7oz7o\n" + "zP8hWznVQBSa2BJ2E4uDD8p5oNz1+O+o4UgSBbOr83Gp5SZGw9KO7cgNql9Id/Ii\n" + "sHYxXdrYdAl6InuR6q52CJgcGqQgpFYG+KYqDiByfX52slyz5FA/dfZxsmoEVFLB\n" + "rgbeuhsJGIoasTkGIdCYJhYI7k2uWtvYNurnhgvlpfPHHuSnJ+aWVdKaQUthgbsy\n" + "T2XHuLYpWx7+B7kCF5B4eKtn3e7yzE7A8jn8Teq6yRNUh1PnM7CRMmz3g4UAxmJT\n" + "F5NyQd14IqveFuOk40Ba8wLoypll5We5tV7StUyvaOlAhi+gGPHfWKk4MGiBoaOV\n" + "52D1+/dkh/abokdKZtE59gJX0MrH6mihfc9KQs7N51IhSs4kG5zGIBdvgXtmP17H\n" + "hixUFi0Y85pqGidW4LLQ1pmK9k0U4gYlwHtqHh8an35/vp/mFhl2BDHcpuYKZ34U\n" + "ZDo9GglfCTVEsUvAnwfhuYN0e0kveSTuRCMltjGg0Fs1h9ljNNuc46W4qIx/d5ls\n" + "aMOTKc3PTtwtqgXOXRFn2U7AUXOgtEqyqpuj5ZcjH2YQ3BL24qAiYEHaBOHM+8qF\n" + "9JLZE64j5dnZAgMBAAGjUzBRMB0GA1UdDgQWBBTmqsbDUpi5hgPbcPESYR9t8jsD\n" + "7jAfBgNVHSMEGDAWgBTmqsbDUpi5hgPbcPESYR9t8jsD7jAPBgNVHRMBAf8EBTAD\n" + "AQH/MA0GCSqGSIb3DQEBCwUAA4ICAQADdVBYClYWqNk1s2gamjGsyQ2r88TWTD6X\n" + "RySVnyQPATWuEctrr6+8qTrbqBP4bTPKE+uTzwk+o5SirdJJAkrcwEsSCFw7J/qf\n" + "5U/mXN+EUuqyiMHOS/vLe5X1enj0I6fqJY2mCGFD7Jr/+el1XXjzRZsLZHmqSxww\n" + "T+UjJP+ffvtq3pq4nMQowxXm+Wub0gFHj5wkKMTIDyqnbjzB9bdVd0crtA+EpYIi\n" + "f8y5WB46g1CngRnMzRQvg5FCmxg57i+mVgiUjUe54VenwK9aeeHIuOdLCZ0RmiNH\n" + "KHPUBct+S/AXx8DCoAdm51EahwMBnlUyISpwJ+LVMWA2R9DOxdhEF0tv5iBsD9rn\n" + "oKIWoa0t/Vwnd2n8wyLhuA7N4yzm0rdBjO/rU6n0atIab5+CEDyLeyWQBVwfCUF5\n" + "XYNxOBJgGfSgJa23KUtn15pS/nSTa6sOtS/Mryc4UuNzxn+3ebNOG4UPlH6miSMK\n" + "yA+5SCyKgrn3idifzrq+XafA2WUnxdBLgJMM4OIPAGNjCCW2P1cP/NVllUTjTy2y\n" + "AIKQ/D9V/DzlbIIT6F3CNnqa9xnrBWTKF1YH/zSB7Gh2xlr0WnOWJVQbNUYet982\n" + "JL5ibRhsiqBgltgQPhKhN/rGuh7Cb28679fQqLXKgOWvV2fC4b2y0v9dG78jGCEE\n" + "LBzBUUmunw==\n" + "-----END CERTIFICATE-----\n"; + +static constexpr char INTERMEDIATE_CERT[] = + "-----BEGIN CERTIFICATE-----\n" + "MIIGETCCA/mgAwIBAgICEAAwDQYJKoZIhvcNAQELBQAwgacxCzAJBgNVBAYTAlVT\n" + "MRMwEQYDVQQIDApDYWxpZm9ybmlhMRIwEAYDVQQHDAlQYWxvIEFsdG8xFTATBgNV\n" + "BAoMDFNhbmRzdG9ybS5pbzEbMBkGA1UECwwSVGVzdGluZyBEZXBhcnRtZW50MRcw\n" + "FQYDVQQDDA5jYS5leGFtcGxlLmNvbTEiMCAGCSqGSIb3DQEJARYTZ2FycGx5QHNh\n" + "bmRzdG9ybS5pbzAgFw0yMDA2MjcwMDQyNTNaGA8yMTIwMDYwMzAwNDI1M1owgZcx\n" + "CzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRUwEwYDVQQKDAxTYW5k\n" + "c3Rvcm0uaW8xGzAZBgNVBAsMElRlc3RpbmcgRGVwYXJ0bWVudDEbMBkGA1UEAwwS\n" + "aW50LWNhLmV4YW1wbGUuY29tMSIwIAYJKoZIhvcNAQkBFhNnYXJwbHlAc2FuZHN0\n" + "b3JtLmlvMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAvDRmV4e7F/1K\n" + "TaySC2xq00aYZfKJmwcvBvLM60xMrwDf5KlHWoYzog72q6qNb0TpvaeOlymU1IfR\n" + "DkRIMFz6qd4QpK3Ah0vDtqEsoYE6F7gr2QAAaB7HQhczClh6FFkMCygrHDRQlJcF\n" + "VseTKxnZBUAUhnG7OI8bHMZsprg31SNG1GCXq/CO/rkKIP1Pdwevr8DFHL3LFF03\n" + "vdeo6+f/3LjlsCzVCNsCcAYIScRSl1scj2QYcwP7tTmDRAmU9EZFv9MWSqRIMT4L\n" + "/4tP9AL/pWGdx8RRvbXoLJQf+hQW6YnSorRmIH/xYsvqMaan4P0hkomfIHP4nMYa\n" + "LgI8VsNhTeDZ7IvSF2F73baluTOHhUD5eE/WDffNeslaCoMbH/B3H6ks0zYt/mHG\n" + "mDaw3OxgMYep7TIE+SABOSJV+pbtQWNyM7u2+TYHm2DaxD3quf+BoYUZT01uDtN4\n" + "BSsR7XEzF25w/4lDxqBxGAZ0DzItK0kzqMykSWvDIjpSg/UjRj05sc+5zcgE4pX1\n" + "nOLD+FuB9jVqo6zCiIkHsSI0XHnm4D6awB1UyDwSh8mivfUDT53OpOIwOI3EB/4U\n" + "iZstUKgyXNrXsE6wS/3JfdDZ9xkw4dWV+0FWKJ4W6Y8UgKvQJRChpCUtcuxfaLjX\n" + "/ZIcMRYEFNjFppka/7frNT1VNnRvJ9cCAwEAAaNTMFEwHQYDVR0OBBYEFNnHDzWZ\n" + "NC5pP6njUnf1BMXuQKnnMB8GA1UdIwQYMBaAFOaqxsNSmLmGA9tw8RJhH23yOwPu\n" + "MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggIBAG0ShYBIbwGMDUBj\n" + "9G2vxI4Nx7lip9/wO5NCC2biVxtcjNxwTuCcqCxeSGgdimo82vJTm9Wa1AkoawSv\n" + "+spoXJAeOL5dGmZX3IRD0KuDoabH7F6xbNOG3HphvwAJcKQOitPAgTt2eGXCaBqw\n" + "a1ivxPzchHVd3+LyjLMK5G3zzgW0tZhKp7IYE0eGoSGwrw10ox6ibyswe8evul1N\n" + "r7Z3zFSWqjy6mcyp7PvlImDMYzWHYYdncEaeBHd13TnTQ++AtY3Da2lu7WG5VlcV\n" + "vcRaNx8ZLIUpW7u76F9fH25GFnQdF87dgt/ufntts4NifwRRRdIDAgunQ0Nf/TDy\n" + "ow/P7dobloGl9c5xGyXk9EbYYiK/Iuga1yUKJP3UQhfOTNBsDkyUD/jLg+0bI8T+\n" + "Vv4hKNhzUObXjzA5P7+RoEqnPe8R5wjvea6OJc2MDm5GFrzZOdIEnW/iuaQTE3ol\n" + "PYZVDvgPB+l7IC0brwTvdnvXaGLFu8ICQtNoOuSkbEAysyBWCEGwloYLL3r/0ozb\n" + "k3z5mUZVBwhpQS/GEChfUKiLxk9rbDmnhKASa6FVtycSVfoW/Oh5M65ARf9CFIUv\n" + "cvDYLpQl7Ao8QiDqlQBvitNi665uucsZ/zOdFlTKKAellLASpB/q2Y6J4wkogFNa\n" + "dGRV0WpQkcuMjXK7bzSSLMo0nMCm\n" + "-----END CERTIFICATE-----\n"; + +static constexpr char HOST_KEY[] = + "-----BEGIN RSA PRIVATE KEY-----\n" + "MIIJJwIBAAKCAgEAzO6/HGJ46PIogZJPxfSE6eHuqE1JvC8eaSCOKpqbfGsIHPfh\n" + "8d+TFatDe4GF4B6YYYmxpJUN85wZC+6GTA7xeM9UubssiwfACE/se5lJA0SSbXo9\n" + "SXkE5sI/LM7NA6+h3PTmoLlwduWHh6twHbcfcJUtSOSJemCLUqZOiSQ+CGW2RzEF\n" + "01SM/eOmpN9gtmVejD7wOxItbjX/DDE/IKv6xQ2G1RYxEpggCnNQjZj4R0uroN6L\n" + "tM7hiQebueXs3KfVIEcEV9oXfTf2tI4G3NTn9MJUvZQNIMUVOtjl8UmLHfk3T9Hc\n" + "eL/RM0svWIPWNoUJxA6m4u4g9D0eED2W0chKe86tF65RRYHvPNQXZf+tG7wyvRKP\n" + "ePWIlEzIAgXW2WwXEvzFvvRnfX1mYeWA7KhyRaVcCp9H0i5ZP0L8E2RsvYd+2QMm\n" + "zrxkeM01KMJSTAtUhqydNaH5dzlp8UXWvoTFLAh4F744OjdOyjHqoTZi1oh+iqsc\n" + "3o4tWQF3YzpjharP3mngGj+YDp/ZN6F8QdFQU+iNFKx6aDRgnKm+ri0/yTDMQ3G1\n" + "4bQ9FNmLRwB+W5UI9Oy931JxxcW7LxncbHgA326++bD1CdLXyxOvpHK40B5f9zsg\n" + "m980xdjBcLr2o0nGAsZ9V79a0BNqjvDA3DjmXBGTYxJSpMGfipP5KbP2hl0CAwEA\n" + "AQKCAgBSqV63EVVaCQujsCOzYn0WZgbBJmO+n3bxyqrtrm1XU0jzfl1KFfebPvi6\n" + "YbVhgJXQihz4mRMGl4lW0cCj/0cRhvfS7xf5gIfKEor+FAdqZQd3V15PO5xphCK9\n" + "bTEu8nIk0TgRzpr5qn3vkIxpwArTe6jHhT+a+ERaczCsiszm0DglITYLV0iDxIbc\n" + "bCnziJIJmf2Gpj9i/C7DeT3QbO56+4jOfOQQbwJFlNwCMZi8EV7KRdoudWBtyH7d\n" + "DkxreNsz6NFsqlDdNmyxybQk8VAa3yQVUBm3hSeaFBE0MYkG7xaLgMgggKbevM39\n" + "Mzh9x034IjzYvlrWiayNunoSZmr8KhYQUsgAE5F+khTLfheiGvXLkjTdBLIGG5q7\n" + "nb3G1la/jx0/0YpQvawNriovwii76cgHjmnEiNBJH685fvVP1y3bM/dJ62xdzFpw\n" + "7/1xrii1D9rSvrns037WOrv6VtdtNpJoLbUlNXU8CX9Oc7umbBLCLEp00+0DDctk\n" + "3KVX+sQc37w9KRURu8LqyFFx4VvyIQ+sJQKwVqFpaIkR7MzN6apkBCXwF+WtgtCz\n" + "7RGcu8V00yA0Rqqm1RVhPzbBU5UII0sRTYDacZFqzku8dVqEH0dUGlvtG2oIu0ed\n" + "OOv93q2EIyxmA88xcA2YFLI7P/qjYuhcUCd4QZb8S6/c71V+lQKCAQEA+fhQHE8j\n" + "Pp45zNZ9ITBTWqf0slC9iafpZ1qbnI60Sm1UigOD6Qzc5p6jdKggNr7VcQK+QyOo\n" + "pnM3qePHXgDzJZfdN7cXB9O7jZSTWUyTDc1C59b3H4ney/tIIj4FxCXBolkPrUOW\n" + "PE96i2PRQRGI1Pp5TFcFcaT5ZhX9WTPo/0fYBFEpof62mxGYRbQ8cEmRQYAXmlES\n" + "mqE9RwhwqTKy3VET0qDV5eq1EihMkWlDleyIUHF7Ra1ECzh8cbTx8NpS2Iu3BL5v\n" + "Sk6fGv8aegmf4DZcDU4eoZXBA9veAigT0X5cQH4HsAuppIOfRD8pFaJEe2K5EXXu\n" + "lDSlJk5BjSrARwKCAQEA0eBORzH+g2kB/I8JIMYqy4urI7mGU9xNSwF34HhlPURQ\n" + "NxPIByjUKrKZnCVaN2rEk8vRzY8EdsxRpKeGBYW1DZ/NL8nM7RHMKjGopf/n2JRk\n" + "7m1Mn6f4mLRIVzOID/eR2iGPvWNcKrJxgkGXNXlWEi0wasPwbVksV05uGd/5K/sT\n" + "cVVKkYLPhcGkYd7VvCfgF5x+5kPRLTrG8EKbePclUAVhs9huSRR9FAjSThpLyJZo\n" + "0/3WxPNm9FoPYGyAevvdPiF5f6yp1IdPDNZHx46WSVfkgqA5XICjanl5WCNHezDp\n" + "93r/Ix7zJ0Iyu9BxI2wJH3wGGqmsGveKOfZl90MaOwKCAQBCSGfluc5cslQdTtrL\n" + "TCcuKM8n4WUA9XdcopgUwXppKeh62EfIKlMBDBvHuTUhjyTF3LZa0z/LM04VTIL3\n" + "GEVhOI2+UlxXBPv8pOMVkMqFpGITW9sXj9V2PWF5Qv0AcAqSZA9WIE/cGi8iewtn\n" + "t6CS6P/1EDYvVlGTkk0ltDAaURCkxGjHveTp5ZZ9FTfZhohv1+lqUAkg25SGG2TU\n" + "WM85BGC/P0q4tq3g7LKw9DqprJjQy+amKTWbzBSjihmFhj7lkNas+VpFV+e0nuSE\n" + "a7zrFT7/gDF7I1yVC14pMDthF6Kar1CWi+El8Ijw7daVF/wUw67TRHRI9FS+fY3A\n" + "Qw/NAoIBAFbE7LgElFwSEu8u17BEHbdPhC7d6gpLv2zuK3iTbg+5aYyL0hwbpjQM\n" + "6PMkgjr9Gk6cap4YrdjLuklftUodMHB0i+lg/idZP1aGd1pCBcGGAICOkapEUMQZ\n" + "bPsYY/1t9k//piS/qoBAjCs1IOXLx2j2Y9kQLxuWTX2/AEgUUDj9sdkeURj9wvxi\n" + "xaps7WK//ablXZWnnhib/1mfwBVv4G5H+0/WgCoYnWmmCASgXIqOnMJgZOXCV+NY\n" + "RJkx4qB19s9UGZ5ObVxfoLAG+2AmtD2YZ/IVegGjcWx40lE9LLVi0KgvosILbq3h\n" + "cYYytEPXy6HHreJiGbSAeRZjp15l0LcCggEAWjeKaf61ogHq0uCwwipVdAGY4lNG\n" + "dAh2IjAWX0IvjBuRVuUNZyB8GApH3qilDsuzPLYtIhHtbuPrFjPyxMBoppk7hLJ+\n" + "ubzoHxMKeyKwlIi4jY+CXLaHtIV/hZFVYcdvr5A2zUrMgzEnDEEFn96+Dz0IdGrF\n" + "a37oDKDYpfK8EFk/jb2aAhIgYSHSUg4KKlQRuPfu0Vt/O8aasjmAQvfEmbx6c2C5\n" + "4q16Ky/ZM7mCqQNJAXgyPfOeJnX1PwCKntdEs2xtzXb7dEP9Yy9uAgmmVe7EzPCj\n" + "ml57PWxIJKtok73AHEat6qboncHvW1RPDAiQYmXdbe40v4wlPDrnWjUUiA==\n" + "-----END RSA PRIVATE KEY-----\n"; + +static constexpr char VALID_CERT[] = + "-----BEGIN CERTIFICATE-----\n" + "MIIF+jCCA+KgAwIBAgICEAAwDQYJKoZIhvcNAQELBQAwgZcxCzAJBgNVBAYTAlVT\n" + "MRMwEQYDVQQIDApDYWxpZm9ybmlhMRUwEwYDVQQKDAxTYW5kc3Rvcm0uaW8xGzAZ\n" + "BgNVBAsMElRlc3RpbmcgRGVwYXJ0bWVudDEbMBkGA1UEAwwSaW50LWNhLmV4YW1w\n" + "bGUuY29tMSIwIAYJKoZIhvcNAQkBFhNnYXJwbHlAc2FuZHN0b3JtLmlvMCAXDTIw\n" + "MDYyNzAwNDI1M1oYDzIxMjAwNjI3MDA0MjUzWjCBkDELMAkGA1UEBhMCVVMxEzAR\n" + "BgNVBAgMCkNhbGlmb3JuaWExFTATBgNVBAoMDFNhbmRzdG9ybS5pbzEbMBkGA1UE\n" + "CwwSVGVzdGluZyBEZXBhcnRtZW50MRQwEgYDVQQDDAtleGFtcGxlLmNvbTEiMCAG\n" + "CSqGSIb3DQEJARYTZ2FycGx5QHNhbmRzdG9ybS5pbzCCAiIwDQYJKoZIhvcNAQEB\n" + "BQADggIPADCCAgoCggIBAMzuvxxieOjyKIGST8X0hOnh7qhNSbwvHmkgjiqam3xr\n" + "CBz34fHfkxWrQ3uBheAemGGJsaSVDfOcGQvuhkwO8XjPVLm7LIsHwAhP7HuZSQNE\n" + "km16PUl5BObCPyzOzQOvodz05qC5cHblh4ercB23H3CVLUjkiXpgi1KmTokkPghl\n" + "tkcxBdNUjP3jpqTfYLZlXow+8DsSLW41/wwxPyCr+sUNhtUWMRKYIApzUI2Y+EdL\n" + "q6Dei7TO4YkHm7nl7Nyn1SBHBFfaF3039rSOBtzU5/TCVL2UDSDFFTrY5fFJix35\n" + "N0/R3Hi/0TNLL1iD1jaFCcQOpuLuIPQ9HhA9ltHISnvOrReuUUWB7zzUF2X/rRu8\n" + "Mr0Sj3j1iJRMyAIF1tlsFxL8xb70Z319ZmHlgOyockWlXAqfR9IuWT9C/BNkbL2H\n" + "ftkDJs68ZHjNNSjCUkwLVIasnTWh+Xc5afFF1r6ExSwIeBe+ODo3Tsox6qE2YtaI\n" + "foqrHN6OLVkBd2M6Y4Wqz95p4Bo/mA6f2TehfEHRUFPojRSsemg0YJypvq4tP8kw\n" + "zENxteG0PRTZi0cAfluVCPTsvd9SccXFuy8Z3Gx4AN9uvvmw9QnS18sTr6RyuNAe\n" + "X/c7IJvfNMXYwXC69qNJxgLGfVe/WtATao7wwNw45lwRk2MSUqTBn4qT+Smz9oZd\n" + "AgMBAAGjUzBRMB0GA1UdDgQWBBRdpPVLjMnJFs7lCtMW6x39Wnwt3TAfBgNVHSME\n" + "GDAWgBTZxw81mTQuaT+p41J39QTF7kCp5zAPBgNVHRMBAf8EBTADAQH/MA0GCSqG\n" + "SIb3DQEBCwUAA4ICAQA2kid2fGqjeDRVuclfDRr0LhbFYfJJXxW7SPgcUpJYXeAz\n" + "LXotBm/Cc+K01nNtl0JYfJy4IkaQUYgfVsA5/FqTGnbRmpEd5XidiGE6PfkXZSNj\n" + "02v6Uv2bAs8NnJirS5F0JhWZ45xAOMbl04QPUdkISF1JzioCcCqWOggbIV7kzwrB\n" + "MTcsx8vuh04S9vB318pKli4uIjNdwu7HnoqbrqhSgTUK1aXS6sDQNN/nvR6F5TRL\n" + "MC0cCtIA6n04c09WRfHxl/YPQwayxGD23eQ9UC7Noe/R8B3K/+6XUYIEmXx6HpnP\n" + "yt/79iBPwLnaVjGDwKfI8EPuSo7AkDSO3uxMjf7eCL2sCzWlgsne9yOfYxGn+q9K\n" + "h3KTOR1b7EVU/G4h7JxlSHqf3Ii9qFba/HsUo1yMjVEraMNpxXCijGsN30fqUpLg\n" + "2g9lNKmIdyHuYdlZET082b1dvb7cfYChlqHvrPv5awbsc1Ka2pOFOMwnUi2w3cHc\n" + "TLq3SyipI+sgJg3HHSJ3zVHrgKUeDoQi52i5WedvIBPfHC4Ik4zEzlkCMDd3hoLe\n" + "THAsBGBBEICK+lLSo0Mst2kxhTHHK+PxmhorXXnGOpla2wbwzVZgxur45axCWVqH\n" + "cbdcVhzR2w4XhjaS74WGiHHn5mHb/uYZJiZGpFCNefU2WOBlRCX7hgAzlPa4ZQ==\n" + "-----END CERTIFICATE-----\n"; + +static constexpr char HOST_KEY2[] = + "-----BEGIN RSA PRIVATE KEY-----\n" + "MIIJKgIBAAKCAgEAtpULbBVceP5UTP8Aq/qZ4zuws0mgfRazlFFDzn0SpxKRgfUR\n" + "OrDB8EMcffL+IxWYdzszYnm7R4p8udQHtqdX1m+JpWPIcEyOGuKEjEGGVBbfteiG\n" + "vCZaHmmhSGFuBuRQnsmOMN2sX4ATPgISeUpKz3YcEw5zbGV9XveQBCiCZYJOEY2R\n" + "qzuzfwmO76Nf/0pQtFaN6vjHOGOp5e6xEWUNMruliw83/BYmtOE0CH9QmLSi5d/s\n" + "OMsppXVduwUshHv2gwocXFik4FUhDMKjfzp71uvRLqAnpsf6u5uShXwqgamohQct\n" + "h9D0x5KMoTwlf7LnV52dJ4Fp0np4FYkNhJJxqMXjJuzW4HqHyt/zzelhXavKyPyf\n" + "XGmkRHzQLAaOzDN+3qXBvArubYapuy/CF1n/Dh9OvcGJ8vK2wtahGig2Wrwh5k0e\n" + "ZEiQfkFtXKhVmxNgcEEr6coIPAPe882F0HrWgM5h/huS50MC5OWyjnAySZ7L/Qj0\n" + "7jDfNcij1yajmv4ahsL8FI4hal8k3DSXe3MDnAwmBxKg+b/KUWNiucdInZUZ17c5\n" + "765aQoeIPZFVBoAQrgFFLPE31wC7SwrMuKMhy2UKbgXjcZ5MMkbS2WBSaqBFLSys\n" + "zHY0cFPCRh6K7d7vDvSG7lZT16lNFfagbvcO1uusQBD22gGvNOF2zZe377ECAwEA\n" + "AQKCAgA7Dm6JYUdt42XFGd5PwlkwRMhc1X3RuBwR5081ZQM5gyoJjQkroKy6WBrJ\n" + "KmXFV2DfgAiY26MV+tdpDAoKrIoe1CkDlAjrOffk/ku9Shx26oclwbaC+SzBFY2T\n" + "aeA63nKtSahyaeEtarHOpsDu9nbIL/3YtB3le9ZXd1/f2HKE/ubdipsJdeATQTY4\n" + "kPGmE5WTH0P8MsfNl38G3nPrmnHwbP2Ywy1qnoeajhVUgknBevwNuqYfoKcx24qb\n" + "yYqit63+qLCPtiRuY1qzU+mqZ3JTDCe3Gxp4OcsCD8oO3yConAXkMXQqsA3c16wh\n" + "IuFGMsndbx+7/YILEI3y+UekD/IvBmLAtX0X5fQEhHm/8SowHAe5XIOTWGnTjXrF\n" + "JhGwsRuQtSXJPhTVAeR07IrPAASVp0BeppBdHv0pUkOte/usYH3PhYGLuQJIbOHi\n" + "AvDZDE/6CVszXFU7Ry3QpIZ4aQQMGGOJg1LVhNnt7br7ZZJ2mU5BOOvgdcNACcR8\n" + "+sCD2DE3dD6Nxy5rGESHSqYDyedd3KJ7wPlBO6p6KUttgwWtZDXh940kGxcnEQtQ\n" + "HAnHOxCJU+cxj9ekxoNHcTjqCEGkL7jy013soG4yrw62vMsVKHxEUqo/uaBSI7Nd\n" + "h+JiHb/mQ/sGvQcaAlJTnX7kUfy+oPi0nL3Hz97CxwcWSFFmIQKCAQEA8AiKO13n\n" + "0FjMGZJHXeu73Rg/OvQBSSE8Fco/watnEKKzRPIXUdYgjxL8fH7jFwV7Ix941ddG\n" + "2RcY7zchbMAP95xfeYUlrHyxbgNjWWIFezaattrCn8Sw52SXNcsT2bGuaNmyVmHs\n" + "gwyIDCl+1cPArhcuCun8MszsSy5W0EPwsaCqDnPinTE0lSsj0MWpR4K7+m4OOmA1\n" + "zwqQ9j/pgvP8is7YeTEb1a7JtVAD+4nCd7XnzUuun2Qw5jaCGFJKZGEwqeSUQZS6\n" + "NAuQ0OTaw8m3qb7incS9tZahWLACLfT4Jrrfh9Is2pFVgQstzin9dpfLrWKeBRGP\n" + "D3ItA73haE/xjQKCAQEAwrova9Epg5ylKReKCTMZREDO1PTbHXRS7GrgNN1Sy1Ne\n" + "Ke3UjEowJMMINekYEJimuGLixl50K0a6T5lQoUQY6dYHeZEQ4YSN3lPWtgbuZEtI\n" + "OlSrsw9/duT55gVPiRsZTiHjM1mYVEtAeUxHH/PVoSjPY4V1OKWA3HJ/TtgP1JEN\n" + "scdIdIXP1HZnjxxN4juVHyr1+iC2KXbb/OajXFMUCPkp7YrlknDyYcgj2uXRqC3k\n" + "ju3oBplcEnNrWO6RfqQ+QYv87huPXV5sHXzjNBj9ssHwn+EYxwpne+LvMfg3E5l1\n" + "o7Yl40IfHKK85ts8qwjG6tJ3TUxAUrMPSKBLbbODtQKCAQEAxJWZ8Kkl8+LltYOx\n" + "41/vilITZwr0CpqnhQkRUmI4lM1LmQnUw3dlTwgztRqOjgo1ITzjT+9x3NYn27MB\n" + "MvnRme992h6MDkpJXlp0AX5gEttTtrJPd141rC0cEjhx13bH6qNwhYLJm0KmIZ/S\n" + "euxJX8soMFQV8t0WITSgcQ1TkYaOACw0ypzD/e9I8/EOhLyzi5SbHoAxUZHLy4Ho\n" + "kxGUIXLqo8bujwEJve78dAQNOtHGOMLlDzGVQtYdkiHDP5bBrkLAkT1nirx2LD9i\n" + "U7tfKixlmOTKom/tUJ9GCbF5ku61p50gkxk4N+mZ6CFHrtr/Os9rr6cDzZiq+UeH\n" + "1lCy+QKCAQEAhjYxTQyConWq2CGjQCf5+DL624hQJYLxTIV1Nrp8wCsbsaZ8Yp0X\n" + "hZ7u38lijr3H2zo8tyCOzO0YqJgxHJWE3lZoHH/BtM3Zwizixd8NHA9PHvUQyn+a\n" + "COZU3xc19He6/0EYCWJtPVwIehH6y6kRytwH5L4tRve7UzWPTVZZwtafK7MA218H\n" + "GZbqVZbaj10lsK+5jcZSB04m3a5RVebk3jJtlY2wITi7tm1tWQghctr+twx+aV32\n" + "OblXeZokqbamOiM0FyDjtSTJO6HCLzwyT6ygHnHU1Ar1vEtzNWuw+k9A5685eeMu\n" + "8luv+yWMMQ4BnAOnup0dkGJd3F6u3lNmKQKCAQEAkZKTi4g+AKGikU9JuKOTF2UH\n" + "DdolZK/pXfWRIzpyC5cxiwnHhpqtl4jRrNSVSXWo/ChdtxhGv+8FNO+G8wQr0I0K\n" + "iWF4qWU/q7vbjWuE8mDrWfaCWM3IwEoMQ7Ub+gTf2JzQG0t0oSgJ70uaYxxm2x7U\n" + "eBnblzZ6ODG7jgq2WX9S8/JzTvqrtlVDmdPUJlsIymRiDHt+zDAh4kv0aRUJQVWy\n" + "cCd1rWSl2BnCgrLi0Ez5k5EuOW7v3TIFWI6AXAa5kshG3Q5du/TtPLSyqJZu2UCx\n" + "3LprECmh9aJmE8KZL3R5ClSqkzVjuOUj56y63vSiV1B7kTbTnT7G+sJwwgxvyA==\n" + "-----END RSA PRIVATE KEY-----\n"; + +static constexpr char VALID_CERT2[] = + "-----BEGIN CERTIFICATE-----\n" + "MIIF+jCCA+KgAwIBAgICEAAwDQYJKoZIhvcNAQELBQAwgZcxCzAJBgNVBAYTAlVT\n" + "MRMwEQYDVQQIDApDYWxpZm9ybmlhMRUwEwYDVQQKDAxTYW5kc3Rvcm0uaW8xGzAZ\n" + "BgNVBAsMElRlc3RpbmcgRGVwYXJ0bWVudDEbMBkGA1UEAwwSaW50LWNhLmV4YW1w\n" + "bGUuY29tMSIwIAYJKoZIhvcNAQkBFhNnYXJwbHlAc2FuZHN0b3JtLmlvMCAXDTIw\n" + "MDYyNzAwNDI1M1oYDzIxMjAwNjI3MDA0MjUzWjCBkDELMAkGA1UEBhMCVVMxEzAR\n" + "BgNVBAgMCkNhbGlmb3JuaWExFTATBgNVBAoMDFNhbmRzdG9ybS5pbzEbMBkGA1UE\n" + "CwwSVGVzdGluZyBEZXBhcnRtZW50MRQwEgYDVQQDDAtleGFtcGxlLm5ldDEiMCAG\n" + "CSqGSIb3DQEJARYTZ2FycGx5QHNhbmRzdG9ybS5pbzCCAiIwDQYJKoZIhvcNAQEB\n" + "BQADggIPADCCAgoCggIBALaVC2wVXHj+VEz/AKv6meM7sLNJoH0Ws5RRQ859EqcS\n" + "kYH1ETqwwfBDHH3y/iMVmHc7M2J5u0eKfLnUB7anV9ZviaVjyHBMjhrihIxBhlQW\n" + "37XohrwmWh5poUhhbgbkUJ7JjjDdrF+AEz4CEnlKSs92HBMOc2xlfV73kAQogmWC\n" + "ThGNkas7s38Jju+jX/9KULRWjer4xzhjqeXusRFlDTK7pYsPN/wWJrThNAh/UJi0\n" + "ouXf7DjLKaV1XbsFLIR79oMKHFxYpOBVIQzCo386e9br0S6gJ6bH+rubkoV8KoGp\n" + "qIUHLYfQ9MeSjKE8JX+y51ednSeBadJ6eBWJDYSScajF4ybs1uB6h8rf883pYV2r\n" + "ysj8n1xppER80CwGjswzft6lwbwK7m2GqbsvwhdZ/w4fTr3BifLytsLWoRooNlq8\n" + "IeZNHmRIkH5BbVyoVZsTYHBBK+nKCDwD3vPNhdB61oDOYf4bkudDAuTlso5wMkme\n" + "y/0I9O4w3zXIo9cmo5r+GobC/BSOIWpfJNw0l3tzA5wMJgcSoPm/ylFjYrnHSJ2V\n" + "Gde3Oe+uWkKHiD2RVQaAEK4BRSzxN9cAu0sKzLijIctlCm4F43GeTDJG0tlgUmqg\n" + "RS0srMx2NHBTwkYeiu3e7w70hu5WU9epTRX2oG73DtbrrEAQ9toBrzThds2Xt++x\n" + "AgMBAAGjUzBRMB0GA1UdDgQWBBSCVMitc7axjQ0JObyQ7SoZ15v41jAfBgNVHSME\n" + "GDAWgBTZxw81mTQuaT+p41J39QTF7kCp5zAPBgNVHRMBAf8EBTADAQH/MA0GCSqG\n" + "SIb3DQEBCwUAA4ICAQAGqI+GGbSHkV9C16OLKgujS17zAJDuMeUZVoUvsh0oj7hK\n" + "QwuJ6M6VIWZXk0Ccs/TbtQgyUtt98HY/M5LYjvuB3jb348TvYvBg1un6DC1LNFnw\n" + "x19eUvwxhoI0I9A/heD6251plaXl0rk+wmTn+gqHNswb0LZw7l8XclOQ8s13/Ei3\n" + "fD4P5N3LiXaPfcXzFtEvWJE1ONC/PvLfwWWE2T+/LabJ4I4iumX8oAJZyx9BCE09\n" + "54/0cV1V6xjp31/CS7vkYtDMeREnydwC3PsjjzO18nM0GVw6R2eok/yvD2Rg/pqJ\n" + "CiscKswcy0OR42pCzJyAwHaXV0KZEG9E97ukiqh3ByBUfR0ZwkKv7tDaL7UQiXdF\n" + "sheJ3l8TyQNcuWljjm1MWJt9ZzZt5zE4+yes4YVDNNe9l2jIoT8641pcy2MmPOdJ\n" + "8pEE3xJ2SAdeJKVXuHoi7glzmlK1O5nSNK3GIfKRwJ2hmIXSAoMPfpwtJWdJDNGZ\n" + "N2HThXDMleMrJqwsdToRCp0nBm40cKSDk/o7SfiE7z4e1EVDAFBlWl5SAq9Pqwh5\n" + "lBlsQXd5SbzWGyVk7BjtT3ttbXru9NEINo1l9Cw74GQuW40FsQf4drZVDVtaNWPd\n" + "IvZM211bcU/zZV44rkz3nc08jSGo2qP8bEcuYAlTneDLyrAmpisUXKYm1Q1SxA==\n" + "-----END CERTIFICATE-----\n"; + +static constexpr char EXPIRED_CERT[] = + "-----BEGIN CERTIFICATE-----\n" + "MIIF+DCCA+CgAwIBAgICEAAwDQYJKoZIhvcNAQELBQAwgZcxCzAJBgNVBAYTAlVT\n" + "MRMwEQYDVQQIDApDYWxpZm9ybmlhMRUwEwYDVQQKDAxTYW5kc3Rvcm0uaW8xGzAZ\n" + "BgNVBAsMElRlc3RpbmcgRGVwYXJ0bWVudDEbMBkGA1UEAwwSaW50LWNhLmV4YW1w\n" + "bGUuY29tMSIwIAYJKoZIhvcNAQkBFhNnYXJwbHlAc2FuZHN0b3JtLmlvMB4XDTE2\n" + "MDEwMTAwMDAwMFoXDTE2MDEwMTAwMDAwMFowgZAxCzAJBgNVBAYTAlVTMRMwEQYD\n" + "VQQIDApDYWxpZm9ybmlhMRUwEwYDVQQKDAxTYW5kc3Rvcm0uaW8xGzAZBgNVBAsM\n" + "ElRlc3RpbmcgRGVwYXJ0bWVudDEUMBIGA1UEAwwLZXhhbXBsZS5jb20xIjAgBgkq\n" + "hkiG9w0BCQEWE2dhcnBseUBzYW5kc3Rvcm0uaW8wggIiMA0GCSqGSIb3DQEBAQUA\n" + "A4ICDwAwggIKAoICAQDM7r8cYnjo8iiBkk/F9ITp4e6oTUm8Lx5pII4qmpt8awgc\n" + "9+Hx35MVq0N7gYXgHphhibGklQ3znBkL7oZMDvF4z1S5uyyLB8AIT+x7mUkDRJJt\n" + "ej1JeQTmwj8szs0Dr6Hc9OaguXB25YeHq3Adtx9wlS1I5Il6YItSpk6JJD4IZbZH\n" + "MQXTVIz946ak32C2ZV6MPvA7Ei1uNf8MMT8gq/rFDYbVFjESmCAKc1CNmPhHS6ug\n" + "3ou0zuGJB5u55ezcp9UgRwRX2hd9N/a0jgbc1Of0wlS9lA0gxRU62OXxSYsd+TdP\n" + "0dx4v9EzSy9Yg9Y2hQnEDqbi7iD0PR4QPZbRyEp7zq0XrlFFge881Bdl/60bvDK9\n" + "Eo949YiUTMgCBdbZbBcS/MW+9Gd9fWZh5YDsqHJFpVwKn0fSLlk/QvwTZGy9h37Z\n" + "AybOvGR4zTUowlJMC1SGrJ01ofl3OWnxRda+hMUsCHgXvjg6N07KMeqhNmLWiH6K\n" + "qxzeji1ZAXdjOmOFqs/eaeAaP5gOn9k3oXxB0VBT6I0UrHpoNGCcqb6uLT/JMMxD\n" + "cbXhtD0U2YtHAH5blQj07L3fUnHFxbsvGdxseADfbr75sPUJ0tfLE6+kcrjQHl/3\n" + "OyCb3zTF2MFwuvajScYCxn1Xv1rQE2qO8MDcOOZcEZNjElKkwZ+Kk/kps/aGXQID\n" + "AQABo1MwUTAdBgNVHQ4EFgQUXaT1S4zJyRbO5QrTFusd/Vp8Ld0wHwYDVR0jBBgw\n" + "FoAU2ccPNZk0Lmk/qeNSd/UExe5AqecwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG\n" + "9w0BAQsFAAOCAgEAYRQazCV707BpbBo3j2PXfg4rmrm1GIA0JXFsY27CII0aSgTw\n" + "roMBnwp3sJ+UIxqddkwf/4Bn/kq8yHu8WMc1cb4bzsqgU4K2zGJcVF3i9Y4R6oE1\n" + "9Y6QM1db5HYiCdXSNW5uZUQGButyIXsUfPns0jMZfmEhsW4WrN3m2qE357FeBfCF\n" + "nP4Ij3sbUq01OoPBL6sWUbltfL5PgqKitE6UFu1/WFpBatP+8ITOOLhkGmJ70zb1\n" + "rY3jnwlEaRJVw8DmIkkabrgKu2gUZ3qKX9aeW9iJW524A2ZjEg1xoBRq9MVcfiRN\n" + "qAERaEd7MgjIFmOYR5O2juzR3eMrv+U+JsY6K5ItTwwE/MHKjYRM22CgmQahPfvo\n" + "o40qLUn/zJNJAN+1hrmOqFXpC5vmfKV9pcG7BOsuZ9V9gkssnJRosCuU8iIW3gyo\n" + "C6TFLIneStvBzokoCTv0Fxxh/vqfIWmHYr7nsFM8S/X2iDLuCoiB6qsRw5NaOIg6\n" + "QB7cEi3sgBZU+eJDmynR3waIU0HGmj9DK8Tc2TMd5wvkBpJBkqqQkICVQ/u/g7up\n" + "swvT5Iap509sI4nmKqe9meN6m3xBSJrCNPTbjUyu7PuC/rBe7lVwnP5/PN/aj6ZU\n" + "XGyLwArQ/5GgT2sy3aEQQTtb+kthnZo7NL8nmkpoTbrm84DJ4dwkD4qc+hs=\n" + "-----END CERTIFICATE-----\n"; + +static constexpr char SELF_SIGNED_CERT[] = + "-----BEGIN CERTIFICATE-----\n" + "MIIGLTCCBBWgAwIBAgIUIGB2OqfFvs22f6uTwFJwWKaLH+kwDQYJKoZIhvcNAQEL\n" + "BQAwgaQxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRIwEAYDVQQH\n" + "DAlQYWxvIEFsdG8xFTATBgNVBAoMDFNhbmRzdG9ybS5pbzEbMBkGA1UECwwSVGVz\n" + "dGluZyBEZXBhcnRtZW50MRQwEgYDVQQDDAtleGFtcGxlLmNvbTEiMCAGCSqGSIb3\n" + "DQEJARYTZ2FycGx5QHNhbmRzdG9ybS5pbzAgFw0yMDA2MjcwMDQyNTNaGA8yMTIw\n" + "MDYyNzAwNDI1M1owgaQxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh\n" + "MRIwEAYDVQQHDAlQYWxvIEFsdG8xFTATBgNVBAoMDFNhbmRzdG9ybS5pbzEbMBkG\n" + "A1UECwwSVGVzdGluZyBEZXBhcnRtZW50MRQwEgYDVQQDDAtleGFtcGxlLmNvbTEi\n" + "MCAGCSqGSIb3DQEJARYTZ2FycGx5QHNhbmRzdG9ybS5pbzCCAiIwDQYJKoZIhvcN\n" + "AQEBBQADggIPADCCAgoCggIBAMzuvxxieOjyKIGST8X0hOnh7qhNSbwvHmkgjiqa\n" + "m3xrCBz34fHfkxWrQ3uBheAemGGJsaSVDfOcGQvuhkwO8XjPVLm7LIsHwAhP7HuZ\n" + "SQNEkm16PUl5BObCPyzOzQOvodz05qC5cHblh4ercB23H3CVLUjkiXpgi1KmTokk\n" + "PghltkcxBdNUjP3jpqTfYLZlXow+8DsSLW41/wwxPyCr+sUNhtUWMRKYIApzUI2Y\n" + "+EdLq6Dei7TO4YkHm7nl7Nyn1SBHBFfaF3039rSOBtzU5/TCVL2UDSDFFTrY5fFJ\n" + "ix35N0/R3Hi/0TNLL1iD1jaFCcQOpuLuIPQ9HhA9ltHISnvOrReuUUWB7zzUF2X/\n" + "rRu8Mr0Sj3j1iJRMyAIF1tlsFxL8xb70Z319ZmHlgOyockWlXAqfR9IuWT9C/BNk\n" + "bL2HftkDJs68ZHjNNSjCUkwLVIasnTWh+Xc5afFF1r6ExSwIeBe+ODo3Tsox6qE2\n" + "YtaIfoqrHN6OLVkBd2M6Y4Wqz95p4Bo/mA6f2TehfEHRUFPojRSsemg0YJypvq4t\n" + "P8kwzENxteG0PRTZi0cAfluVCPTsvd9SccXFuy8Z3Gx4AN9uvvmw9QnS18sTr6Ry\n" + "uNAeX/c7IJvfNMXYwXC69qNJxgLGfVe/WtATao7wwNw45lwRk2MSUqTBn4qT+Smz\n" + "9oZdAgMBAAGjUzBRMB0GA1UdDgQWBBRdpPVLjMnJFs7lCtMW6x39Wnwt3TAfBgNV\n" + "HSMEGDAWgBRdpPVLjMnJFs7lCtMW6x39Wnwt3TAPBgNVHRMBAf8EBTADAQH/MA0G\n" + "CSqGSIb3DQEBCwUAA4ICAQDEEnq3/Yg0UL3d8z0EWXGnj5hbc8Cf9hj8J918aift\n" + "XhqQ+EU6BV1V7IsrUZ07Wkws7hgd7HpBgmSNQu9cIoSyjiS9je8KL9TqwmBaPNg3\n" + "jE31maqLHdLQfR0USYo7wp8cE7w/tojLwyhuwVEJR4IpVlfAgmD5HMhCX4vwZTUB\n" + "bkzsRtY56JRhNDO2ExY7QPFF4FhXLf8eZGqqk09FTpQemJFwZ2+MYlSOILrP4RL3\n" + "T9LW0EgymAjDUHT047xr5xPAjRUEplqT90bEsAp5D199m/c143tq3Cke/eQDWAR7\n" + "HVYmRmuOhwyqhkKfssZZvuq7Shm0u2vuOvfGhcW7JmukalrCixjitQmbOv1EJT0F\n" + "tQN61nRTUpnC37DEgtYpV8n+GgT1hWXDzC/0UNVOFIR26eX0kxZmqU6v+Rqz/qYe\n" + "NA2TXZ4YvL081QvPFOWVpodM6LLYw2cSGBCdfAdRE1ECoqzk5EgRBH5SrZnuebMG\n" + "V8aJsIRMI011QDEz69YJFzefI9WaawcHqfTWZoCeCBDNm0pEVqRbBAQ8E7IXfQUu\n" + "WjPbENMyTTp6+uRmQkmNJjv9HcvFUu+wyhFODBrZ4LEwFP+oWBGWqru28Se7b66H\n" + "ZvzKfQVXzYpEmhHHK2n5X1hNwSr0kb3QffVpbz3/TBIgQcOyBp/RnOY38pngfRw1\n" + "SQ==\n" + "-----END CERTIFICATE-----\n"; + +// ======================================================================================= + +class ErrorNexus { + // Helper class that wraps various promises such that if one throws an exception, they all do. + +public: + ErrorNexus(): ErrorNexus(kj::newPromiseAndFulfiller()) {} + + template + kj::Promise wrap(kj::Promise&& promise) { + return promise.catch_([this](kj::Exception&& e) -> kj::Promise { + fulfiller->reject(kj::cp(e)); + return kj::mv(e); + }).exclusiveJoin(failurePromise.addBranch().then([]() -> T { KJ_UNREACHABLE; })); + } + +private: + kj::ForkedPromise failurePromise; + kj::Own> fulfiller; + + ErrorNexus(kj::PromiseFulfillerPair paf) + : failurePromise(kj::mv(paf.promise).fork()), + fulfiller(kj::mv(paf.fulfiller)) {} +}; + +struct TlsTest { + kj::AsyncIoContext io = setupAsyncIo(); + TlsContext tlsClient; + TlsContext tlsServer; + + TlsTest(TlsContext::Options clientOpts = defaultClient(), + TlsContext::Options serverOpts = defaultServer()) + : tlsClient(kj::mv(clientOpts)), + tlsServer(kj::mv(serverOpts)) {} + + static TlsContext::Options defaultServer() { + static TlsKeypair keypair = { + TlsPrivateKey(HOST_KEY), + TlsCertificate(kj::str(VALID_CERT, INTERMEDIATE_CERT)) + }; + TlsContext::Options options; + options.defaultKeypair = keypair; + return options; + } + + static TlsContext::Options defaultClient() { + static TlsCertificate caCert(CA_CERT); + TlsContext::Options options; + options.useSystemTrustStore = false; + options.trustedCertificates = kj::arrayPtr(&caCert, 1); + return options; + } + + Promise writeToServer(AsyncIoStream& client) { + return client.write("foo", 4); + } + + Promise readFromClient(AsyncIoStream& server) { + auto buf = heapArray(4); + + auto readPromise = server.read(buf.begin(), buf.size()); + + auto checkBuffer = [buf = kj::mv(buf)]() { + KJ_ASSERT(kj::StringPtr(buf.begin(), buf.end()-1) == kj::StringPtr("foo")); + }; + + return readPromise.then(kj::mv(checkBuffer)); + } + + void testConnection(AsyncIoStream& client, AsyncIoStream& server) { + auto writePromise = writeToServer(client); + auto readPromise = readFromClient(server); + + writePromise.wait(io.waitScope); + readPromise.wait(io.waitScope); + }; +}; + +KJ_TEST("TLS basics") { + TlsTest test; + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + test.testConnection(*client, *server); + + // Test clean shutdown. + { + auto eofPromise = server->readAllText(); + KJ_EXPECT(!eofPromise.poll(test.io.waitScope)); + client->shutdownWrite(); + KJ_ASSERT(eofPromise.poll(test.io.waitScope)); + KJ_EXPECT(eofPromise.wait(test.io.waitScope) == ""_kj); + } + + // Test UNCLEAN shutdown in other direction. + { + auto eofPromise = client->readAllText(); + KJ_EXPECT(!eofPromise.poll(test.io.waitScope)); + { auto drop = kj::mv(server); } + KJ_EXPECT(eofPromise.poll(test.io.waitScope)); + KJ_EXPECT_THROW(DISCONNECTED, eofPromise.wait(test.io.waitScope)); + } +} + +KJ_TEST("TLS half-duplex") { + // Test shutting down one direction of a connection but continuing to stream in the other + // direction. + + TlsTest test; + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + client->shutdownWrite(); + KJ_EXPECT(server->readAllText().wait(test.io.waitScope) == ""); + + for (uint i = 0; i < 100; i++) { + char buffer[7]; + auto writePromise = server->write("foobar", 6); + auto readPromise = client->read(buffer, 6); + writePromise.wait(test.io.waitScope); + readPromise.wait(test.io.waitScope); + buffer[6] = '\0'; + KJ_ASSERT(kj::StringPtr(buffer, 6) == "foobar"); + } + + server->shutdownWrite(); + KJ_EXPECT(client->readAllText().wait(test.io.waitScope) == ""); +} + +KJ_TEST("TLS peer identity") { + TlsTest test; + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto innerClientId = kj::LocalPeerIdentity::newInstance({}); + auto& innerClientIdRef = *innerClientId; + auto clientPromise = e.wrap(test.tlsClient.wrapClient( + kj::AuthenticatedStream { kj::mv(pipe.ends[0]), kj::mv(innerClientId) }, + "example.com")); + + auto innerServerId = kj::LocalPeerIdentity::newInstance({}); + auto& innerServerIdRef = *innerServerId; + auto serverPromise = e.wrap(test.tlsServer.wrapServer( + kj::AuthenticatedStream { kj::mv(pipe.ends[1]), kj::mv(innerServerId) })); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + { + auto id = client.peerIdentity.downcast(); + KJ_ASSERT(id->hasCertificate()); + KJ_EXPECT(id->getCommonName() == "example.com"); + KJ_EXPECT(&id->getNetworkIdentity() == &innerClientIdRef); + KJ_EXPECT(id->toString() == "example.com"); + } + + { + auto id = server.peerIdentity.downcast(); + KJ_EXPECT(!id->hasCertificate()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + "client did not provide a certificate", id->getCommonName()); + KJ_EXPECT(&id->getNetworkIdentity() == &innerServerIdRef); + KJ_EXPECT(id->toString() == "(anonymous client)"); + } + + test.testConnection(*client.stream, *server.stream); +} + +KJ_TEST("TLS multiple messages") { + TlsTest test; + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + auto writePromise = client->write("foo", 3) + .then([&]() { return client->write("bar", 3); }); + + char buf[4]; + buf[3] = '\0'; + + server->read(&buf, 3).wait(test.io.waitScope); + KJ_ASSERT(kj::StringPtr(buf) == "foo"); + + writePromise = writePromise + .then([&]() { return client->write("baz", 3); }); + + server->read(&buf, 3).wait(test.io.waitScope); + KJ_ASSERT(kj::StringPtr(buf) == "bar"); + + server->read(&buf, 3).wait(test.io.waitScope); + KJ_ASSERT(kj::StringPtr(buf) == "baz"); + + auto readPromise = server->read(&buf, 3); + KJ_EXPECT(!readPromise.poll(test.io.waitScope)); + + writePromise = writePromise + .then([&]() { return client->write("qux", 3); }); + + readPromise.wait(test.io.waitScope); + KJ_ASSERT(kj::StringPtr(buf) == "qux"); +} + +KJ_TEST("TLS zero-sized write") { + TlsTest test; + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + char buf[7]; + auto readPromise = server->read(&buf, 6); + + client->write("", 0).wait(test.io.waitScope); + client->write("foo", 3).wait(test.io.waitScope); + client->write("", 0).wait(test.io.waitScope); + client->write("bar", 3).wait(test.io.waitScope); + + readPromise.wait(test.io.waitScope); + buf[6] = '\0'; + + KJ_ASSERT(kj::StringPtr(buf) == "foobar"); +} + +kj::Promise writeN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) { + if (count == 0) return kj::READY_NOW; + --count; + return stream.write(text.begin(), text.size()) + .then([&stream, text, count]() { + return writeN(stream, text, count); + }); +} + +kj::Promise readN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) { + if (count == 0) return kj::READY_NOW; + --count; + auto buf = kj::heapString(text.size()); + auto promise = stream.read(buf.begin(), buf.size()); + return promise.then([&stream, text, buf=kj::mv(buf), count]() { + KJ_ASSERT(buf == text, buf, text, count); + return readN(stream, text, count); + }); +} + +KJ_TEST("TLS full duplex") { + TlsTest test; + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + +#if _WIN32 + // On Windows we observe that `writeUp`, below, completes before the other end has started + // reading, failing the `!writeUp.poll()` expectation. I guess Windows has big buffers. We can + // fix this by requesting small buffers here. (Worth keeping in mind that Windows doesn't have + // socketpairs, so `newTwoWayPipe()` is implemented in terms of loopback TCP, ugh.) + uint small = 256; + pipe.ends[0]->setsockopt(SOL_SOCKET, SO_SNDBUF, &small, sizeof(small)); + pipe.ends[0]->setsockopt(SOL_SOCKET, SO_RCVBUF, &small, sizeof(small)); +#endif + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + auto writeUp = writeN(*client, "foo", 10000); + auto readDown = readN(*client, "bar", 10000); +#if !(_WIN32 && __clang__) + // TODO(someday): work out why this expectation fails even with the above fix + KJ_EXPECT(!writeUp.poll(test.io.waitScope)); +#endif + KJ_EXPECT(!readDown.poll(test.io.waitScope)); + + auto writeDown = writeN(*server, "bar", 10000); + auto readUp = readN(*server, "foo", 10000); + + readUp.wait(test.io.waitScope); + readDown.wait(test.io.waitScope); + writeUp.wait(test.io.waitScope); + writeDown.wait(test.io.waitScope); +} + +class TestSniCallback: public TlsSniCallback { +public: + kj::Maybe getKey(kj::StringPtr hostname) override { + ++callCount; + + KJ_ASSERT(hostname == "example.com"); + return TlsKeypair { + TlsPrivateKey(HOST_KEY), + TlsCertificate(kj::str(VALID_CERT, INTERMEDIATE_CERT)) + }; + } + + uint callCount = 0; +}; + +KJ_TEST("TLS SNI") { + TlsContext::Options serverOptions; + TestSniCallback callback; + serverOptions.sniCallback = callback; + + TlsTest test(TlsTest::defaultClient(), kj::mv(serverOptions)); + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + test.testConnection(*client, *server); + + KJ_ASSERT(callback.callCount == 1); +} + +void expectInvalidCert(kj::StringPtr hostname, TlsCertificate cert, + kj::StringPtr message, kj::Maybe altMessage = nullptr) { + TlsKeypair keypair = { TlsPrivateKey(HOST_KEY), kj::mv(cert) }; + TlsContext::Options serverOpts; + serverOpts.defaultKeypair = keypair; + TlsTest test(TlsTest::defaultClient(), kj::mv(serverOpts)); + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), hostname)); + auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1]))); + + clientPromise.then([](kj::Own) { + KJ_FAIL_EXPECT("expected exception"); + }, [message, altMessage](kj::Exception&& e) { + if (kj::_::hasSubstring(e.getDescription(), message)) { + return; + } + + KJ_IF_MAYBE(a, altMessage) { + if (kj::_::hasSubstring(e.getDescription(), *a)) { + return; + } + } + + KJ_FAIL_EXPECT("exception didn't contain expected message", message, + altMessage.orDefault(nullptr), e); + }).wait(test.io.waitScope); +} + +KJ_TEST("TLS certificate validation") { + // Where we've given two possible error texts below, it's because OpenSSL v1 produces the former + // text while v3 produces the latter. Note that as of this writing, our Windows CI build claims + // to be v3 but produces v1 text, for reasons I don't care to investigate. + expectInvalidCert("wrong.com", TlsCertificate(kj::str(VALID_CERT, INTERMEDIATE_CERT)), + "Hostname mismatch"_kj, "hostname mismatch"_kj); + expectInvalidCert("example.com", TlsCertificate(VALID_CERT), + "unable to get local issuer certificate"_kj); + expectInvalidCert("example.com", TlsCertificate(kj::str(EXPIRED_CERT, INTERMEDIATE_CERT)), + "certificate has expired"_kj); + expectInvalidCert("example.com", TlsCertificate(SELF_SIGNED_CERT), + "self signed certificate"_kj, "self-signed certificate"_kj); +} + +// BoringSSL seems to print error messages differently. +#ifdef OPENSSL_IS_BORINGSSL +#define SSL_MESSAGE_DIFFERENT_IN_BORINGSSL(interesting, boring) boring +#else +#define SSL_MESSAGE_DIFFERENT_IN_BORINGSSL(interesting, boring) interesting +#endif + +KJ_TEST("TLS client certificate verification") { + enum class VerifyClients { + YES, + NO + }; + auto makeServerOptionsForClient = []( + const TlsContext::Options& clientOptions, + VerifyClients verifyClients + ) { + TlsContext::Options serverOptions = TlsTest::defaultServer(); + serverOptions.verifyClients = verifyClients == VerifyClients::YES; + + // Share the certs between the client and server. + serverOptions.trustedCertificates = clientOptions.trustedCertificates; + + return serverOptions; + }; + + TlsKeypair selfSignedKeypair = { TlsPrivateKey(HOST_KEY), TlsCertificate(SELF_SIGNED_CERT) }; + TlsKeypair altKeypair = { + TlsPrivateKey(HOST_KEY2), + TlsCertificate(kj::str(VALID_CERT2, INTERMEDIATE_CERT)), + }; + + { + // No certificate loaded in the client: fail + auto clientOptions = TlsTest::defaultClient(); + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::YES); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com") + .then([](kj::Own stream) { + auto promise = stream->readAllBytes(); + return promise.attach(kj::mv(stream)); + }); + auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1])); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + SSL_MESSAGE_DIFFERENT_IN_BORINGSSL("peer did not return a certificate", + "PEER_DID_NOT_RETURN_A_CERTIFICATE"), + serverPromise.ignoreResult().wait(test.io.waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + SSL_MESSAGE_DIFFERENT_IN_BORINGSSL( + "alert", // "alert handshake failure" or "alert certificate required" + "ALERT"), // "ALERT_HANDSHAKE_FAILURE" or "ALERT_CERTIFICATE_REQUIRED" + clientPromise.ignoreResult().wait(test.io.waitScope)); + } + + { + // Self-signed certificate loaded in the client: fail + auto clientOptions = TlsTest::defaultClient(); + clientOptions.defaultKeypair = selfSignedKeypair; + + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::YES); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com") + .then([](kj::Own stream) { + auto promise = stream->readAllBytes(); + return promise.attach(kj::mv(stream)); + }); + auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1])); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + SSL_MESSAGE_DIFFERENT_IN_BORINGSSL("certificate verify failed", + "CERTIFICATE_VERIFY_FAILED"), + serverPromise.ignoreResult().wait(test.io.waitScope)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE( + SSL_MESSAGE_DIFFERENT_IN_BORINGSSL("alert unknown ca", + "TLSV1_ALERT_UNKNOWN_CA"), + clientPromise.ignoreResult().wait(test.io.waitScope)); + } + + { + // Trusted certificate loaded in the client: success. + auto clientOptions = TlsTest::defaultClient(); + clientOptions.defaultKeypair = altKeypair; + + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::YES); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); + + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer( + kj::AuthenticatedStream { kj::mv(pipe.ends[1]), kj::UnknownPeerIdentity::newInstance() })); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + auto id = server.peerIdentity.downcast(); + KJ_ASSERT(id->hasCertificate()); + KJ_EXPECT(id->getCommonName() == "example.net"); + + test.testConnection(*client, *server.stream); + } + + { + // If verifyClients is off, client certificate is ignored, even if trusted. + auto clientOptions = TlsTest::defaultClient(); + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::NO); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); + + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer( + kj::AuthenticatedStream { kj::mv(pipe.ends[1]), kj::UnknownPeerIdentity::newInstance() })); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + auto id = server.peerIdentity.downcast(); + KJ_EXPECT(!id->hasCertificate()); + } + + { + // Non-trusted keys are ignored too (not errors). + auto clientOptions = TlsTest::defaultClient(); + clientOptions.defaultKeypair = selfSignedKeypair; + + auto serverOptions = makeServerOptionsForClient(clientOptions, VerifyClients::NO); + TlsTest test(kj::mv(clientOptions), kj::mv(serverOptions)); + + ErrorNexus e; + + auto pipe = test.io.provider->newTwoWayPipe(); + + auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com")); + auto serverPromise = e.wrap(test.tlsServer.wrapServer( + kj::AuthenticatedStream { kj::mv(pipe.ends[1]), kj::UnknownPeerIdentity::newInstance() })); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + auto id = server.peerIdentity.downcast(); + KJ_EXPECT(!id->hasCertificate()); + } +} + +class MockConnectionReceiver final: public ConnectionReceiver { + // This connection receiver allows mocked async connection establishment without the network. + + struct ClientRequest { + Maybe maybeException; + Own>> clientFulfiller; + }; + +public: + MockConnectionReceiver(AsyncIoProvider& provider): provider(provider) {} + + Promise> accept() override { + return acceptImpl(); + } + + Promise acceptAuthenticated() override { + return acceptImpl().then([](auto stream) -> AuthenticatedStream { + return { kj::mv(stream), LocalPeerIdentity::newInstance({}) }; + }); + } + + uint getPort() override { + return 0; + } + void getsockopt(int, int, void*, uint*) override {} + void setsockopt(int, int, const void*, uint) override {} + + Promise> connect() { + // Mock a new successful connection to our receiver. + return connectImpl(); + } + + Promise badConnect() { + // Mock a new failed connection to our receiver. + return connectImpl(KJ_EXCEPTION(DISCONNECTED, "Pipes are leaky")).ignoreResult(); + } + +private: + Promise> acceptImpl() { + if (clientRequests.empty()) { + KJ_ASSERT(!serverFulfiller); + auto paf = newPromiseAndFulfiller(); + serverFulfiller = kj::mv(paf.fulfiller); + return paf.promise.then([this] { + return acceptImpl(); + }); + } + + // This is accepting in FILO order, it shouldn't matter in practice. + auto request = kj::mv(clientRequests.back()); + clientRequests.removeLast(); + + KJ_IF_MAYBE(exception, kj::mv(request.maybeException)) { + request.clientFulfiller = nullptr; // The other end had an issue, break the promise. + return kj::mv(*exception); + } else { + auto pipe = provider.newTwoWayPipe(); + request.clientFulfiller->fulfill(kj::mv(pipe.ends[0])); + return kj::mv(pipe.ends[1]); + } + } + + Promise> connectImpl(Maybe maybeException = nullptr) { + auto paf = newPromiseAndFulfiller>(); + clientRequests.add(ClientRequest{ kj::mv(maybeException), kj::mv(paf.fulfiller) }); + + if (auto fulfiller = kj::mv(serverFulfiller)) { + fulfiller->fulfill(); + } + + return kj::mv(paf.promise); + } + + AsyncIoProvider& provider; + + Own> serverFulfiller; + Vector clientRequests; +}; + +class TlsReceiverTest final: public TlsTest { + // TlsReceiverTest augments TlsTest to test TlsConnectionReceiver. +public: + TlsReceiverTest(): TlsTest() { + auto baseReceiverPtr = kj::heap(*io.provider); + baseReceiver = baseReceiverPtr.get(); + receiver = tlsServer.wrapPort(kj::mv(baseReceiverPtr)); + } + + TlsReceiverTest(TlsReceiverTest&&) = delete; + TlsReceiverTest(const TlsReceiverTest&) = delete; + TlsReceiverTest& operator=(TlsReceiverTest&&) = delete; + TlsReceiverTest& operator=(const TlsReceiverTest&) = delete; + + Own receiver; + MockConnectionReceiver* baseReceiver; +}; + +KJ_TEST("TLS receiver basics") { + TlsReceiverTest test; + + auto clientPromise = test.baseReceiver->connect().then([&](auto stream) { + return test.tlsClient.wrapClient(kj::mv(stream), "example.com"); + }); + auto serverPromise = test.receiver->accept(); + + auto client = clientPromise.wait(test.io.waitScope); + auto server = serverPromise.wait(test.io.waitScope); + + test.testConnection(*client, *server); +} + +KJ_TEST("TLS receiver experiences pre-TLS error") { + TlsReceiverTest test; + + KJ_LOG(INFO, "Accepting before a bad connect"); + auto promise = test.receiver->accept(); + + KJ_LOG(INFO, "Disappointing our server"); + test.baseReceiver->badConnect(); + + // Can't use KJ_EXPECT_THROW_RECOVERABLE_MESSAGE because wait() that returns a value can't throw + // recoverable exceptions. Can't use KJ_EXPECT_THROW_MESSAGE because non-recoverable exceptions + // will fork() in -fno-exception which screws up our state. + promise.then([](auto) { + KJ_FAIL_EXPECT("expected exception"); + }, [](kj::Exception&& e) { + KJ_EXPECT(e.getDescription() == "Pipes are leaky"); + }).wait(test.io.waitScope); + + KJ_LOG(INFO, "Trying to load a promise after failure"); + test.receiver->accept().then([](auto) { + KJ_FAIL_EXPECT("expected exception"); + }, [](kj::Exception&& e) { + KJ_EXPECT(e.getDescription() == "Pipes are leaky"); + }).wait(test.io.waitScope); +} + +KJ_TEST("TLS receiver accepts multiple clients") { + TlsReceiverTest test; + + auto wrapClient = [&](auto stream) { + return test.tlsClient.wrapClient(kj::mv(stream), "example.com"); + }; + + auto writeToServer = [&](auto client) { + return test.writeToServer(*client).attach(kj::mv(client)); + }; + + auto readFromClient = [&](auto server) { + return test.readFromClient(*server).attach(kj::mv(server)); + }; + + KJ_LOG(INFO, "Requesting a bunch of client connects"); + constexpr auto kClientCount = 20; + auto clientPromises = Vector>(); + for (auto i = 0; i < kClientCount; ++i) { + auto clientPromise = test.baseReceiver->connect().then(wrapClient).then(writeToServer); + clientPromises.add(kj::mv(clientPromise)); + } + + KJ_LOG(INFO, "Requesting and resolving a bunch of server accepts in sequence"); + for (auto i = 0; i < kClientCount; ++i) { + // Resolve each receive in sequence like the Supervisor/Network. + test.receiver->accept().then(readFromClient).wait(test.io.waitScope); + } + + KJ_LOG(INFO, "Resolving all of our client connects in parallel"); + joinPromises(clientPromises.releaseAsArray()).wait(test.io.waitScope); + + KJ_LOG(INFO, "Requesting one last server accept that we'll never resolve"); + auto extraAcceptPromise = test.receiver->accept().then(readFromClient); + KJ_EXPECT(!extraAcceptPromise.poll(test.io.waitScope)); +} + +KJ_TEST("TLS receiver does not stall on client that disconnects before ssl handshake") { + TlsReceiverTest test; + + auto wrapClient = [&](auto stream) { + return test.tlsClient.wrapClient(kj::mv(stream), "example.com"); + }; + + auto writeToServer = [&](auto client) { + return test.writeToServer(*client).attach(kj::mv(client)); + }; + + auto readFromClient = [&](auto server) { + return test.readFromClient(*server).attach(kj::mv(server)); + }; + + constexpr auto kClientCount = 20; + auto clientPromises = Vector>(); + + KJ_LOG(INFO, "Requesting the first batch of client connects in parallel"); + for (auto i = 0; i < kClientCount / 2; ++i) { + auto clientPromise = test.baseReceiver->connect().then(wrapClient).then(writeToServer); + clientPromises.add(kj::mv(clientPromise)); + } + + KJ_LOG(INFO, "Requesting and resolving a client connect that hangs up before ssl connect"); + KJ_ASSERT(test.baseReceiver->connect().wait(test.io.waitScope)); + + KJ_LOG(INFO, "Requesting the second batch of client connects in parallel"); + for (auto i = 0; i < kClientCount / 2; ++i) { + auto clientPromise = test.baseReceiver->connect().then(wrapClient).then(writeToServer); + clientPromises.add(kj::mv(clientPromise)); + } + + KJ_LOG(INFO, "Requesting and resolving a bunch of server accepts in sequence"); + for (auto i = 0; i < kClientCount; ++i) { + test.receiver->accept().then(readFromClient).wait(test.io.waitScope); + } + + KJ_LOG(INFO, "Resolving all of our client connects in parallel"); + joinPromises(clientPromises.releaseAsArray()).wait(test.io.waitScope); + + KJ_LOG(INFO, "Requesting one last server accept that we'll never resolve"); + auto extraAcceptPromise = test.receiver->accept().then(readFromClient); + KJ_EXPECT(!extraAcceptPromise.poll(test.io.waitScope)); +} + +KJ_TEST("TLS receiver does not stall on hung client") { + TlsReceiverTest test; + + auto wrapClient = [&](auto stream) { + return test.tlsClient.wrapClient(kj::mv(stream), "example.com"); + }; + + auto writeToServer = [&](auto client) { + return test.writeToServer(*client).attach(kj::mv(client)); + }; + + auto readFromClient = [&](auto server) { + return test.readFromClient(*server).attach(kj::mv(server)); + }; + + constexpr auto kClientCount = 20; + auto clientPromises = Vector>(); + + KJ_LOG(INFO, "Requesting the first batch of client connects in parallel"); + for (auto i = 0; i < kClientCount / 2; ++i) { + auto clientPromise = test.baseReceiver->connect().then(wrapClient).then(writeToServer); + clientPromises.add(kj::mv(clientPromise)); + } + + KJ_LOG(INFO, "Requesting and resolving a client connect that never does ssl connect"); + auto hungClient = test.baseReceiver->connect().wait(test.io.waitScope); + KJ_ASSERT(hungClient); + + KJ_LOG(INFO, "Requesting the second batch of client connects in parallel"); + for (auto i = 0; i < kClientCount / 2; ++i) { + auto clientPromise = test.baseReceiver->connect().then(wrapClient).then(writeToServer); + clientPromises.add(kj::mv(clientPromise)); + } + + KJ_LOG(INFO, "Requesting and resolving a bunch of server accepts in sequence"); + for (auto i = 0; i < kClientCount; ++i) { + test.receiver->accept().then(readFromClient).wait(test.io.waitScope); + } + + KJ_LOG(INFO, "Resolving all of our client connects in parallel"); + joinPromises(clientPromises.releaseAsArray()).wait(test.io.waitScope); + + KJ_LOG(INFO, "Releasing the hung client"); + hungClient = {}; + + KJ_LOG(INFO, "Requesting one last server accept that we'll never resolve"); + auto extraAcceptPromise = test.receiver->accept().then(readFromClient); + KJ_EXPECT(!extraAcceptPromise.poll(test.io.waitScope)); +} + +kj::Promise expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { + if (expected.size() == 0) return kj::READY_NOW; + + auto buffer = kj::heapArray(expected.size()); + + auto promise = in.tryRead(buffer.begin(), 1, buffer.size()); + return promise.then([&in,expected,buffer=kj::mv(buffer)](size_t amount) { + if (amount == 0) { + KJ_FAIL_ASSERT("expected data never sent", expected); + } + + auto actual = buffer.slice(0, amount); + if (memcmp(actual.begin(), expected.begin(), actual.size()) != 0) { + KJ_FAIL_ASSERT("data from stream doesn't match expected", expected, actual); + } + + return expectRead(in, expected.slice(amount)); + }); +} + +kj::Promise expectEnd(kj::AsyncInputStream& in) { + static char buffer; + + auto promise = in.tryRead(&buffer, 1, 1); + return promise.then([](size_t amount) { + KJ_ASSERT(amount == 0, "expected EOF"); + }); +} + +KJ_TEST("NetworkHttpClient connect with tlsStarter") { + auto io = kj::setupAsyncIo(); + auto listener1 = io.provider->getNetwork().parseAddress("127.0.0.1", 0) + .wait(io.waitScope)->listen(); + + auto acceptLoop KJ_UNUSED = listener1->accept().then([](Own stream) { + return stream->pumpTo(*stream).attach(kj::mv(stream)).ignoreResult(); + }).eagerlyEvaluate(nullptr); + + HttpClientSettings clientSettings; + kj::TimerImpl clientTimer(kj::origin()); + HttpHeaderTable headerTable; + TlsContext tls; + + auto tlsNetwork = tls.wrapNetwork(io.provider->getNetwork()); + clientSettings.tlsContext = tls; + auto client = newHttpClient(clientTimer, headerTable, + io.provider->getNetwork(), *tlsNetwork, clientSettings); + kj::HttpConnectSettings httpConnectSettings = { false, nullptr }; + kj::TlsStarterCallback tlsStarter; + httpConnectSettings.tlsStarter = tlsStarter; + auto request = client->connect( + kj::str("127.0.0.1:", listener1->getPort()), HttpHeaders(headerTable), httpConnectSettings); + + KJ_ASSERT(tlsStarter != nullptr); + + auto buf = kj::heapArray(4); + + auto promises = kj::heapArrayBuilder>(2); + promises.add(request.connection->write("hello", 5)); + promises.add(expectRead(*request.connection, "hello"_kj)); + kj::joinPromisesFailFast(promises.finish()) + .then([io=kj::mv(request.connection)]() mutable { + io->shutdownWrite(); + return expectEnd(*io).attach(kj::mv(io)); + }).attach(kj::mv(listener1)).wait(io.waitScope); +} + +#ifdef KJ_EXTERNAL_TESTS +KJ_TEST("TLS to capnproto.org") { + kj::AsyncIoContext io = setupAsyncIo(); + TlsContext tls; + + auto network = tls.wrapNetwork(io.provider->getNetwork()); + auto addr = network->parseAddress("capnproto.org", 443).wait(io.waitScope); + auto stream = addr->connect().wait(io.waitScope); + + kj::StringPtr request = + "HEAD / HTTP/1.1\r\n" + "Host: capnproto.org\r\n" + "Connection: close\r\n" + "User-Agent: capnp-test/0.6\r\n" + "\r\n"; + + stream->write(request.begin(), request.size()).wait(io.waitScope); + + char buffer[4096]; + size_t n = stream->tryRead(buffer, sizeof(buffer) - 1, sizeof(buffer) - 1).wait(io.waitScope); + buffer[n] = '\0'; + kj::StringPtr response(buffer, n); + + KJ_ASSERT(response.startsWith("HTTP/1.1 200 OK\r\n")); +} +#endif + +} // namespace +} // namespace kj + +#endif // KJ_HAS_OPENSSL diff --git a/c++/src/kj/compat/tls.c++ b/c++/src/kj/compat/tls.c++ new file mode 100644 index 0000000000..6affeb1f67 --- /dev/null +++ b/c++/src/kj/compat/tls.c++ @@ -0,0 +1,1147 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if KJ_HAS_OPENSSL + +#include "tls.h" + +#include "readiness-io.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#define BIO_set_init(x,v) (x->init=v) +#define BIO_get_data(x) (x->ptr) +#define BIO_set_data(x,v) (x->ptr=v) +#endif + +namespace kj { + +// ======================================================================================= +// misc helpers + +namespace { + +kj::Exception getOpensslError() { + // Call when an OpenSSL function returns an error code to convert that into an exception. + + kj::Vector lines; + while (unsigned long long error = ERR_get_error()) { +#ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING + // OpenSSL 3.0+ reports unexpected disconnects this way. + if (ERR_GET_REASON(error) == SSL_R_UNEXPECTED_EOF_WHILE_READING) { + return KJ_EXCEPTION(DISCONNECTED, + "peer disconnected without gracefully ending TLS session"); + } +#endif + + char message[1024]; + ERR_error_string_n(error, message, sizeof(message)); + lines.add(kj::heapString(message)); + } + kj::String message = kj::strArray(lines, "\n"); + return KJ_EXCEPTION(FAILED, "OpenSSL error", message); +} + +KJ_NORETURN(void throwOpensslError()); +void throwOpensslError() { + // Call when an OpenSSL function returns an error code to convert that into an exception and + // throw it. + + kj::throwFatalException(getOpensslError()); +} + +#if OPENSSL_VERSION_NUMBER < 0x10100000L && !defined(OPENSSL_IS_BORINGSSL) +// Older versions of OpenSSL don't define _up_ref() functions. + +void EVP_PKEY_up_ref(EVP_PKEY* pkey) { + CRYPTO_add(&pkey->references, 1, CRYPTO_LOCK_EVP_PKEY); +} + +void X509_up_ref(X509* x509) { + CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509); +} + +#endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +class OpenSslInit { + // Initializes the OpenSSL library. +public: + OpenSslInit() { + SSL_library_init(); + SSL_load_error_strings(); + OPENSSL_config(nullptr); + } +}; + +void ensureOpenSslInitialized() { + // Initializes the OpenSSL library the first time it is called. + static OpenSslInit init; +} +#else +inline void ensureOpenSslInitialized() { + // As of 1.1.0, no initialization is needed. +} +#endif + +bool isIpAddress(kj::StringPtr addr) { + bool isPossiblyIp6 = true; + bool isPossiblyIp4 = true; + uint colonCount = 0; + uint dotCount = 0; + for (auto c: addr) { + if (c == ':') { + isPossiblyIp4 = false; + ++colonCount; + } else if (c == '.') { + isPossiblyIp6 = false; + ++dotCount; + } else if ('0' <= c && c <= '9') { + // Digit is valid for ipv4 or ipv6. + } else if (('a' <= c && c <= 'f') || ('A' <= c && c <= 'F')) { + // Hex digit could be ipv6 but not ipv4. + isPossiblyIp4 = false; + } else { + // Nope. + return false; + } + } + + // An IPv4 address has 3 dots. (Yes, I'm aware that technically IPv4 addresses can be formatted + // with fewer dots, but it's not clear that we actually want to support TLS authentication of + // non-canonical address formats, so for now I'm not. File a bug if you care.) An IPv6 address + // has at least 2 and as many as 7 colons. + return (isPossiblyIp4 && dotCount == 3) + || (isPossiblyIp6 && colonCount >= 2 && colonCount <= 7); +} + +} // namespace + +// ======================================================================================= +// Implementation of kj::AsyncIoStream that applies TLS on top of some other AsyncIoStream. +// +// TODO(perf): OpenSSL's I/O abstraction layer, "BIO", is readiness-based, but AsyncIoStream is +// completion-based. This forces us to use an intermediate buffer which wastes memory and incurs +// redundant copies. We could improve the situation by creating a way to detect if the underlying +// AsyncIoStream is simply wrapping a file descriptor (or other readiness-based stream?) and use +// that directly if so. + +class TlsConnection final: public kj::AsyncIoStream { +public: + TlsConnection(kj::Own stream, SSL_CTX* ctx) + : TlsConnection(*stream, ctx) { + ownInner = kj::mv(stream); + } + + TlsConnection(kj::AsyncIoStream& stream, SSL_CTX* ctx) + : inner(stream), readBuffer(stream), writeBuffer(stream) { + ssl = SSL_new(ctx); + if (ssl == nullptr) { + throwOpensslError(); + } + + BIO* bio = BIO_new(const_cast(getBioVtable())); + if (bio == nullptr) { + SSL_free(ssl); + throwOpensslError(); + } + + BIO_set_data(bio, this); + BIO_set_init(bio, 1); + SSL_set_bio(ssl, bio, bio); + } + + kj::Promise connect(kj::StringPtr expectedServerHostname) { + if (!SSL_set_tlsext_host_name(ssl, expectedServerHostname.cStr())) { + return getOpensslError(); + } + + X509_VERIFY_PARAM* verify = SSL_get0_param(ssl); + if (verify == nullptr) { + return getOpensslError(); + } + + if (isIpAddress(expectedServerHostname)) { + if (X509_VERIFY_PARAM_set1_ip_asc(verify, expectedServerHostname.cStr()) <= 0) { + return getOpensslError(); + } + } else { + if (X509_VERIFY_PARAM_set1_host( + verify, expectedServerHostname.cStr(), expectedServerHostname.size()) <= 0) { + return getOpensslError(); + } + } + + // As of OpenSSL 1.1.0, X509_V_FLAG_TRUSTED_FIRST is on by default. Turning it on for older + // versions -- as well as certain OpenSSL-compatible libraries -- fixes the problem described + // here: https://community.letsencrypt.org/t/openssl-client-compatibility-changes-for-let-s-encrypt-certificates/143816 + // + // Otherwise, certificates issued by Let's Encrypt won't work as of September 30, 2021: + // https://letsencrypt.org/docs/dst-root-ca-x3-expiration-september-2021/ + X509_VERIFY_PARAM_set_flags(verify, X509_V_FLAG_TRUSTED_FIRST); + + return sslCall([this]() { return SSL_connect(ssl); }).then([this](size_t) { + X509* cert = SSL_get_peer_certificate(ssl); + KJ_REQUIRE(cert != nullptr, "TLS peer provided no certificate") { return; } + X509_free(cert); + + auto result = SSL_get_verify_result(ssl); + if (result != X509_V_OK) { + const char* reason = X509_verify_cert_error_string(result); + KJ_FAIL_REQUIRE("TLS peer's certificate is not trusted", reason) { break; } + } + }); + } + + kj::Promise accept() { + // We are the server. Set SSL options to prefer server's cipher choice. + SSL_set_options(ssl, SSL_OP_CIPHER_SERVER_PREFERENCE); + + auto acceptPromise = sslCall([this]() { + return SSL_accept(ssl); + }); + return acceptPromise.then([](size_t ret) { + if (ret == 0) { + kj::throwRecoverableException( + KJ_EXCEPTION(DISCONNECTED, "Client disconnected during SSL_accept()")); + } + }); + } + + kj::Own getIdentity(kj::Own inner) { + return kj::heap(SSL_get_peer_certificate(ssl), kj::mv(inner), + kj::Badge()); + } + + ~TlsConnection() noexcept(false) { + SSL_free(ssl); + } + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return tryReadInternal(buffer, minBytes, maxBytes, 0); + } + + Promise write(const void* buffer, size_t size) override { + return writeInternal(kj::arrayPtr(reinterpret_cast(buffer), size), nullptr); + } + + Promise write(ArrayPtr> pieces) override { + auto cork = writeBuffer.cork(); + return writeInternal(pieces[0], pieces.slice(1, pieces.size())).attach(kj::mv(cork)); + } + + Promise whenWriteDisconnected() override { + return inner.whenWriteDisconnected(); + } + + void shutdownWrite() override { + KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()"); + + // TODO(2.0): shutdownWrite() is problematic because it doesn't return a promise. It was + // designed to assume that it would only be called after all writes are finished and that + // there was no reason to block at that point, but SSL sessions don't fit this since they + // actually have to send a shutdown message. + shutdownTask = sslCall([this]() { + // The first SSL_shutdown() call is expected to return 0 and may flag a misleading error. + int result = SSL_shutdown(ssl); + return result == 0 ? 1 : result; + }).ignoreResult().eagerlyEvaluate([](kj::Exception&& e) { + KJ_LOG(ERROR, e); + }); + } + + void abortRead() override { + inner.abortRead(); + } + + void getsockopt(int level, int option, void* value, uint* length) override { + inner.getsockopt(level, option, value, length); + } + void setsockopt(int level, int option, const void* value, uint length) override { + inner.setsockopt(level, option, value, length); + } + + void getsockname(struct sockaddr* addr, uint* length) override { + inner.getsockname(addr, length); + } + void getpeername(struct sockaddr* addr, uint* length) override { + inner.getpeername(addr, length); + } + + kj::Maybe getFd() const override { + return inner.getFd(); + } + +private: + SSL* ssl; + kj::AsyncIoStream& inner; + kj::Own ownInner; + + kj::Maybe> shutdownTask; + + ReadyInputStreamWrapper readBuffer; + ReadyOutputStreamWrapper writeBuffer; + + kj::Promise tryReadInternal( + void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyDone) { + return sslCall([this,buffer,maxBytes]() { return SSL_read(ssl, buffer, maxBytes); }) + .then([this,buffer,minBytes,maxBytes,alreadyDone](size_t n) -> kj::Promise { + if (n >= minBytes || n == 0) { + return alreadyDone + n; + } else { + return tryReadInternal(reinterpret_cast(buffer) + n, + minBytes - n, maxBytes - n, alreadyDone + n); + } + }); + } + + Promise writeInternal(kj::ArrayPtr first, + kj::ArrayPtr> rest) { + KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()"); + + // SSL_write() with a zero-sized input returns 0, but a 0 return is documented as indicating + // an error. So, we need to avoid zero-sized writes entirely. + while (first.size() == 0) { + if (rest.size() == 0) { + return kj::READY_NOW; + } + first = rest.front(); + rest = rest.slice(1, rest.size()); + } + + return sslCall([this,first]() { return SSL_write(ssl, first.begin(), first.size()); }) + .then([this,first,rest](size_t n) -> kj::Promise { + if (n == 0) { + return KJ_EXCEPTION(DISCONNECTED, "ssl connection ended during write"); + } else if (n < first.size()) { + return writeInternal(first.slice(n, first.size()), rest); + } else if (rest.size() > 0) { + return writeInternal(rest[0], rest.slice(1, rest.size())); + } else { + return kj::READY_NOW; + } + }); + } + + template + kj::Promise sslCall(Func&& func) { + auto result = func(); + + if (result > 0) { + return result; + } else { + int error = SSL_get_error(ssl, result); + switch (error) { + case SSL_ERROR_ZERO_RETURN: + return constPromise(); + case SSL_ERROR_WANT_READ: + return readBuffer.whenReady().then( + [this,func=kj::mv(func)]() mutable { return sslCall(kj::fwd(func)); }); + case SSL_ERROR_WANT_WRITE: + return writeBuffer.whenReady().then( + [this,func=kj::mv(func)]() mutable { return sslCall(kj::fwd(func)); }); + case SSL_ERROR_SSL: + return getOpensslError(); + case SSL_ERROR_SYSCALL: + if (result == 0) { + // OpenSSL pre-3.0 reports unexpected disconnects this way. Note that 3.0+ report it + // as SSL_ERROR_SSL with the reason SSL_R_UNEXPECTED_EOF_WHILE_READING, which is + // handled in throwOpensslError(). + return KJ_EXCEPTION(DISCONNECTED, + "peer disconnected without gracefully ending TLS session"); + } else { + // According to documentation we shouldn't get here, because our BIO never returns an + // "error". But in practice we do get here sometimes when the peer disconnects + // prematurely. + return KJ_EXCEPTION(DISCONNECTED, "SSL unable to continue I/O"); + } + default: + KJ_FAIL_ASSERT("unexpected SSL error code", error); + } + } + } + + static int bioRead(BIO* b, char* out, int outl) { + BIO_clear_retry_flags(b); + KJ_IF_MAYBE(n, reinterpret_cast(BIO_get_data(b))->readBuffer + .read(kj::arrayPtr(out, outl).asBytes())) { + return *n; + } else { + BIO_set_retry_read(b); + return -1; + } + } + + static int bioWrite(BIO* b, const char* in, int inl) { + BIO_clear_retry_flags(b); + KJ_IF_MAYBE(n, reinterpret_cast(BIO_get_data(b))->writeBuffer + .write(kj::arrayPtr(in, inl).asBytes())) { + return *n; + } else { + BIO_set_retry_write(b); + return -1; + } + } + + static long bioCtrl(BIO* b, int cmd, long num, void* ptr) { + switch (cmd) { + case BIO_CTRL_EOF: + return reinterpret_cast(BIO_get_data(b))->readBuffer.isAtEnd(); + case BIO_CTRL_FLUSH: + return 1; + case BIO_CTRL_PUSH: + case BIO_CTRL_POP: + // Informational? + return 0; +#ifdef BIO_CTRL_GET_KTLS_SEND + case BIO_CTRL_GET_KTLS_SEND: + case BIO_CTRL_GET_KTLS_RECV: + // TODO(someday): Support kTLS if the underlying stream is a raw socket. + return 0; +#endif + default: + KJ_LOG(WARNING, "unimplemented bio_ctrl", cmd); + return 0; + } + } + + static int bioCreate(BIO* b) { + BIO_set_data(b, nullptr); + return 1; + } + + static int bioDestroy(BIO* b) { + // The BIO does NOT own the TlsConnection. + return 1; + } + +#if OPENSSL_VERSION_NUMBER < 0x10100000L + static const BIO_METHOD* getBioVtable() { + static const BIO_METHOD VTABLE { + BIO_TYPE_SOURCE_SINK, + "KJ stream", + TlsConnection::bioWrite, + TlsConnection::bioRead, + nullptr, // puts + nullptr, // gets + TlsConnection::bioCtrl, + TlsConnection::bioCreate, + TlsConnection::bioDestroy, + nullptr + }; + return &VTABLE; + } +#else + static const BIO_METHOD* getBioVtable() { + static const BIO_METHOD* const vtable = makeBioVtable(); + return vtable; + } + static const BIO_METHOD* makeBioVtable() { + BIO_METHOD* vtable = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "KJ stream"); + BIO_meth_set_write(vtable, TlsConnection::bioWrite); + BIO_meth_set_read(vtable, TlsConnection::bioRead); + BIO_meth_set_ctrl(vtable, TlsConnection::bioCtrl); + BIO_meth_set_create(vtable, TlsConnection::bioCreate); + BIO_meth_set_destroy(vtable, TlsConnection::bioDestroy); + return vtable; + } +#endif +}; + +// ======================================================================================= +// Implementations of ConnectionReceiver, NetworkAddress, and Network as wrappers adding TLS. + +class TlsConnectionReceiver final: public ConnectionReceiver, public TaskSet::ErrorHandler { +public: + TlsConnectionReceiver( + TlsContext &tls, Own inner, + kj::Maybe acceptErrorHandler) + : tls(tls), inner(kj::mv(inner)), + acceptLoopTask(acceptLoop().eagerlyEvaluate([this](Exception &&e) { + onAcceptFailure(kj::mv(e)); + })), + acceptErrorHandler(kj::mv(acceptErrorHandler)), + tasks(*this) {} + + void taskFailed(Exception&& e) override { + KJ_IF_MAYBE(handler, acceptErrorHandler){ + handler->operator()(kj::mv(e)); + } else if (e.getType() != Exception::Type::DISCONNECTED) { + KJ_LOG(ERROR, "error accepting tls connection", kj::mv(e)); + } + }; + + Promise> accept() override { + return acceptAuthenticated().then([](AuthenticatedStream&& stream) { + return kj::mv(stream.stream); + }); + } + + Promise acceptAuthenticated() override { + KJ_IF_MAYBE(e, maybeInnerException) { + // We've experienced an exception from the inner receiver, we consider this unrecoverable. + return Exception(*e); + } + + return queue.pop(); + } + + uint getPort() override { + return inner->getPort(); + } + + void getsockopt(int level, int option, void* value, uint* length) override { + return inner->getsockopt(level, option, value, length); + } + + void setsockopt(int level, int option, const void* value, uint length) override { + return inner->setsockopt(level, option, value, length); + } + +private: + void onAcceptSuccess(AuthenticatedStream&& stream) { + // Queue this stream to go through SSL_accept. + + auto acceptPromise = kj::evalNow([&] { + // Do the SSL acceptance procedure. + return tls.wrapServer(kj::mv(stream)); + }); + + auto sslPromise = acceptPromise.then([this](auto&& stream) -> Promise { + // This is only attached to the success path, thus the error handler will catch if our + // promise fails. + queue.push(kj::mv(stream)); + return kj::READY_NOW; + }); + tasks.add(kj::mv(sslPromise)); + } + + void onAcceptFailure(Exception&& e) { + // Store this exception to reject all future calls to accept() and reject any unfulfilled + // promises from the queue. + maybeInnerException = kj::mv(e); + queue.rejectAll(Exception(KJ_REQUIRE_NONNULL(maybeInnerException))); + } + + Promise acceptLoop() { + // Accept one connection and queue up the next accept on our TaskSet. + + return inner->acceptAuthenticated().then( + [this](AuthenticatedStream&& stream) { + onAcceptSuccess(kj::mv(stream)); + + // Queue up the next accept loop immediately without waiting for SSL_accept()/wrapServer(). + return acceptLoop(); + }); + } + + TlsContext& tls; + Own inner; + + Promise acceptLoopTask; + ProducerConsumerQueue queue; + kj::Maybe acceptErrorHandler; + TaskSet tasks; + + Maybe maybeInnerException; +}; + +class TlsNetworkAddress final: public kj::NetworkAddress { +public: + TlsNetworkAddress(TlsContext& tls, kj::String hostname, kj::Own&& inner) + : tls(tls), hostname(kj::mv(hostname)), inner(kj::mv(inner)) {} + + Promise> connect() override { + // Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress + // as soon as connect() returns, and this works with the native network implementation. + // So, we make some copies here. + auto& tlsRef = tls; + auto hostnameCopy = kj::str(hostname); + return inner->connect().then( + [&tlsRef,hostname=kj::mv(hostnameCopy)](Own&& stream) { + return tlsRef.wrapClient(kj::mv(stream), hostname); + }); + } + + Promise connectAuthenticated() override { + // Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress + // as soon as connect() returns, and this works with the native network implementation. + // So, we make some copies here. + auto& tlsRef = tls; + auto hostnameCopy = kj::str(hostname); + return inner->connectAuthenticated().then( + [&tlsRef, hostname = kj::mv(hostnameCopy)](kj::AuthenticatedStream stream) { + return tlsRef.wrapClient(kj::mv(stream), hostname); + }); + } + + Own listen() override { + return tls.wrapPort(inner->listen()); + } + + Own clone() override { + return kj::heap(tls, kj::str(hostname), inner->clone()); + } + + String toString() override { + return kj::str("tls:", inner->toString()); + } + +private: + TlsContext& tls; + kj::String hostname; + kj::Own inner; +}; + +class TlsNetwork final: public kj::Network { +public: + TlsNetwork(TlsContext& tls, kj::Network& inner): tls(tls), inner(inner) {} + TlsNetwork(TlsContext& tls, kj::Own inner) + : tls(tls), inner(*inner), ownInner(kj::mv(inner)) {} + + Promise> parseAddress(StringPtr addr, uint portHint) override { + // We want to parse the hostname or IP address out of `addr`. This is a bit complicated as + // KJ's default network implementation has a fairly featureful grammar for these things. + // In particular, we cannot just split on ':' because the address might be IPv6. + + kj::String hostname; + + if (addr.startsWith("[")) { + // IPv6, like "[1234:5678::abcd]:123". Take the part between the brackets. + KJ_IF_MAYBE(pos, addr.findFirst(']')) { + hostname = kj::str(addr.slice(1, *pos)); + } else { + // Uhh??? Just take the whole thing, cert will fail later. + hostname = kj::heapString(addr); + } + } else if (addr.startsWith("unix:") || addr.startsWith("unix-abstract:")) { + // Unfortunately, `unix:123` is ambiguous (maybe there is a host named "unix"?), but the + // default KJ network implementation will interpret it as a Unix domain socket address. + // We don't want TLS to then try to authenticate that as a host named "unix". + KJ_FAIL_REQUIRE("can't authenticate Unix domain socket with TLS", addr); + } else { + uint colons = 0; + for (auto c: addr) { + if (c == ':') { + ++colons; + } + } + + if (colons >= 2) { + // Must be an IPv6 address. If it had a port, it would have been wrapped in []. So don't + // strip the port. + hostname = kj::heapString(addr); + } else { + // Assume host:port or ipv4:port. This is a shaky assumption, as the above hacks + // demonstrate. + // + // In theory it might make sense to extend the NetworkAddress interface so that it can tell + // us what the actual parser decided the hostname is. However, when I tried this it proved + // rather cumbersome and actually broke code in the Workers Runtime that does complicated + // stacking of kj::Network implementations. + KJ_IF_MAYBE(pos, addr.findFirst(':')) { + hostname = kj::heapString(addr.slice(0, *pos)); + } else { + hostname = kj::heapString(addr); + } + } + } + + return inner.parseAddress(addr, portHint) + .then([this, hostname=kj::mv(hostname)](kj::Own&& addr) mutable + -> kj::Own { + return kj::heap(tls, kj::mv(hostname), kj::mv(addr)); + }); + } + + Own getSockaddr(const void* sockaddr, uint len) override { + KJ_UNIMPLEMENTED("TLS does not implement getSockaddr() because it needs to know hostnames"); + } + + Own restrictPeers( + kj::ArrayPtr allow, + kj::ArrayPtr deny = nullptr) override { + // TODO(someday): Maybe we could implement the ability to specify CA or hostname restrictions? + // Or is it better to let people do that via the TlsContext? A neat thing about + // restrictPeers() is that it's easy to make user-configurable. + return kj::heap(tls, inner.restrictPeers(allow, deny)); + } + +private: + TlsContext& tls; + kj::Network& inner; + kj::Own ownInner; +}; + +// ======================================================================================= +// class TlsContext + +TlsContext::Options::Options() + : useSystemTrustStore(true), + verifyClients(false), + minVersion(TlsVersion::TLS_1_2), + cipherList("ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305") {} +// Cipher list is Mozilla's "intermediate" list, except with classic DH removed since we don't +// currently support setting dhparams. See: +// https://mozilla.github.io/server-side-tls/ssl-config-generator/ +// +// Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll +// never bother. + +struct TlsContext::SniCallback { + // struct SniCallback exists only so that callback() can be declared in the .c++ file, since it + // references OpenSSL types. + + static int callback(SSL* ssl, int* ad, void* arg); +}; + +TlsContext::TlsContext(Options options) { + ensureOpenSslInitialized(); + +#if OPENSSL_VERSION_NUMBER >= 0x10100000L || defined(OPENSSL_IS_BORINGSSL) + SSL_CTX* ctx = SSL_CTX_new(TLS_method()); +#else + SSL_CTX* ctx = SSL_CTX_new(SSLv23_method()); +#endif + + if (ctx == nullptr) { + throwOpensslError(); + } + KJ_ON_SCOPE_FAILURE(SSL_CTX_free(ctx)); + + // honor options.useSystemTrustStore + if (options.useSystemTrustStore) { + if (!SSL_CTX_set_default_verify_paths(ctx)) { + throwOpensslError(); + } + } + + // honor options.trustedCertificates + if (options.trustedCertificates.size() > 0) { + X509_STORE* store = SSL_CTX_get_cert_store(ctx); + if (store == nullptr) { + throwOpensslError(); + } + for (auto& cert: options.trustedCertificates) { + if (!X509_STORE_add_cert(store, reinterpret_cast(cert.chain[0]))) { + throwOpensslError(); + } + } + } + + if (options.verifyClients) { + SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); + } + + // honor options.minVersion + long optionFlags = 0; + if (options.minVersion > TlsVersion::SSL_3) { + optionFlags |= SSL_OP_NO_SSLv3; + } + if (options.minVersion > TlsVersion::TLS_1_0) { + optionFlags |= SSL_OP_NO_TLSv1; + } + if (options.minVersion > TlsVersion::TLS_1_1) { + optionFlags |= SSL_OP_NO_TLSv1_1; + } + if (options.minVersion > TlsVersion::TLS_1_2) { + optionFlags |= SSL_OP_NO_TLSv1_2; + } + if (options.minVersion > TlsVersion::TLS_1_3) { +#ifdef SSL_OP_NO_TLSv1_3 + optionFlags |= SSL_OP_NO_TLSv1_3; +#else + KJ_FAIL_REQUIRE("OpenSSL headers don't support TLS 1.3"); +#endif + } + SSL_CTX_set_options(ctx, optionFlags); // note: never fails; returns new options bitmask + + // honor options.cipherList + if (!SSL_CTX_set_cipher_list(ctx, options.cipherList.cStr())) { + throwOpensslError(); + } + + // honor options.defaultKeypair + KJ_IF_MAYBE(kp, options.defaultKeypair) { + if (!SSL_CTX_use_PrivateKey(ctx, reinterpret_cast(kp->privateKey.pkey))) { + throwOpensslError(); + } + + if (!SSL_CTX_use_certificate(ctx, reinterpret_cast(kp->certificate.chain[0]))) { + throwOpensslError(); + } + + for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) { + X509* x509 = reinterpret_cast(kp->certificate.chain[i]); + if (x509 == nullptr) break; // end of chain + + if (!SSL_CTX_add_extra_chain_cert(ctx, x509)) { + throwOpensslError(); + } + + // SSL_CTX_add_extra_chain_cert() does NOT up the refcount itself. + X509_up_ref(x509); + } + } + + // honor options.sniCallback + KJ_IF_MAYBE(sni, options.sniCallback) { + SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback); + SSL_CTX_set_tlsext_servername_arg(ctx, sni); + } + + KJ_IF_MAYBE(timeout, options.acceptTimeout) { + this->timer = KJ_REQUIRE_NONNULL(options.timer, + "acceptTimeout option requires that a timer is also provided"); + this->acceptTimeout = *timeout; + } + + this->acceptErrorHandler = kj::mv(options.acceptErrorHandler); + + this->ctx = ctx; +} + +int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) { + // The third parameter is actually type TlsSniCallback*. + + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + TlsSniCallback& sni = *reinterpret_cast(arg); + + const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (name != nullptr) { + KJ_IF_MAYBE(kp, sni.getKey(name)) { + if (!SSL_use_PrivateKey(ssl, reinterpret_cast(kp->privateKey.pkey))) { + throwOpensslError(); + } + + if (!SSL_use_certificate(ssl, reinterpret_cast(kp->certificate.chain[0]))) { + throwOpensslError(); + } + + if (!SSL_clear_chain_certs(ssl)) { + throwOpensslError(); + } + + for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) { + X509* x509 = reinterpret_cast(kp->certificate.chain[i]); + if (x509 == nullptr) break; // end of chain + + if (!SSL_add0_chain_cert(ssl, x509)) { + throwOpensslError(); + } + + // SSL_add0_chain_cert() does NOT up the refcount itself. + X509_up_ref(x509); + } + } + } + })) { + KJ_LOG(ERROR, "exception when invoking SNI callback", *exception); + *ad = SSL_AD_INTERNAL_ERROR; + return SSL_TLSEXT_ERR_ALERT_FATAL; + } + + return SSL_TLSEXT_ERR_OK; +} + +TlsContext::~TlsContext() noexcept(false) { + SSL_CTX_free(reinterpret_cast(ctx)); +} + +kj::Promise> TlsContext::wrapClient( + kj::Own stream, kj::StringPtr expectedServerHostname) { + auto conn = kj::heap(kj::mv(stream), reinterpret_cast(ctx)); + auto promise = conn->connect(expectedServerHostname); + return promise.then([conn=kj::mv(conn)]() mutable + -> kj::Own { + return kj::mv(conn); + }); +} + +kj::Promise> TlsContext::wrapServer(kj::Own stream) { + auto conn = kj::heap(kj::mv(stream), reinterpret_cast(ctx)); + auto promise = conn->accept(); + KJ_IF_MAYBE(timeout, acceptTimeout) { + promise = KJ_REQUIRE_NONNULL(timer).afterDelay(*timeout).then([]() -> kj::Promise { + return KJ_EXCEPTION(DISCONNECTED, "timed out waiting for client during TLS handshake"); + }).exclusiveJoin(kj::mv(promise)); + } + return promise.then([conn=kj::mv(conn)]() mutable + -> kj::Own { + return kj::mv(conn); + }); +} + +kj::Promise TlsContext::wrapClient( + kj::AuthenticatedStream stream, kj::StringPtr expectedServerHostname) { + auto conn = kj::heap(kj::mv(stream.stream), reinterpret_cast(ctx)); + auto promise = conn->connect(expectedServerHostname); + return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable { + auto id = conn->getIdentity(kj::mv(innerId)); + return kj::AuthenticatedStream { kj::mv(conn), kj::mv(id) }; + }); +} + +kj::Promise TlsContext::wrapServer(kj::AuthenticatedStream stream) { + auto conn = kj::heap(kj::mv(stream.stream), reinterpret_cast(ctx)); + auto promise = conn->accept(); + KJ_IF_MAYBE(timeout, acceptTimeout) { + promise = KJ_REQUIRE_NONNULL(timer).afterDelay(*timeout).then([]() -> kj::Promise { + return KJ_EXCEPTION(DISCONNECTED, "timed out waiting for client during TLS handshake"); + }).exclusiveJoin(kj::mv(promise)); + } + return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable { + auto id = conn->getIdentity(kj::mv(innerId)); + return kj::AuthenticatedStream { kj::mv(conn), kj::mv(id) }; + }); +} + +kj::Own TlsContext::wrapPort(kj::Own port) { + auto handler = acceptErrorHandler.map([](TlsErrorHandler& handler) { + return handler.reference(); + }); + return kj::heap(*this, kj::mv(port), kj::mv(handler)); +} + +kj::Own TlsContext::wrapAddress( + kj::Own address, kj::StringPtr expectedServerHostname) { + return kj::heap(*this, kj::str(expectedServerHostname), kj::mv(address)); +} + +kj::Own TlsContext::wrapNetwork(kj::Network& network) { + return kj::heap(*this, network); +} + +// ======================================================================================= +// class TlsPrivateKey + +TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr asn1) { + ensureOpenSslInitialized(); + + const byte* ptr = asn1.begin(); + pkey = d2i_AutoPrivateKey(nullptr, &ptr, asn1.size()); + if (pkey == nullptr) { + throwOpensslError(); + } +} + +TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe password) { + ensureOpenSslInitialized(); + + // const_cast apparently needed for older versions of OpenSSL. + BIO* bio = BIO_new_mem_buf(const_cast(pem.begin()), pem.size()); + KJ_DEFER(BIO_free(bio)); + + pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password); + if (pkey == nullptr) { + throwOpensslError(); + } +} + +TlsPrivateKey::TlsPrivateKey(const TlsPrivateKey& other) + : pkey(other.pkey) { + if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast(pkey)); +} + +TlsPrivateKey& TlsPrivateKey::operator=(const TlsPrivateKey& other) { + if (pkey != other.pkey) { + EVP_PKEY_free(reinterpret_cast(pkey)); + pkey = other.pkey; + if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast(pkey)); + } + return *this; +} + +TlsPrivateKey::~TlsPrivateKey() noexcept(false) { + EVP_PKEY_free(reinterpret_cast(pkey)); +} + +int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) { + auto& password = *reinterpret_cast*>(u); + + KJ_IF_MAYBE(p, password) { + int result = kj::min(p->size(), size); + memcpy(buf, p->begin(), result); + return result; + } else { + return 0; + } +} + +// ======================================================================================= +// class TlsCertificate + +TlsCertificate::TlsCertificate(kj::ArrayPtr> asn1) { + ensureOpenSslInitialized(); + + KJ_REQUIRE(asn1.size() > 0, "must provide at least one certificate in chain"); + KJ_REQUIRE(asn1.size() <= kj::size(chain), + "exceeded maximum certificate chain length of 10"); + + memset(chain, 0, sizeof(chain)); + + for (auto i: kj::indices(asn1)) { + auto p = asn1[i].begin(); + + // "_AUX" apparently refers to some auxiliary information that can be appended to the + // certificate, but should only be trusted for your own certificate, not the whole chain?? + // I don't really know, I'm just cargo-culting. + chain[i] = i == 0 ? d2i_X509_AUX(nullptr, &p, asn1[i].size()) + : d2i_X509(nullptr, &p, asn1[i].size()); + + if (chain[i] == nullptr) { + for (size_t j = 0; j < i; j++) { + X509_free(reinterpret_cast(chain[j])); + } + throwOpensslError(); + } + } +} + +TlsCertificate::TlsCertificate(kj::ArrayPtr asn1) + : TlsCertificate(kj::arrayPtr(&asn1, 1)) {} + +TlsCertificate::TlsCertificate(kj::StringPtr pem) { + ensureOpenSslInitialized(); + + memset(chain, 0, sizeof(chain)); + + // const_cast apparently needed for older versions of OpenSSL. + BIO* bio = BIO_new_mem_buf(const_cast(pem.begin()), pem.size()); + KJ_DEFER(BIO_free(bio)); + + for (auto i: kj::indices(chain)) { + // "_AUX" apparently refers to some auxiliary information that can be appended to the + // certificate, but should only be trusted for your own certificate, not the whole chain?? + // I don't really know, I'm just cargo-culting. + chain[i] = i == 0 ? PEM_read_bio_X509_AUX(bio, nullptr, nullptr, nullptr) + : PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); + + if (chain[i] == nullptr) { + auto error = ERR_peek_last_error(); + if (i > 0 && ERR_GET_LIB(error) == ERR_LIB_PEM && + ERR_GET_REASON(error) == PEM_R_NO_START_LINE) { + // EOF; we're done. + ERR_clear_error(); + return; + } else { + for (size_t j = 0; j < i; j++) { + X509_free(reinterpret_cast(chain[j])); + } + throwOpensslError(); + } + } + } + + // We reached the chain length limit. Try to read one more to verify that the chain ends here. + X509* dummy = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); + if (dummy != nullptr) { + X509_free(dummy); + for (auto i: kj::indices(chain)) { + X509_free(reinterpret_cast(chain[i])); + } + KJ_FAIL_REQUIRE("exceeded maximum certificate chain length of 10"); + } +} + +TlsCertificate::TlsCertificate(const TlsCertificate& other) { + memcpy(chain, other.chain, sizeof(chain)); + for (void* p: chain) { + if (p == nullptr) break; // end of chain; quit early + X509_up_ref(reinterpret_cast(p)); + } +} + +TlsCertificate& TlsCertificate::operator=(const TlsCertificate& other) { + for (auto i: kj::indices(chain)) { + if (chain[i] != other.chain[i]) { + EVP_PKEY_free(reinterpret_cast(chain[i])); + chain[i] = other.chain[i]; + if (chain[i] != nullptr) X509_up_ref(reinterpret_cast(chain[i])); + } else if (chain[i] == nullptr) { + // end of both chains; quit early + break; + } + } + return *this; +} + +TlsCertificate::~TlsCertificate() noexcept(false) { + for (void* p: chain) { + if (p == nullptr) break; // end of chain; quit early + X509_free(reinterpret_cast(p)); + } +} + +// ======================================================================================= +// class TlsPeerIdentity + +TlsPeerIdentity::~TlsPeerIdentity() noexcept(false) { + if (cert != nullptr) { + X509_free(reinterpret_cast(cert)); + } +} + +kj::String TlsPeerIdentity::toString() { + if (hasCertificate()) { + return getCommonName(); + } else { + return kj::str("(anonymous client)"); + } +} + +kj::String TlsPeerIdentity::getCommonName() { + if (cert == nullptr) { + KJ_FAIL_REQUIRE("client did not provide a certificate") { return nullptr; } + } + + X509_NAME* subj = X509_get_subject_name(reinterpret_cast(cert)); + + int index = X509_NAME_get_index_by_NID(subj, NID_commonName, -1); + KJ_ASSERT(index != -1, "certificate has no common name?"); + X509_NAME_ENTRY* entry = X509_NAME_get_entry(subj, index); + KJ_ASSERT(entry != nullptr); + ASN1_STRING* data = X509_NAME_ENTRY_get_data(entry); + KJ_ASSERT(data != nullptr); + + unsigned char* out = nullptr; + int len = ASN1_STRING_to_UTF8(&out, data); + KJ_ASSERT(len >= 0); + KJ_DEFER(OPENSSL_free(out)); + + return kj::heapString(reinterpret_cast(out), len); +} + +} // namespace kj + +#endif // KJ_HAS_OPENSSL diff --git a/c++/src/kj/compat/tls.h b/c++/src/kj/compat/tls.h new file mode 100644 index 0000000000..f78a23c999 --- /dev/null +++ b/c++/src/kj/compat/tls.h @@ -0,0 +1,309 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once +// This file implements TLS (aka SSL) encrypted networking. It is actually a wrapper, currently +// around OpenSSL / BoringSSL / LibreSSL, but the interface is intended to remain +// implementation-agnostic. +// +// Unlike OpenSSL's API, the API defined in this file is intended to be hard to use wrong. Good +// ciphers and settings are used by default. Certificates validation is performed automatically +// and cannot be bypassed. + +#include + +KJ_BEGIN_HEADER + +namespace kj { + +class TlsPrivateKey; +class TlsCertificate; +struct TlsKeypair; +class TlsSniCallback; +class TlsConnection; + +enum class TlsVersion { + SSL_3, // avoid; cryptographically broken + TLS_1_0, // avoid; cryptographically weak + TLS_1_1, // avoid; cryptographically weak + TLS_1_2, + TLS_1_3 +}; + +using TlsErrorHandler = kj::Function; +// Use a simple kj::Function for handling errors during parallel accept(). + +class TlsContext: public kj::SecureNetworkWrapper { + // TLS system. Allocate one of these, configure it with the proper keys and certificates (or + // use the defaults), and then use it to wrap the standard KJ network interfaces in + // implementations that transparently use TLS. + +public: + + struct Options { + Options(); + // Initializes all values to reasonable defaults. + + KJ_DISALLOW_COPY(Options); + Options(Options&&) = default; + Options& operator=(Options&&) = default; + // Options is a move-only value type. + + bool useSystemTrustStore; + // Whether or not to trust the system's default trust store. Default: true. + + bool verifyClients; + // If true, when acting as a server, require the client to present a certificate. The + // certificate must be signed by one of the trusted CAs, otherwise the client will be rejected. + // (Typically you should set `useSystemTrustStore` false when using this flag, and specify + // your specific trusted CAs in `trustedCertificates`.) + // Default: false + + kj::ArrayPtr trustedCertificates; + // Additional certificates which should be trusted. Default: none. + + TlsVersion minVersion; + // Minimum version. Defaults to minimum version that hasn't been cryptographically broken. + // If you override this, consider doing: + // + // options.minVersion = kj::max(myVersion, options.minVersion); + + kj::StringPtr cipherList; + // OpenSSL cipher list string. The default is a curated list designed to be compatible with + // almost all software in current use (specifically, based on Mozilla's "intermediate" + // recommendations). The defaults will change in future versions of this library to account + // for the latest cryptanalysis. + // + // Generally you should only specify your own `cipherList` if: + // - You have extreme backwards-compatibility needs and wish to enable obsolete and/or broken + // algorithms. + // - You need quickly to disable an algorithm recently discovered to be broken. + + kj::Maybe defaultKeypair; + // Default keypair to use for all connections. Required for servers; optional for clients. + + kj::Maybe sniCallback; + // Callback that can be used to choose a different key/certificate based on the specific + // hostname requested by the client. + + kj::Maybe timer; + // The timer used for `acceptTimeout` below. + + kj::Maybe acceptTimeout; + // Timeout applied to accepting a new TLS connection. `timer` is required if this is set. + + kj::Maybe acceptErrorHandler; + // Error handler used for TLS accept errors. + }; + + TlsContext(Options options = Options()); + ~TlsContext() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(TlsContext); + + kj::Promise> wrapServer(kj::Own stream); + // Upgrade a regular network stream to TLS and begin the initial handshake as the server. The + // returned promise resolves when the handshake has completed successfully. + + kj::Promise> wrapClient( + kj::Own stream, kj::StringPtr expectedServerHostname); + // Upgrade a regular network stream to TLS and begin the initial handshake as a client. The + // returned promise resolves when the handshake has completed successfully, including validating + // the server's certificate. + // + // You must specify the server's hostname. This is used for two purposes: + // 1. It is sent to the server in the initial handshake via the TLS SNI extension, so that a + // server serving multiple hosts knows which certificate to use. + // 2. The server's certificate is validated against this hostname. If validation fails, the + // promise returned by wrapClient() will be broken; you'll never get a stream. + + kj::Promise wrapServer(kj::AuthenticatedStream stream); + kj::Promise wrapClient( + kj::AuthenticatedStream stream, kj::StringPtr expectedServerHostname); + // Like wrapServer() and wrapClient(), but also produces information about the peer's + // certificate (if any). The returned `peerIdentity` will be a `TlsPeerIdentity`. + + kj::Own wrapPort(kj::Own port); + // Upgrade a ConnectionReceiver to one that automatically upgrades all accepted connections to + // TLS (acting as the server). + + kj::Own wrapAddress( + kj::Own address, kj::StringPtr expectedServerHostname); + // Upgrade a NetworkAddress to one that automatically upgrades all connections to TLS, acting + // as the client when `connect()` is called or the server if `listen()` is called. + // `connect()` will athenticate the server as `expectedServerHostname`. + + kj::Own wrapNetwork(kj::Network& network); + // Upgrade a Network to one that automatically upgrades all connections to TLS. The network will + // only accept addresses of the form "hostname" and "hostname:port" (it does not accept raw IP + // addresses). It will automatically use SNI and verify certificates based on these hostnames. + +private: + void* ctx; // actually type SSL_CTX, but we don't want to #include the OpenSSL headers here + kj::Maybe timer; + kj::Maybe acceptTimeout; + kj::Maybe acceptErrorHandler; + + struct SniCallback; +}; + +class TlsPrivateKey { + // A private key suitable for use in a TLS server. + +public: + TlsPrivateKey(kj::ArrayPtr asn1); + // Parse a single binary (ASN1) private key. Supports PKCS8 keys as well as "traditional format" + // RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to + // decrypt. + + TlsPrivateKey(kj::StringPtr pem, kj::Maybe password = nullptr); + // Parse a single PEM-encoded private key. Supports PKCS8 keys as well as "traditional format" + // RSA and DSA keys. A password may optionally be provided and will be used if the key is + // encrypted. + + ~TlsPrivateKey() noexcept(false); + + TlsPrivateKey(const TlsPrivateKey& other); + TlsPrivateKey& operator=(const TlsPrivateKey& other); + // Copy-by-refcount. + + inline TlsPrivateKey(TlsPrivateKey&& other): pkey(other.pkey) { other.pkey = nullptr; } + inline TlsPrivateKey& operator=(TlsPrivateKey&& other) { + pkey = other.pkey; other.pkey = nullptr; + return *this; + } + +private: + void* pkey; // actually type EVP_PKEY* + + friend class TlsContext; + + static int passwordCallback(char* buf, int size, int rwflag, void* u); +}; + +class TlsCertificate { + // A TLS certificate, possibly with chained intermediate certificates. + +public: + TlsCertificate(kj::ArrayPtr asn1); + // Parse a single binary (ASN1) X509 certificate. + + TlsCertificate(kj::ArrayPtr> asn1); + // Parse a chain of binary (ASN1) X509 certificates. + + TlsCertificate(kj::StringPtr pem); + // Parse a PEM-encode X509 certificate or certificate chain. A chain can be constructed by + // concatenating multiple PEM-encoded certificates, starting with the leaf certificate. + + ~TlsCertificate() noexcept(false); + + TlsCertificate(const TlsCertificate& other); + TlsCertificate& operator=(const TlsCertificate& other); + // Copy-by-refcount. + + inline TlsCertificate(TlsCertificate&& other) { + memcpy(chain, other.chain, sizeof(chain)); + memset(other.chain, 0, sizeof(chain)); + } + inline TlsCertificate& operator=(TlsCertificate&& other) { + memcpy(chain, other.chain, sizeof(chain)); + memset(other.chain, 0, sizeof(chain)); + return *this; + } + +private: + void* chain[10]; + // Actually type X509*[10]. + // + // Note that OpenSSL has a default maximum cert chain length of 10. Although configurable at + // runtime, you'd actually have to convince the _peer_ to reconfigure, which is unlikely except + // in specific use cases. So to avoid excess allocations we just assume a max of 10 certs. + // + // If this proves to be a problem, we should maybe use STACK_OF(X509) here, but stacks are not + // refcounted -- the X509_chain_up_ref() function actually allocates a new stack and uprefs all + // the certs. + + friend class TlsContext; +}; + +struct TlsKeypair { + // A pair of a private key and a certificate, for use by a server. + + TlsPrivateKey privateKey; + TlsCertificate certificate; +}; + +class TlsSniCallback { + // Callback object to implement Server Name Indication, in which the server is able to decide + // what key and certificate to use based on the hostname that the client is requesting. + // + // TODO(someday): Currently this callback is synchronous, because the OpenSSL API seems to be + // synchronous. Other people (e.g. Node) have figured out how to do it asynchronously, but + // it's unclear to me if and how this is possible while using the OpenSSL APIs. It looks like + // Node may be manually parsing the ClientHello message rather than relying on OpenSSL. We + // could do that but it's too much work for today. + +public: + virtual kj::Maybe getKey(kj::StringPtr hostname) = 0; + // Get the key to use for `hostname`. Null return means use the default from + // TlsContext::Options::defaultKeypair. +}; + +class TlsPeerIdentity final: public kj::PeerIdentity { +public: + KJ_DISALLOW_COPY_AND_MOVE(TlsPeerIdentity); + ~TlsPeerIdentity() noexcept(false); + + kj::String toString() override; + + kj::PeerIdentity& getNetworkIdentity() { return *inner; } + // Gets the PeerIdentity of the underlying network connection. + + bool hasCertificate() { return cert != nullptr; } + // Did the peer even present a (trusted) certificate? Servers must always present certificates. + // Clients need only present certificates when the `verifyClients` option is enabled. + // + // Methods of this class that read details of the certificate will throw exceptions when no + // certificate was presented. We don't have them return `Maybe`s because most applications know + // in advance whether or not a certificate should be present, so it would lead to lots of + // `KJ_ASSERT_NONNULL`... + + kj::String getCommonName(); + // Get the authenticated common name from the certificate. + + bool matchesHostname(kj::StringPtr hostname); + // Check if the certificate authenticates the given hostname, considering wildcards and SAN + // extensions. If no certificate was provided, always returns false. + + // TODO(someday): Methods for other things. Match hostnames (i.e. evaluate wildcards and SAN)? + // Key fingerprint? Other certificate fields? + +private: + void* cert; // actually type X509*, but we don't want to #include the OpenSSL headers here. + kj::Own inner; + +public: // (not really public, only TlsConnection can call this) + TlsPeerIdentity(void* cert, kj::Own inner, kj::Badge) + : cert(cert), inner(kj::mv(inner)) {} +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/compat/url-test.c++ b/c++/src/kj/compat/url-test.c++ new file mode 100644 index 0000000000..d5a437a9c4 --- /dev/null +++ b/c++/src/kj/compat/url-test.c++ @@ -0,0 +1,560 @@ +// Copyright (c) 2017 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "url.h" +#include +#include + +namespace kj { +namespace { + +Url parseAndCheck(kj::StringPtr originalText, kj::StringPtr expectedRestringified = nullptr, + Url::Options options = {}) { + if (expectedRestringified == nullptr) expectedRestringified = originalText; + auto url = Url::parse(originalText, Url::REMOTE_HREF, options); + KJ_EXPECT(kj::str(url) == expectedRestringified, url, originalText, expectedRestringified); + // Make sure clones also restringify to the expected string. + auto clone = url.clone(); + KJ_EXPECT(kj::str(clone) == expectedRestringified, clone, originalText, expectedRestringified); + return url; +} + +static constexpr Url::Options NO_DECODE { + false, // percentDecode + false, // allowEmpty +}; + +static constexpr Url::Options ALLOW_EMPTY { + true, // percentDecode + true, // allowEmpty +}; + +KJ_TEST("parse / stringify URL") { + { + auto url = parseAndCheck("https://capnproto.org"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + auto url = parseAndCheck("https://capnproto.org:80"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org:80"); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + auto url = parseAndCheck("https://capnproto.org/"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar/"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar?baz=qux&corge#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 2); + KJ_EXPECT(url.query[0].name == "baz"); + KJ_EXPECT(url.query[0].value == "qux"); + KJ_EXPECT(url.query[1].name == "corge"); + KJ_EXPECT(url.query[1].value == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar?baz=qux&corge=#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 2); + KJ_EXPECT(url.query[0].name == "baz"); + KJ_EXPECT(url.query[0].value == "qux"); + KJ_EXPECT(url.query[1].name == "corge"); + KJ_EXPECT(url.query[1].value == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar?baz=&corge=grault#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 2); + KJ_EXPECT(url.query[0].name == "baz"); + KJ_EXPECT(url.query[0].value == ""); + KJ_EXPECT(url.query[1].name == "corge"); + KJ_EXPECT(url.query[1].value == "grault"); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar/?baz=qux&corge=grault#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 2); + KJ_EXPECT(url.query[0].name == "baz"); + KJ_EXPECT(url.query[0].value == "qux"); + KJ_EXPECT(url.query[1].name == "corge"); + KJ_EXPECT(url.query[1].value == "grault"); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar?baz=qux#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 1); + KJ_EXPECT(url.query[0].name == "baz"); + KJ_EXPECT(url.query[0].value == "qux"); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo?bar%20baz=qux+quux", + "https://capnproto.org/foo?bar+baz=qux+quux"); + KJ_ASSERT(url.query.size() == 1); + KJ_EXPECT(url.query[0].name == "bar baz"); + KJ_EXPECT(url.query[0].value == "qux quux"); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo/bar/#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://capnproto.org#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://capnproto.org/#garply"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(url.fragment) == "garply"); + } + + { + auto url = parseAndCheck("https://foo@capnproto.org"); + KJ_EXPECT(url.scheme == "https"); + auto& user = KJ_ASSERT_NONNULL(url.userInfo); + KJ_EXPECT(user.username == "foo"); + KJ_EXPECT(user.password == nullptr); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + auto url = parseAndCheck("https://$foo&:12+,34@capnproto.org"); + KJ_EXPECT(url.scheme == "https"); + auto& user = KJ_ASSERT_NONNULL(url.userInfo); + KJ_EXPECT(user.username == "$foo&"); + KJ_EXPECT(KJ_ASSERT_NONNULL(user.password) == "12+,34"); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + auto url = parseAndCheck("https://[2001:db8::1234]:80/foo"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.userInfo == nullptr); + KJ_EXPECT(url.host == "[2001:db8::1234]:80"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + auto url = parseAndCheck("https://capnproto.org/foo%2Fbar/baz"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo/bar", "baz"})); + } + + parseAndCheck("https://capnproto.org/foo/bar?", "https://capnproto.org/foo/bar"); + parseAndCheck("https://capnproto.org/foo/bar?#", "https://capnproto.org/foo/bar#"); + parseAndCheck("https://capnproto.org/foo/bar#"); + + // Scheme and host are forced to lower-case. + parseAndCheck("hTtP://capNprotO.org/fOo/bAr", "http://capnproto.org/fOo/bAr"); + + // URLs with underscores in their hostnames are allowed, but you probably shouldn't use them. They + // are not valid domain names. + parseAndCheck("https://bad_domain.capnproto.org/"); + + // Make sure URLs with %-encoded '%' signs in their userinfo, path, query, and fragment components + // get correctly re-encoded. + parseAndCheck("https://foo%25bar:baz%25qux@capnproto.org/"); + parseAndCheck("https://capnproto.org/foo%25bar"); + parseAndCheck("https://capnproto.org/?foo%25bar=baz%25qux"); + parseAndCheck("https://capnproto.org/#foo%25bar"); + + // Make sure redundant /'s and &'s aren't collapsed when options.removeEmpty is false. + parseAndCheck("https://capnproto.org/foo//bar///test//?foo=bar&&baz=qux&", nullptr, ALLOW_EMPTY); + + // "." and ".." are still processed, though. + parseAndCheck("https://capnproto.org/foo//../bar/.", + "https://capnproto.org/foo/bar/", ALLOW_EMPTY); + + { + auto url = parseAndCheck("https://foo/", nullptr, ALLOW_EMPTY); + KJ_EXPECT(url.path.size() == 0); + KJ_EXPECT(url.hasTrailingSlash); + } + + { + auto url = parseAndCheck("https://foo/bar/", nullptr, ALLOW_EMPTY); + KJ_EXPECT(url.path.size() == 1); + KJ_EXPECT(url.hasTrailingSlash); + } +} + +KJ_TEST("URL percent encoding") { + parseAndCheck( + "https://b%6fb:%61bcd@capnpr%6fto.org/f%6fo?b%61r=b%61z#q%75x", + "https://bob:abcd@capnproto.org/foo?bar=baz#qux"); + + parseAndCheck( + "https://b\001b:\001bcd@capnproto.org/f\001o?b\001r=b\001z#q\001x", + "https://b%01b:%01bcd@capnproto.org/f%01o?b%01r=b%01z#q%01x"); + + parseAndCheck( + "https://b b: bcd@capnproto.org/f o?b r=b z#q x", + "https://b%20b:%20bcd@capnproto.org/f%20o?b+r=b+z#q%20x"); + + parseAndCheck( + "https://capnproto.org/foo?bar=baz#@?#^[\\]{|}", + "https://capnproto.org/foo?bar=baz#@?#^[\\]{|}"); + + // All permissible non-alphanumeric, non-separator path characters. + parseAndCheck( + "https://capnproto.org/!$&'()*+,-.:;=@[]^_|~", + "https://capnproto.org/!$&'()*+,-.:;=@[]^_|~"); +} + +KJ_TEST("parse / stringify URL w/o decoding") { + { + auto url = parseAndCheck("https://capnproto.org/foo%2Fbar/baz", nullptr, NO_DECODE); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo%2Fbar", "baz"})); + } + + { + // This case would throw an exception without NO_DECODE. + Url url = parseAndCheck("https://capnproto.org/R%20%26%20S?%foo=%QQ", nullptr, NO_DECODE); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"R%20%26%20S"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 1); + KJ_EXPECT(url.query[0].name == "%foo"); + KJ_EXPECT(url.query[0].value == "%QQ"); + } +} + +KJ_TEST("URL relative paths") { + parseAndCheck( + "https://capnproto.org/foo//bar", + "https://capnproto.org/foo/bar"); + + parseAndCheck( + "https://capnproto.org/foo/./bar", + "https://capnproto.org/foo/bar"); + + parseAndCheck( + "https://capnproto.org/foo/bar//", + "https://capnproto.org/foo/bar/"); + + parseAndCheck( + "https://capnproto.org/foo/bar/.", + "https://capnproto.org/foo/bar/"); + + parseAndCheck( + "https://capnproto.org/foo/baz/../bar", + "https://capnproto.org/foo/bar"); + + parseAndCheck( + "https://capnproto.org/foo/bar/baz/..", + "https://capnproto.org/foo/bar/"); + + parseAndCheck( + "https://capnproto.org/..", + "https://capnproto.org/"); + + parseAndCheck( + "https://capnproto.org/foo/../..", + "https://capnproto.org/"); +} + +KJ_TEST("URL for HTTP request") { + { + Url url = Url::parse("https://bob:1234@capnproto.org/foo/bar?baz=qux#corge"); + KJ_EXPECT(url.toString(Url::REMOTE_HREF) == + "https://bob:1234@capnproto.org/foo/bar?baz=qux#corge"); + KJ_EXPECT(url.toString(Url::HTTP_PROXY_REQUEST) == "https://capnproto.org/foo/bar?baz=qux"); + KJ_EXPECT(url.toString(Url::HTTP_REQUEST) == "/foo/bar?baz=qux"); + } + + { + Url url = Url::parse("https://capnproto.org"); + KJ_EXPECT(url.toString(Url::REMOTE_HREF) == "https://capnproto.org"); + KJ_EXPECT(url.toString(Url::HTTP_PROXY_REQUEST) == "https://capnproto.org"); + KJ_EXPECT(url.toString(Url::HTTP_REQUEST) == "/"); + } + + { + Url url = Url::parse("/foo/bar?baz=qux&corge", Url::HTTP_REQUEST); + KJ_EXPECT(url.scheme == nullptr); + KJ_EXPECT(url.host == nullptr); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 2); + KJ_EXPECT(url.query[0].name == "baz"); + KJ_EXPECT(url.query[0].value == "qux"); + KJ_EXPECT(url.query[1].name == "corge"); + KJ_EXPECT(url.query[1].value == nullptr); + } + + { + Url url = Url::parse("https://capnproto.org/foo/bar?baz=qux&corge", Url::HTTP_PROXY_REQUEST); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr({"foo", "bar"})); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 2); + KJ_EXPECT(url.query[0].name == "baz"); + KJ_EXPECT(url.query[0].value == "qux"); + KJ_EXPECT(url.query[1].name == "corge"); + KJ_EXPECT(url.query[1].value == nullptr); + } + + { + // '#' is allowed in path components in the HTTP_REQUEST context. + Url url = Url::parse("/foo#bar", Url::HTTP_REQUEST); + KJ_EXPECT(url.toString(Url::HTTP_REQUEST) == "/foo%23bar"); + KJ_EXPECT(url.scheme == nullptr); + KJ_EXPECT(url.host == nullptr); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr{"foo#bar"}); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + // '#' is allowed in path components in the HTTP_PROXY_REQUEST context. + Url url = Url::parse("https://capnproto.org/foo#bar", Url::HTTP_PROXY_REQUEST); + KJ_EXPECT(url.toString(Url::HTTP_PROXY_REQUEST) == "https://capnproto.org/foo%23bar"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path.asPtr() == kj::ArrayPtr{"foo#bar"}); + KJ_EXPECT(!url.hasTrailingSlash); + KJ_EXPECT(url.query == nullptr); + KJ_EXPECT(url.fragment == nullptr); + } + + { + // '#' is allowed in query components in the HTTP_REQUEST context. + Url url = Url::parse("/?foo=bar#123", Url::HTTP_REQUEST); + KJ_EXPECT(url.toString(Url::HTTP_REQUEST) == "/?foo=bar%23123"); + KJ_EXPECT(url.scheme == nullptr); + KJ_EXPECT(url.host == nullptr); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 1); + KJ_EXPECT(url.query[0].name == "foo"); + KJ_EXPECT(url.query[0].value == "bar#123"); + KJ_EXPECT(url.fragment == nullptr); + } + + { + // '#' is allowed in query components in the HTTP_PROXY_REQUEST context. + Url url = Url::parse("https://capnproto.org/?foo=bar#123", Url::HTTP_PROXY_REQUEST); + KJ_EXPECT(url.toString(Url::HTTP_PROXY_REQUEST) == "https://capnproto.org/?foo=bar%23123"); + KJ_EXPECT(url.scheme == "https"); + KJ_EXPECT(url.host == "capnproto.org"); + KJ_EXPECT(url.path == nullptr); + KJ_EXPECT(url.hasTrailingSlash); + KJ_ASSERT(url.query.size() == 1); + KJ_EXPECT(url.query[0].name == "foo"); + KJ_EXPECT(url.query[0].value == "bar#123"); + KJ_EXPECT(url.fragment == nullptr); + } +} + +KJ_TEST("parse URL failure") { + KJ_EXPECT(Url::tryParse("ht/tps://capnproto.org") == nullptr); + KJ_EXPECT(Url::tryParse("capnproto.org") == nullptr); + KJ_EXPECT(Url::tryParse("https:foo") == nullptr); + + // percent-decode errors + KJ_EXPECT(Url::tryParse("https://capnproto.org/f%nno") == nullptr); + KJ_EXPECT(Url::tryParse("https://capnproto.org/foo?b%nnr=baz") == nullptr); + + // components not valid in context + KJ_EXPECT(Url::tryParse("https://capnproto.org/foo", Url::HTTP_REQUEST) == nullptr); + KJ_EXPECT(Url::tryParse("https://bob:123@capnproto.org/foo", Url::HTTP_PROXY_REQUEST) == nullptr); + KJ_EXPECT(Url::tryParse("https://capnproto.org#/foo", Url::HTTP_PROXY_REQUEST) == nullptr); +} + +void parseAndCheckRelative(kj::StringPtr base, kj::StringPtr relative, kj::StringPtr expected, + Url::Options options = {}) { + auto parsed = Url::parse(base, Url::REMOTE_HREF, options).parseRelative(relative); + KJ_EXPECT(kj::str(parsed) == expected, parsed, expected); + auto clone = parsed.clone(); + KJ_EXPECT(kj::str(clone) == expected, clone, expected); +} + +KJ_TEST("parse relative URL") { + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "#grault", + "https://capnproto.org/foo/bar?baz=qux#grault"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz#corge", + "#grault", + "https://capnproto.org/foo/bar?baz#grault"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=#corge", + "#grault", + "https://capnproto.org/foo/bar?baz=#grault"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "?grault", + "https://capnproto.org/foo/bar?grault"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "?grault=", + "https://capnproto.org/foo/bar?grault="); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "?grault+garply=waldo", + "https://capnproto.org/foo/bar?grault+garply=waldo"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "grault", + "https://capnproto.org/foo/grault"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "/grault", + "https://capnproto.org/grault"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "//grault", + "https://grault"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "//grault/garply", + "https://grault/garply"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "http:/grault", + "http://capnproto.org/grault"); + parseAndCheckRelative("https://capnproto.org/foo/bar?baz=qux#corge", + "/http:/grault", + "https://capnproto.org/http:/grault"); + parseAndCheckRelative("https://capnproto.org/", + "/foo/../bar", + "https://capnproto.org/bar"); +} + +KJ_TEST("parse relative URL w/o decoding") { + // This case would throw an exception without NO_DECODE. + parseAndCheckRelative("https://capnproto.org/R%20%26%20S?%foo=%QQ", + "%ANOTH%ERBAD%URL", + "https://capnproto.org/%ANOTH%ERBAD%URL", NO_DECODE); +} + +KJ_TEST("parse relative URL failure") { + auto base = Url::parse("https://example.com/"); + KJ_EXPECT(base.tryParseRelative("https://[not a host]") == nullptr); +} + +} // namespace +} // namespace kj diff --git a/c++/src/kj/compat/url.c++ b/c++/src/kj/compat/url.c++ new file mode 100644 index 0000000000..e7626761fc --- /dev/null +++ b/c++/src/kj/compat/url.c++ @@ -0,0 +1,495 @@ +// Copyright (c) 2017 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "url.h" +#include +#include +#include +#include + +namespace kj { + +namespace { + +constexpr auto ALPHAS = parse::charRange('a', 'z').orRange('A', 'Z'); +constexpr auto DIGITS = parse::charRange('0', '9'); + +constexpr auto END_AUTHORITY = parse::anyOfChars("/?#"); + +// Authority, path, and query components can typically be terminated by the start of a fragment. +// However, fragments are disallowed in HTTP_REQUEST and HTTP_PROXY_REQUEST contexts. As a quirk, we +// allow the fragment start character ('#') to live unescaped in path and query components. We do +// not currently allow it in the authority component, because our parser would reject it as a host +// character anyway. + +const parse::CharGroup_& getEndPathPart(Url::Context context) { + static constexpr auto END_PATH_PART_HREF = parse::anyOfChars("/?#"); + static constexpr auto END_PATH_PART_REQUEST = parse::anyOfChars("/?"); + + switch (context) { + case Url::REMOTE_HREF: return END_PATH_PART_HREF; + case Url::HTTP_PROXY_REQUEST: return END_PATH_PART_REQUEST; + case Url::HTTP_REQUEST: return END_PATH_PART_REQUEST; + } + + KJ_UNREACHABLE; +} + +const parse::CharGroup_& getEndQueryPart(Url::Context context) { + static constexpr auto END_QUERY_PART_HREF = parse::anyOfChars("&#"); + static constexpr auto END_QUERY_PART_REQUEST = parse::anyOfChars("&"); + + switch (context) { + case Url::REMOTE_HREF: return END_QUERY_PART_HREF; + case Url::HTTP_PROXY_REQUEST: return END_QUERY_PART_REQUEST; + case Url::HTTP_REQUEST: return END_QUERY_PART_REQUEST; + } + + KJ_UNREACHABLE; +} + +constexpr auto SCHEME_CHARS = ALPHAS.orGroup(DIGITS).orAny("+-."); +constexpr auto NOT_SCHEME_CHARS = SCHEME_CHARS.invert(); + +constexpr auto HOST_CHARS = ALPHAS.orGroup(DIGITS).orAny(".-:[]_"); +// [] is for ipv6 literals. +// _ is not allowed in domain names, but the WHATWG URL spec allows it in hostnames, so we do, too. +// TODO(someday): The URL spec actually allows a lot more than just '_', and requires nameprepping +// to Punycode. We'll have to decide how we want to deal with all that. + +void toLower(String& text) { + for (char& c: text) { + if ('A' <= c && c <= 'Z') { + c += 'a' - 'A'; + } + } +} + +Maybe> trySplit(StringPtr& text, char c) { + KJ_IF_MAYBE(pos, text.findFirst(c)) { + ArrayPtr result = text.slice(0, *pos); + text = text.slice(*pos + 1); + return result; + } else { + return nullptr; + } +} + +Maybe> trySplit(ArrayPtr& text, char c) { + for (auto i: kj::indices(text)) { + if (text[i] == c) { + ArrayPtr result = text.slice(0, i); + text = text.slice(i + 1, text.size()); + return result; + } + } + return nullptr; +} + +ArrayPtr split(StringPtr& text, const parse::CharGroup_& chars) { + for (auto i: kj::indices(text)) { + if (chars.contains(text[i])) { + ArrayPtr result = text.slice(0, i); + text = text.slice(i); + return result; + } + } + auto result = text.asArray(); + text = ""; + return result; +} + +String percentDecode(ArrayPtr text, bool& hadErrors, const Url::Options& options) { + if (options.percentDecode) { + auto result = decodeUriComponent(text); + if (result.hadErrors) hadErrors = true; + return kj::mv(result); + } + return kj::str(text); +} + +String percentDecodeQuery(ArrayPtr text, bool& hadErrors, const Url::Options& options) { + if (options.percentDecode) { + auto result = decodeWwwForm(text); + if (result.hadErrors) hadErrors = true; + return kj::mv(result); + } + return kj::str(text); +} + +} // namespace + +Url::~Url() noexcept(false) {} + +Url Url::clone() const { + return { + kj::str(scheme), + userInfo.map([](const UserInfo& ui) -> UserInfo { + return { + kj::str(ui.username), + ui.password.map([](const String& s) { return kj::str(s); }) + }; + }), + kj::str(host), + KJ_MAP(part, path) { return kj::str(part); }, + hasTrailingSlash, + KJ_MAP(param, query) -> QueryParam { + // Preserve the "allocated-ness" of `param.value` with this careful copy. + return { kj::str(param.name), param.value.begin() == nullptr ? kj::String() + : kj::str(param.value) }; + }, + fragment.map([](const String& s) { return kj::str(s); }), + options + }; +} + +Url Url::parse(StringPtr url, Context context, Options options) { + return KJ_REQUIRE_NONNULL(tryParse(url, context, options), "invalid URL", url); +} + +Maybe Url::tryParse(StringPtr text, Context context, Options options) { + Url result; + result.options = options; + bool err = false; // tracks percent-decoding errors + + auto& END_PATH_PART = getEndPathPart(context); + auto& END_QUERY_PART = getEndQueryPart(context); + + if (context == HTTP_REQUEST) { + if (!text.startsWith("/")) { + return nullptr; + } + } else { + KJ_IF_MAYBE(scheme, trySplit(text, ':')) { + result.scheme = kj::str(*scheme); + } else { + // missing scheme + return nullptr; + } + toLower(result.scheme); + if (result.scheme.size() == 0 || + !ALPHAS.contains(result.scheme[0]) || + !SCHEME_CHARS.containsAll(result.scheme.slice(1))) { + // bad scheme + return nullptr; + } + + if (!text.startsWith("//")) { + // We require an authority (hostname) part. + return nullptr; + } + text = text.slice(2); + + { + auto authority = split(text, END_AUTHORITY); + + KJ_IF_MAYBE(userpass, trySplit(authority, '@')) { + if (context != REMOTE_HREF) { + // No user/pass allowed here. + return nullptr; + } + KJ_IF_MAYBE(username, trySplit(*userpass, ':')) { + result.userInfo = UserInfo { + percentDecode(*username, err, options), + percentDecode(*userpass, err, options) + }; + } else { + result.userInfo = UserInfo { + percentDecode(*userpass, err, options), + nullptr + }; + } + } + + result.host = percentDecode(authority, err, options); + if (!HOST_CHARS.containsAll(result.host)) return nullptr; + toLower(result.host); + } + } + + while (text.startsWith("/")) { + text = text.slice(1); + auto part = split(text, END_PATH_PART); + if (part.size() == 2 && part[0] == '.' && part[1] == '.') { + if (result.path.size() != 0) { + result.path.removeLast(); + } + result.hasTrailingSlash = true; + } else if ((part.size() == 0 && (!options.allowEmpty || text.size() == 0)) || + (part.size() == 1 && part[0] == '.')) { + // Collapse consecutive slashes and "/./". + result.hasTrailingSlash = true; + } else { + result.path.add(percentDecode(part, err, options)); + result.hasTrailingSlash = false; + } + } + + if (text.startsWith("?")) { + do { + text = text.slice(1); + auto part = split(text, END_QUERY_PART); + + if (part.size() > 0 || options.allowEmpty) { + KJ_IF_MAYBE(key, trySplit(part, '=')) { + result.query.add(QueryParam { percentDecodeQuery(*key, err, options), + percentDecodeQuery(part, err, options) }); + } else { + result.query.add(QueryParam { percentDecodeQuery(part, err, options), nullptr }); + } + } + } while (text.startsWith("&")); + } + + if (text.startsWith("#")) { + if (context != REMOTE_HREF) { + // No fragment allowed here. + return nullptr; + } + result.fragment = percentDecode(text.slice(1), err, options); + } else { + // We should have consumed everything. + KJ_ASSERT(text.size() == 0); + } + + if (err) return nullptr; + + return kj::mv(result); +} + +Url Url::parseRelative(StringPtr url) const { + return KJ_REQUIRE_NONNULL(tryParseRelative(url), "invalid relative URL", url); +} + +Maybe Url::tryParseRelative(StringPtr text) const { + if (text.size() == 0) return clone(); + + Url result; + result.options = options; + bool err = false; // tracks percent-decoding errors + + auto& END_PATH_PART = getEndPathPart(Url::REMOTE_HREF); + auto& END_QUERY_PART = getEndQueryPart(Url::REMOTE_HREF); + + // scheme + { + bool gotScheme = false; + for (auto i: kj::indices(text)) { + if (text[i] == ':') { + // found valid scheme + result.scheme = kj::str(text.slice(0, i)); + text = text.slice(i + 1); + gotScheme = true; + break; + } else if (NOT_SCHEME_CHARS.contains(text[i])) { + // no scheme + break; + } + } + if (!gotScheme) { + // copy scheme + result.scheme = kj::str(this->scheme); + } + } + + // authority + bool hadNewAuthority = text.startsWith("//"); + if (hadNewAuthority) { + text = text.slice(2); + + auto authority = split(text, END_AUTHORITY); + + KJ_IF_MAYBE(userpass, trySplit(authority, '@')) { + KJ_IF_MAYBE(username, trySplit(*userpass, ':')) { + result.userInfo = UserInfo { + percentDecode(*username, err, options), + percentDecode(*userpass, err, options) + }; + } else { + result.userInfo = UserInfo { + percentDecode(*userpass, err, options), + nullptr + }; + } + } + + result.host = percentDecode(authority, err, options); + if (!HOST_CHARS.containsAll(result.host)) return nullptr; + toLower(result.host); + } else { + // copy authority + result.host = kj::str(this->host); + result.userInfo = this->userInfo.map([](const UserInfo& userInfo) { + return UserInfo { + kj::str(userInfo.username), + userInfo.password.map([](const String& password) { return kj::str(password); }), + }; + }); + } + + // path + bool hadNewPath = text.size() > 0 && text[0] != '?' && text[0] != '#'; + if (hadNewPath) { + // There's a new path. + + if (text[0] == '/') { + // New path is absolute, so don't copy the old path. + text = text.slice(1); + result.hasTrailingSlash = true; + } else if (this->path.size() > 0) { + // New path is relative, so start from the old path, dropping everything after the last + // slash. + auto slice = this->path.slice(0, this->path.size() - (this->hasTrailingSlash ? 0 : 1)); + result.path = KJ_MAP(part, slice) { return kj::str(part); }; + result.hasTrailingSlash = true; + } + + for (;;) { + auto part = split(text, END_PATH_PART); + if (part.size() == 2 && part[0] == '.' && part[1] == '.') { + if (result.path.size() != 0) { + result.path.removeLast(); + } + result.hasTrailingSlash = true; + } else if (part.size() == 0 || (part.size() == 1 && part[0] == '.')) { + // Collapse consecutive slashes and "/./". + result.hasTrailingSlash = true; + } else { + result.path.add(percentDecode(part, err, options)); + result.hasTrailingSlash = false; + } + + if (!text.startsWith("/")) break; + text = text.slice(1); + } + } else if (!hadNewAuthority) { + // copy path + result.path = KJ_MAP(part, this->path) { return kj::str(part); }; + result.hasTrailingSlash = this->hasTrailingSlash; + } + + if (text.startsWith("?")) { + do { + text = text.slice(1); + auto part = split(text, END_QUERY_PART); + + if (part.size() > 0) { + KJ_IF_MAYBE(key, trySplit(part, '=')) { + result.query.add(QueryParam { percentDecodeQuery(*key, err, options), + percentDecodeQuery(part, err, options) }); + } else { + result.query.add(QueryParam { percentDecodeQuery(part, err, options), + nullptr }); + } + } + } while (text.startsWith("&")); + } else if (!hadNewAuthority && !hadNewPath) { + // copy query + result.query = KJ_MAP(param, this->query) -> QueryParam { + // Preserve the "allocated-ness" of `param.value` with this careful copy. + return { kj::str(param.name), param.value.begin() == nullptr ? kj::String() + : kj::str(param.value) }; + }; + } + + if (text.startsWith("#")) { + result.fragment = percentDecode(text.slice(1), err, options); + } else { + // We should have consumed everything. + KJ_ASSERT(text.size() == 0); + } + + if (err) return nullptr; + + return kj::mv(result); +} + +String Url::toString(Context context) const { + Vector chars(128); + + if (context != HTTP_REQUEST) { + chars.addAll(scheme); + chars.addAll(StringPtr("://")); + + if (context == REMOTE_HREF) { + KJ_IF_MAYBE(user, userInfo) { + chars.addAll(options.percentDecode ? encodeUriUserInfo(user->username) + : kj::str(user->username)); + KJ_IF_MAYBE(pass, user->password) { + chars.add(':'); + chars.addAll(options.percentDecode ? encodeUriUserInfo(*pass) : kj::str(*pass)); + } + chars.add('@'); + } + } + + // RFC3986 specifies that hosts can contain percent-encoding escapes while suggesting that + // they should only be used for UTF-8 sequences. However, the DNS standard specifies a + // different way to encode Unicode into domain names and doesn't permit any characters which + // would need to be escaped. Meanwhile, encodeUriComponent() here would incorrectly try to + // escape colons and brackets (e.g. around ipv6 literal addresses). So, instead, we throw if + // the host is invalid. + if (HOST_CHARS.containsAll(host)) { + chars.addAll(host); + } else { + KJ_FAIL_REQUIRE("invalid hostname when stringifying URL", host) { + chars.addAll(StringPtr("invalid-host")); + break; + } + } + } + + for (auto& pathPart: path) { + // Protect against path injection. + KJ_REQUIRE((pathPart != "" || options.allowEmpty) && pathPart != "." && pathPart != "..", + "invalid name in URL path", path) { + continue; + } + chars.add('/'); + chars.addAll(options.percentDecode ? encodeUriPath(pathPart) : kj::str(pathPart)); + } + if (hasTrailingSlash || (path.size() == 0 && context == HTTP_REQUEST)) { + chars.add('/'); + } + + bool first = true; + for (auto& param: query) { + chars.add(first ? '?' : '&'); + first = false; + chars.addAll(options.percentDecode ? encodeWwwForm(param.name) : kj::str(param.name)); + if (param.value.begin() != nullptr) { + chars.add('='); + chars.addAll(options.percentDecode ? encodeWwwForm(param.value) : kj::str(param.value)); + } + } + + if (context == REMOTE_HREF) { + KJ_IF_MAYBE(f, fragment) { + chars.add('#'); + chars.addAll(options.percentDecode ? encodeUriFragment(*f) : kj::str(*f)); + } + } + + chars.add('\0'); + return String(chars.releaseAsArray()); +} + +} // namespace kj diff --git a/c++/src/kj/compat/url.h b/c++/src/kj/compat/url.h new file mode 100644 index 0000000000..6e38d23061 --- /dev/null +++ b/c++/src/kj/compat/url.h @@ -0,0 +1,151 @@ +// Copyright (c) 2017 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include +#include + +KJ_BEGIN_HEADER + +namespace kj { + +struct UrlOptions { + // A bag of options that you can pass to Url::parse()/tryParse() to customize the parser's + // behavior. + // + // A copy of this options struct will be stored in the parsed Url object, at which point it + // controls the behavior of the serializer in Url::toString(). + + bool percentDecode = true; + // True if URL components should be automatically percent-decoded during parsing, and + // percent-encoded during serialization. + + bool allowEmpty = false; + // Whether or not to allow empty path and query components when parsing; otherwise, they are + // silently removed. In other words, setting this false causes consecutive slashes in the path or + // consecutive ampersands in the query to be collapsed into one, whereas if true then they + // produce empty components. +}; + +struct Url { + // Represents a URL (or, more accurately, a URI, but whatever). + // + // Can be parsed from a string and composed back into a string. + + String scheme; + // E.g. "http", "https". + + struct UserInfo { + String username; + Maybe password; + }; + + Maybe userInfo; + // Username / password. + + String host; + // Hostname, including port if specified. We choose not to parse out the port because KJ's + // network address parsing functions already accept addresses containing port numbers, and + // because most web standards don't actually want to separate host and port. + + Vector path; + bool hasTrailingSlash = false; + // Path, split on '/' characters. Note that the individual components of `path` could contain + // '/' characters if they were percent-encoded in the original URL. + // + // No component of the path is allowed to be "", ".", nor ".."; if such components are present, + // toString() will throw. Note that parse() and parseRelative() automatically resolve such + // components. + + struct QueryParam { + String name; + String value; + }; + Vector query; + // Query, e.g. from "?key=value&key2=value2". If a component of the query contains no '=' sign, + // it will be parsed as a key with a null value, and later serialized with no '=' sign if you call + // Url::toString(). + // + // To distinguish between null-valued and empty-valued query parameters, we test whether + // QueryParam::value is an allocated or unallocated string. For example: + // + // QueryParam { kj::str("name"), nullptr } // Null-valued; will not have an '=' sign. + // QueryParam { kj::str("name"), kj::str("") } // Empty-valued; WILL have an '=' sign. + + Maybe fragment; + // The stuff after the '#' character (not including the '#' character itself), if present. + + using Options = UrlOptions; + Options options; + + // --------------------------------------------------------------------------- + + Url() = default; + Url(Url&&) = default; + ~Url() noexcept(false); + Url& operator=(Url&&) = default; + + inline Url(String&& scheme, Maybe&& userInfo, String&& host, Vector&& path, + bool hasTrailingSlash, Vector&& query, Maybe&& fragment, + UrlOptions options) + : scheme(kj::mv(scheme)), userInfo(kj::mv(userInfo)), host(kj::mv(host)), path(kj::mv(path)), + hasTrailingSlash(hasTrailingSlash), query(kj::mv(query)), fragment(kj::mv(fragment)), + options(options) {} + // This constructor makes brace initialization work in C++11 and C++20 -- but is technically not + // needed in C++14 nor C++17. Go figure. + + Url clone() const; + + enum Context { + REMOTE_HREF, + // A link to a remote resource. Requires an authority (hostname) section, hence this will + // reject things like "mailto:" and "data:". This is the default context. + + HTTP_PROXY_REQUEST, + // The URL to place in the first line of an HTTP proxy request. This includes scheme, host, + // path, and query, but omits userInfo (which should be used to construct the Authorization + // header) and fragment (which should not be transmitted). + + HTTP_REQUEST + // The path to place in the first line of a regular HTTP request. This includes only the path + // and query. Scheme, user, host, and fragment are omitted. + + // TODO(someday): Add context(s) that supports things like "mailto:", "data:", "blob:". These + // don't have an authority section. + }; + + kj::String toString(Context context = REMOTE_HREF) const; + // Convert the URL to a string. + + static Url parse(StringPtr text, Context context = REMOTE_HREF, Options options = {}); + static Maybe tryParse(StringPtr text, Context context = REMOTE_HREF, Options options = {}); + // Parse an absolute URL. + + Url parseRelative(StringPtr relative) const; + Maybe tryParseRelative(StringPtr relative) const; + // Parse a relative URL string with this URL as the base. +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/debug-test.c++ b/c++/src/kj/debug-test.c++ index 8bf32a2e2a..3c65b5218b 100644 --- a/c++/src/kj/debug-test.c++ +++ b/c++/src/kj/debug-test.c++ @@ -19,6 +19,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include "debug.h" #include "exception.h" #include @@ -36,11 +40,6 @@ #include #endif -#if _MSC_VER -#pragma warning(disable: 4996) -// Warns that sprintf() is buffer-overrunny. Yeah, I know, it's cool. -#endif - namespace kj { namespace _ { // private namespace { @@ -71,6 +70,7 @@ public: // This is the child! close(pipeFds[0]); outputPipe = pipeFds[1]; + text.clear(); return true; } else { close(pipeFds[1]); @@ -170,9 +170,20 @@ public: this->text += "log message: "; text = str(file, ":", line, ":+", contextDepth, ": ", severity, ": ", mv(text)); this->text.append(text.begin(), text.end()); + this->text.append("\n"); } }; +#define EXPECT_LOG_EQ(f, expText) do { \ + std::string text; \ + { \ + MockExceptionCallback mockCallback; \ + f(); \ + text = kj::mv(mockCallback.text); \ + } \ + EXPECT_EQ(expText, text); \ +} while(0) + #if KJ_NO_EXCEPTIONS #define EXPECT_FATAL(code) if (mockCallback.forkForDeathTest()) { code; abort(); } #else @@ -187,53 +198,55 @@ std::string fileLine(std::string file, int line) { file += ':'; char buffer[32]; - sprintf(buffer, "%d", line); + snprintf(buffer, sizeof(buffer), "%d", line); file += buffer; return file; } TEST(Debug, Log) { - MockExceptionCallback mockCallback; int line; - KJ_LOG(WARNING, "Hello world!"); line = __LINE__; - EXPECT_EQ("log message: " + fileLine(__FILE__, line) + ":+0: warning: Hello world!\n", - mockCallback.text); - mockCallback.text.clear(); + EXPECT_LOG_EQ([&](){ + KJ_LOG(WARNING, "Hello world!"); line = __LINE__; + }, "log message: " + fileLine(__FILE__, line) + ":+0: warning: Hello world!\n"); int i = 123; const char* str = "foo"; - KJ_LOG(ERROR, i, str); line = __LINE__; - EXPECT_EQ("log message: " + fileLine(__FILE__, line) + ":+0: error: i = 123; str = foo\n", - mockCallback.text); - mockCallback.text.clear(); + EXPECT_LOG_EQ([&](){ + KJ_LOG(ERROR, i, str); line = __LINE__; + }, "log message: " + fileLine(__FILE__, line) + ":+0: error: i = 123; str = foo\n"); - KJ_DBG("Some debug text."); line = __LINE__; - EXPECT_EQ("log message: " + fileLine(__FILE__, line) + ":+0: debug: Some debug text.\n", - mockCallback.text); - mockCallback.text.clear(); + // kj::str() expressions are included literally. + EXPECT_LOG_EQ([&](){ + KJ_LOG(ERROR, kj::str(i, str), "x"); line = __LINE__; + }, "log message: " + fileLine(__FILE__, line) + ":+0: error: 123foo; x\n"); + + EXPECT_LOG_EQ([&](){ + KJ_DBG("Some debug text."); line = __LINE__; + }, "log message: " + fileLine(__FILE__, line) + ":+0: debug: Some debug text.\n"); // INFO logging is disabled by default. - KJ_LOG(INFO, "Info."); line = __LINE__; - EXPECT_EQ("", mockCallback.text); - mockCallback.text.clear(); + EXPECT_LOG_EQ([&](){ + KJ_LOG(INFO, "Info."); line = __LINE__; + }, ""); // Enable it. Debug::setLogLevel(Debug::Severity::INFO); - KJ_LOG(INFO, "Some text."); line = __LINE__; - EXPECT_EQ("log message: " + fileLine(__FILE__, line) + ":+0: info: Some text.\n", - mockCallback.text); - mockCallback.text.clear(); + EXPECT_LOG_EQ([&](){ + KJ_LOG(INFO, "Some text."); line = __LINE__; + }, "log message: " + fileLine(__FILE__, line) + ":+0: info: Some text.\n"); // Back to default. Debug::setLogLevel(Debug::Severity::WARNING); - KJ_ASSERT(1 == 1); - EXPECT_FATAL(KJ_ASSERT(1 == 2)); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": failed: expected " - "1 == 2\n", mockCallback.text); - mockCallback.text.clear(); + EXPECT_LOG_EQ([&](){ + KJ_ASSERT(1 == 1); + }, ""); + + EXPECT_LOG_EQ([&](){ + EXPECT_FATAL(KJ_ASSERT(1 == 2)); line = __LINE__; + }, "fatal exception: " + fileLine(__FILE__, line) + ": failed: expected 1 == 2 [1 == 2]\n"); KJ_ASSERT(1 == 1) { ADD_FAILURE() << "Shouldn't call recovery code when check passes."; @@ -241,26 +254,31 @@ TEST(Debug, Log) { }; bool recovered = false; - KJ_ASSERT(1 == 2, "1 is not 2") { recovered = true; break; } line = __LINE__; - EXPECT_EQ("recoverable exception: " + fileLine(__FILE__, line) + ": failed: expected " - "1 == 2; 1 is not 2\n", mockCallback.text); + EXPECT_LOG_EQ([&](){ + KJ_ASSERT(1 == 2, "1 is not 2") { recovered = true; break; } line = __LINE__; + }, ( + "recoverable exception: " + fileLine(__FILE__, line) + ": " + "failed: expected 1 == 2 [1 == 2]; 1 is not 2\n" + )); EXPECT_TRUE(recovered); - mockCallback.text.clear(); - - EXPECT_FATAL(KJ_ASSERT(1 == 2, i, "hi", str)); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": failed: expected " - "1 == 2; i = 123; hi; str = foo\n", mockCallback.text); - mockCallback.text.clear(); - - EXPECT_FATAL(KJ_REQUIRE(1 == 2, i, "hi", str)); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": failed: expected " - "1 == 2; i = 123; hi; str = foo\n", mockCallback.text); - mockCallback.text.clear(); - - EXPECT_FATAL(KJ_FAIL_ASSERT("foo")); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": failed: foo\n", - mockCallback.text); - mockCallback.text.clear(); + + EXPECT_LOG_EQ([&](){ + EXPECT_FATAL(KJ_ASSERT(1 == 2, i, "hi", str)); line = __LINE__; + }, ( + "fatal exception: " + fileLine(__FILE__, line) + ": " + "failed: expected 1 == 2 [1 == 2]; i = 123; hi; str = foo\n" + )); + + EXPECT_LOG_EQ([&](){ + EXPECT_FATAL(KJ_REQUIRE(1 == 2, i, "hi", str)); line = __LINE__; + }, ( + "fatal exception: " + fileLine(__FILE__, line) + ": " + "failed: expected 1 == 2 [1 == 2]; i = 123; hi; str = foo\n" + )); + + EXPECT_LOG_EQ([&](){ + EXPECT_FATAL(KJ_FAIL_ASSERT("foo")); line = __LINE__; + }, "fatal exception: " + fileLine(__FILE__, line) + ": failed: foo\n"); } TEST(Debug, Exception) { @@ -318,7 +336,7 @@ TEST(Debug, Catch) { // Catch as std::exception. try { line = __LINE__; KJ_FAIL_ASSERT("foo"); - ADD_FAILURE() << "Expected exception."; + KJ_KNOWN_UNREACHABLE(ADD_FAILURE() << "Expected exception."); } catch (const std::exception& e) { kj::StringPtr what = e.what(); std::string text; @@ -339,90 +357,174 @@ int mockSyscall(int i, int error = 0) { } TEST(Debug, Syscall) { - MockExceptionCallback mockCallback; int line; int i = 123; const char* str = "foo"; - KJ_SYSCALL(mockSyscall(0)); - KJ_SYSCALL(mockSyscall(1)); - - EXPECT_FATAL(KJ_SYSCALL(mockSyscall(-1, EBADF), i, "bar", str)); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + - ": failed: mockSyscall(-1, EBADF): " + strerror(EBADF) + - "; i = 123; bar; str = foo\n", mockCallback.text); - mockCallback.text.clear(); - - EXPECT_FATAL(KJ_SYSCALL(mockSyscall(-1, ECONNRESET), i, "bar", str)); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + - ": disconnected: mockSyscall(-1, ECONNRESET): " + strerror(ECONNRESET) + - "; i = 123; bar; str = foo\n", mockCallback.text); - mockCallback.text.clear(); - - EXPECT_FATAL(KJ_SYSCALL(mockSyscall(-1, ENOMEM), i, "bar", str)); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + - ": overloaded: mockSyscall(-1, ENOMEM): " + strerror(ENOMEM) + - "; i = 123; bar; str = foo\n", mockCallback.text); - mockCallback.text.clear(); - - EXPECT_FATAL(KJ_SYSCALL(mockSyscall(-1, ENOSYS), i, "bar", str)); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + - ": unimplemented: mockSyscall(-1, ENOSYS): " + strerror(ENOSYS) + - "; i = 123; bar; str = foo\n", mockCallback.text); - mockCallback.text.clear(); + EXPECT_LOG_EQ([&](){ + KJ_SYSCALL(mockSyscall(0)); + KJ_SYSCALL(mockSyscall(1)); + }, ""); + + EXPECT_LOG_EQ([&](){ + EXPECT_FATAL(KJ_SYSCALL(mockSyscall(-1, EBADF), i, "bar", str)); line = __LINE__; + }, ( + "fatal exception: " + fileLine(__FILE__, line) + + ": failed: mockSyscall(-1, EBADF): " + strerror(EBADF) + + "; i = 123; bar; str = foo\n" + )); + + EXPECT_LOG_EQ([&](){ + EXPECT_FATAL(KJ_SYSCALL(mockSyscall(-1, ECONNRESET), i, "bar", str)); line = __LINE__; + }, ( + "fatal exception: " + fileLine(__FILE__, line) + + ": disconnected: mockSyscall(-1, ECONNRESET): " + strerror(ECONNRESET) + + "; i = 123; bar; str = foo\n" + )); + + EXPECT_LOG_EQ([&](){ + EXPECT_FATAL(KJ_SYSCALL(mockSyscall(-1, ENOMEM), i, "bar", str)); line = __LINE__; + }, ( + "fatal exception: " + fileLine(__FILE__, line) + + ": overloaded: mockSyscall(-1, ENOMEM): " + strerror(ENOMEM) + + "; i = 123; bar; str = foo\n" + )); + + EXPECT_LOG_EQ([&](){ + EXPECT_FATAL(KJ_SYSCALL(mockSyscall(-1, ENOSYS), i, "bar", str)); line = __LINE__; + }, ( + "fatal exception: " + fileLine(__FILE__, line) + + ": unimplemented: mockSyscall(-1, ENOSYS): " + strerror(ENOSYS) + + "; i = 123; bar; str = foo\n" + )); int result = 0; bool recovered = false; - KJ_SYSCALL(result = mockSyscall(-2, EBADF), i, "bar", str) { recovered = true; break; } line = __LINE__; - EXPECT_EQ("recoverable exception: " + fileLine(__FILE__, line) + - ": failed: mockSyscall(-2, EBADF): " + strerror(EBADF) + - "; i = 123; bar; str = foo\n", mockCallback.text); + EXPECT_LOG_EQ([&](){ + KJ_SYSCALL(result = mockSyscall(-2, EBADF), i, "bar", str) { recovered = true; break; } line = __LINE__; + }, ( + "recoverable exception: " + fileLine(__FILE__, line) + + ": failed: mockSyscall(-2, EBADF): " + strerror(EBADF) + + "; i = 123; bar; str = foo\n" + )); EXPECT_EQ(-2, result); EXPECT_TRUE(recovered); } TEST(Debug, Context) { - MockExceptionCallback mockCallback; - - { - KJ_CONTEXT("foo"); int cline = __LINE__; - - KJ_LOG(WARNING, "blah"); int line = __LINE__; - EXPECT_EQ("log message: " + fileLine(__FILE__, cline) + ":+0: context: foo\n" - "log message: " + fileLine(__FILE__, line) + ":+1: warning: blah\n", - mockCallback.text); - mockCallback.text.clear(); - - EXPECT_FATAL(KJ_FAIL_ASSERT("bar")); line = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, cline) + ": context: foo\n" - + fileLine(__FILE__, line) + ": failed: bar\n", - mockCallback.text); - mockCallback.text.clear(); - + int line; + int line2; + int cline; + int cline2; + + EXPECT_LOG_EQ([&](){ + KJ_CONTEXT("foo"); cline = __LINE__; + + KJ_LOG(WARNING, "blah"); line = __LINE__; + EXPECT_FATAL(KJ_FAIL_ASSERT("bar")); line2 = __LINE__; + }, ( + "log message: " + fileLine(__FILE__, cline) + ":+0: info: context: foo\n\n" + "log message: " + fileLine(__FILE__, line) + ":+1: warning: blah\n" + "fatal exception: " + fileLine(__FILE__, cline) + ": context: foo\n" + + fileLine(__FILE__, line2) + ": failed: bar\n" + )); + + EXPECT_LOG_EQ([&](){ + KJ_CONTEXT("foo"); cline = __LINE__; { int i = 123; const char* str = "qux"; - KJ_CONTEXT("baz", i, "corge", str); int cline2 = __LINE__; - EXPECT_FATAL(KJ_FAIL_ASSERT("bar")); line = __LINE__; + KJ_CONTEXT("baz", i, "corge", str); cline2 = __LINE__; - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, cline) + ": context: foo\n" - + fileLine(__FILE__, cline2) + ": context: baz; i = 123; corge; str = qux\n" - + fileLine(__FILE__, line) + ": failed: bar\n", - mockCallback.text); - mockCallback.text.clear(); + EXPECT_FATAL(KJ_FAIL_ASSERT("bar")); line = __LINE__; + } + }, ( + "fatal exception: " + fileLine(__FILE__, cline) + ": context: foo\n" + + fileLine(__FILE__, cline2) + ": context: baz; i = 123; corge; str = qux\n" + + fileLine(__FILE__, line) + ": failed: bar\n" + )); + + EXPECT_LOG_EQ([&](){ + KJ_CONTEXT("foo"); cline = __LINE__; + { + int i = 123; + const char* str = "qux"; + KJ_CONTEXT("baz", i, "corge", str); cline2 = __LINE__; } - { - KJ_CONTEXT("grault"); int cline2 = __LINE__; + KJ_CONTEXT("grault"); cline2 = __LINE__; EXPECT_FATAL(KJ_FAIL_ASSERT("bar")); line = __LINE__; - - EXPECT_EQ("fatal exception: " + fileLine(__FILE__, cline) + ": context: foo\n" - + fileLine(__FILE__, cline2) + ": context: grault\n" - + fileLine(__FILE__, line) + ": failed: bar\n", - mockCallback.text); - mockCallback.text.clear(); } + }, ( + "fatal exception: " + fileLine(__FILE__, cline) + ": context: foo\n" + + fileLine(__FILE__, cline2) + ": context: grault\n" + + fileLine(__FILE__, line) + ": failed: bar\n" + )); +} + +KJ_TEST("magic assert stringification") { + { + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + int foo = 123; + int bar = 456; + KJ_ASSERT(foo == bar) { break; } + })); + + KJ_EXPECT(exception.getDescription() == "expected foo == bar [123 == 456]"); + } + + { + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + auto foo = kj::str("hello"); + auto bar = kj::str("world!"); + KJ_ASSERT(foo == bar, foo.size(), bar.size()) { break; } + })); + + KJ_EXPECT(exception.getDescription() == + "expected foo == bar [hello == world!]; foo.size() = 5; bar.size() = 6"); + } + + { + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + KJ_ASSERT(kj::str("hello") == kj::str("world!")) { break; } + })); + + KJ_EXPECT(exception.getDescription() == + "expected kj::str(\"hello\") == kj::str(\"world!\") [hello == world!]"); + } + + { + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + int foo = 123; + int bar = 456; + KJ_ASSERT((foo == bar)) { break; } + })); + + KJ_EXPECT(exception.getDescription() == "expected (foo == bar)"); + } + + // Test use of << on left side, which could create confusion. + { + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + int foo = 123; + int bar = 456; + KJ_ASSERT(foo << 2 == bar) { break; } + })); + + KJ_EXPECT(exception.getDescription() == "expected foo << 2 == bar [492 == 456]"); + } + + // Test use of & on left side. + { + int foo = 4; + KJ_ASSERT(foo & 4); + + auto exception = KJ_ASSERT_NONNULL(kj::runCatchingExceptions([&]() { + KJ_ASSERT(foo & 2) { break; } + })); + + KJ_EXPECT(exception.getDescription() == "expected foo & 2"); } } diff --git a/c++/src/kj/debug.c++ b/c++/src/kj/debug.c++ index 3e1275fbbd..f685e3162f 100644 --- a/c++/src/kj/debug.c++ +++ b/c++/src/kj/debug.c++ @@ -19,21 +19,24 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if _WIN32 || __CYGWIN__ +#include "win32-api-version.h" +#endif + #include "debug.h" #include #include #include #include -#if _WIN32 +#if _WIN32 || __CYGWIN__ +#if !__CYGWIN__ #define strerror_r(errno,buf,len) strerror_s(buf,len,errno) -#define NOMINMAX 1 -#define WIN32_LEAN_AND_MEAN 1 -#define NOSERVICE 1 -#define NOMCX 1 -#define NOIME 1 +#endif #include #include "windows-sanity.h" +#include "encoding.h" +#include #endif namespace kj { @@ -132,11 +135,11 @@ Exception::Type typeOfErrno(int error) { } } -#if _WIN32 +#if _WIN32 || __CYGWIN__ Exception::Type typeOfWin32Error(DWORD error) { switch (error) { - // TODO(0.7): This needs more work. + // TODO(someday): This needs more work. case WSAETIMEDOUT: return Exception::Type::OVERLOADED; @@ -199,18 +202,20 @@ static String makeDescriptionImpl(DescriptionStyle style, const char* code, int quoted = true; } else if (c == ',' && depth == 0) { if (index < argValues.size()) { - argNames[index] = arrayPtr(start, pos - 1); + argNames[index++] = arrayPtr(start, pos - 1); } - ++index; while (isspace(*pos)) ++pos; start = pos; + if (*pos == '\0') { + // ignore trailing comma + break; + } } } } if (index < argValues.size()) { - argNames[index] = arrayPtr(start, pos - 1); + argNames[index++] = arrayPtr(start, pos - 1); } - ++index; if (index != argValues.size()) { getExceptionCallback().logMessage(LogSeverity::ERROR, __FILE__, __LINE__, 0, @@ -241,6 +246,8 @@ static String makeDescriptionImpl(DescriptionStyle style, const char* code, int StringPtr sep = " = "; StringPtr delim = "; "; StringPtr colon = ": "; + StringPtr openBracket = " ["; + StringPtr closeBracket = "]"; StringPtr sysErrorArray; // On android before marshmallow only the posix version of stderror_r was @@ -278,11 +285,26 @@ static String makeDescriptionImpl(DescriptionStyle style, const char* code, int break; } + auto needsLabel = [](ArrayPtr &argName) -> bool { + return (argName.size() > 0 && argName[0] != '\"' && + !(argName.size() >= 8 && memcmp(argName.begin(), "kj::str(", 8) == 0)); + }; + for (size_t i = 0; i < argValues.size(); i++) { + if (argNames[i] == "_kjCondition"_kj) { + // Special handling: don't output delimiter, we want to append this to the previous item, + // in brackets. Also, if it's just "[false]" (meaning we didn't manage to extract a + // comparison), don't add it at all. + if (argValues[i] != "false") { + totalSize += openBracket.size() + argValues[i].size() + closeBracket.size(); + } + continue; + } + if (i > 0 || style != LOG) { totalSize += delim.size(); } - if (argNames[i].size() > 0 && argNames[i][0] != '\"') { + if (needsLabel(argNames[i])) { totalSize += argNames[i].size() + sep.size(); } totalSize += argValues[i].size(); @@ -303,10 +325,20 @@ static String makeDescriptionImpl(DescriptionStyle style, const char* code, int } for (size_t i = 0; i < argValues.size(); i++) { + if (argNames[i] == "_kjCondition"_kj) { + // Special handling: don't output delimiter, we want to append this to the previous item, + // in brackets. Also, if it's just "[false]" (meaning we didn't manage to extract a + // comparison), don't add it at all. + if (argValues[i] != "false") { + pos = _::fill(pos, openBracket, argValues[i], closeBracket); + } + continue; + } + if (i > 0 || style != LOG) { pos = _::fill(pos, delim); } - if (argNames[i].size() > 0 && argNames[i][0] != '\"') { + if (needsLabel(argNames[i])) { pos = _::fill(pos, argNames[i], sep); } pos = _::fill(pos, argValues[i]); @@ -328,7 +360,7 @@ Debug::Fault::~Fault() noexcept(false) { if (exception != nullptr) { Exception copy = mv(*exception); delete exception; - throwRecoverableException(mv(copy), 2); + throwRecoverableException(mv(copy), 1); } } @@ -336,8 +368,8 @@ void Debug::Fault::fatal() { Exception copy = mv(*exception); delete exception; exception = nullptr; - throwFatalException(mv(copy), 2); - abort(); + throwFatalException(mv(copy), 1); + KJ_KNOWN_UNREACHABLE(abort()); } void Debug::Fault::init( @@ -354,31 +386,35 @@ void Debug::Fault::init( makeDescriptionImpl(SYSCALL, condition, osErrorNumber, nullptr, macroArgs, argValues)); } -#if _WIN32 +#if _WIN32 || __CYGWIN__ void Debug::Fault::init( - const char* file, int line, Win32Error osErrorNumber, + const char* file, int line, Win32Result osErrorNumber, const char* condition, const char* macroArgs, ArrayPtr argValues) { LPVOID ptr; - // TODO(0.7): Use FormatMessageW() instead. - // TODO(0.7): Why doesn't this work for winsock errors? - DWORD result = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | + // TODO(someday): Why doesn't this work for winsock errors? + DWORD result = FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, osErrorNumber.number, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPTSTR) &ptr, 0, NULL); + (LPWSTR) &ptr, 0, NULL); + String message; if (result > 0) { KJ_DEFER(LocalFree(ptr)); - exception = new Exception(typeOfWin32Error(osErrorNumber.number), file, line, - makeDescriptionImpl(SYSCALL, condition, 0, reinterpret_cast(ptr), - macroArgs, argValues)); + const wchar_t* desc = reinterpret_cast(ptr); + size_t len = wcslen(desc); + if (len > 0 && desc[len-1] == '\n') --len; + if (len > 0 && desc[len-1] == '\r') --len; + message = kj::str('#', osErrorNumber.number, ' ', + decodeWideString(arrayPtr(desc, len))); } else { - auto message = kj::str("win32 error code: ", osErrorNumber.number); - exception = new Exception(typeOfWin32Error(osErrorNumber.number), file, line, - makeDescriptionImpl(SYSCALL, condition, 0, message.cStr(), - macroArgs, argValues)); + message = kj::str("win32 error code: ", osErrorNumber.number); } + + exception = new Exception(typeOfWin32Error(osErrorNumber.number), file, line, + makeDescriptionImpl(SYSCALL, condition, 0, message.cStr(), + macroArgs, argValues)); } #endif @@ -396,9 +432,9 @@ int Debug::getOsErrorNumber(bool nonblocking) { : result; } -#if _WIN32 -Debug::Win32Error Debug::getWin32Error() { - return Win32Error(::GetLastError()); +#if _WIN32 || __CYGWIN__ +uint Debug::getWin32ErrorCode() { + return ::GetLastError(); } #endif @@ -429,7 +465,7 @@ void Debug::Context::logMessage(LogSeverity severity, const char* file, int line String&& text) { if (!logged) { Value v = ensureInitialized(); - next.logMessage(LogSeverity::INFO, v.file, v.line, 0, + next.logMessage(LogSeverity::INFO, trimSourceFilename(v.file).cStr(), v.line, 0, str("context: ", mv(v.description), '\n')); logged = true; } diff --git a/c++/src/kj/debug.h b/c++/src/kj/debug.h index fff7f98bc0..9f8459b1ce 100644 --- a/c++/src/kj/debug.h +++ b/c++/src/kj/debug.h @@ -67,6 +67,13 @@ // * `KJ_REQUIRE(condition, ...)`: Like `KJ_ASSERT` but used to check preconditions -- e.g. to // validate parameters passed from a caller. A failure indicates that the caller is buggy. // +// * `KJ_ASSUME(condition, ...)`: Like `KJ_ASSERT`, but in release mode (if KJ_DEBUG is not +// defined; see below) instead warrants to the compiler that the condition can be assumed to +// hold, allowing it to optimize accordingly. This can result in undefined behavior, so use +// this macro *only* if you can prove to your satisfaction that the condition is guaranteed by +// surrounding code, and if the condition failing to hold would in any case result in undefined +// behavior in its dependencies. +// // * `KJ_SYSCALL(code, ...)`: Executes `code` assuming it makes a system call. A negative result // is considered an error, with error code reported via `errno`. EINTR is handled by retrying. // Other errors are handled by throwing an exception. If you need to examine the return code, @@ -98,36 +105,30 @@ // omits the first parameter and behaves like it was `false`. `FAIL_SYSCALL` and // `FAIL_RECOVERABLE_SYSCALL` take a string and an OS error number as the first two parameters. // The string should be the name of the failed system call. -// * For every macro `FOO` above, there is a `DFOO` version (or `RECOVERABLE_DFOO`) which is only -// executed in debug mode, i.e. when KJ_DEBUG is defined. KJ_DEBUG is defined automatically -// by common.h when compiling without optimization (unless NDEBUG is defined), but you can also -// define it explicitly (e.g. -DKJ_DEBUG). Generally, production builds should NOT use KJ_DEBUG -// as it may enable expensive checks that are unlikely to fail. - -#ifndef KJ_DEBUG_H_ -#define KJ_DEBUG_H_ +// * For every macro `FOO` above except `ASSUME`, there is a `DFOO` version (or +// `RECOVERABLE_DFOO`) which is only executed in debug mode, i.e. when KJ_DEBUG is defined. +// KJ_DEBUG is defined automatically by common.h when compiling without optimization (unless +// NDEBUG is defined), but you can also define it explicitly (e.g. -DKJ_DEBUG). Generally, +// production builds should NOT use KJ_DEBUG as it may enable expensive checks that are unlikely +// to fail. -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "string.h" #include "exception.h" +#include "windows-sanity.h" // work-around macro conflict with `ERROR` -#ifdef ERROR -// This is problematic because windows.h #defines ERROR, which we use in an enum here. -#error "Make sure to to undefine ERROR (or just #include ) before this file" -#endif +KJ_BEGIN_HEADER namespace kj { -#if _MSC_VER +#if KJ_MSVC_TRADITIONAL_CPP // MSVC does __VA_ARGS__ differently from GCC: // - A trailing comma before an empty __VA_ARGS__ is removed automatically, whereas GCC wants // you to request this behavior with "##__VA_ARGS__". // - If __VA_ARGS__ is passed directly as an argument to another macro, it will be treated as a // *single* argument rather than an argument list. This can be worked around by wrapping the -// outer macro call in KJ_EXPAND(), which appraently forces __VA_ARGS__ to be expanded before +// outer macro call in KJ_EXPAND(), which apparently forces __VA_ARGS__ to be expanded before // the macro is evaluated. I don't understand the C preprocessor. // - Using "#__VA_ARGS__" to stringify __VA_ARGS__ expands to zero tokens when __VA_ARGS__ is // empty, rather than expanding to an empty string literal. We can work around by concatenating @@ -136,16 +137,17 @@ namespace kj { #define KJ_EXPAND(X) X #define KJ_LOG(severity, ...) \ - if (!::kj::_::Debug::shouldLog(::kj::LogSeverity::severity)) {} else \ + for (bool _kj_shouldLog = ::kj::_::Debug::shouldLog(::kj::LogSeverity::severity); \ + _kj_shouldLog; _kj_shouldLog = false) \ ::kj::_::Debug::log(__FILE__, __LINE__, ::kj::LogSeverity::severity, \ "" #__VA_ARGS__, __VA_ARGS__) #define KJ_DBG(...) KJ_EXPAND(KJ_LOG(DBG, __VA_ARGS__)) #define KJ_REQUIRE(cond, ...) \ - if (KJ_LIKELY(cond)) {} else \ + if (auto _kjCondition = ::kj::_::MAGIC_ASSERT << cond) {} else \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \ - #cond, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal()) + #cond, "_kjCondition," #__VA_ARGS__, _kjCondition, __VA_ARGS__);; f.fatal()) #define KJ_FAIL_REQUIRE(...) \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \ @@ -165,21 +167,21 @@ namespace kj { for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \ errorNumber, code, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal()) -#if _WIN32 +#if _WIN32 || __CYGWIN__ #define KJ_WIN32(call, ...) \ - if (::kj::_::Debug::isWin32Success(call)) {} else \ + if (auto _kjWin32Result = ::kj::_::Debug::win32Call(call)) {} else \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \ - ::kj::_::Debug::getWin32Error(), #call, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal()) + _kjWin32Result, #call, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal()) #define KJ_WINSOCK(call, ...) \ - if ((call) != SOCKET_ERROR) {} else \ + if (auto _kjWin32Result = ::kj::_::Debug::winsockCall(call)) {} else \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \ - ::kj::_::Debug::getWin32Error(), #call, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal()) + _kjWin32Result, #call, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal()) #define KJ_FAIL_WIN32(code, errorNumber, ...) \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \ - ::kj::_::Debug::Win32Error(errorNumber), code, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal()) + ::kj::_::Debug::Win32Result(errorNumber), code, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal()) #endif @@ -215,16 +217,17 @@ namespace kj { #else #define KJ_LOG(severity, ...) \ - if (!::kj::_::Debug::shouldLog(::kj::LogSeverity::severity)) {} else \ + for (bool _kj_shouldLog = ::kj::_::Debug::shouldLog(::kj::LogSeverity::severity); \ + _kj_shouldLog; _kj_shouldLog = false) \ ::kj::_::Debug::log(__FILE__, __LINE__, ::kj::LogSeverity::severity, \ #__VA_ARGS__, ##__VA_ARGS__) #define KJ_DBG(...) KJ_LOG(DBG, ##__VA_ARGS__) #define KJ_REQUIRE(cond, ...) \ - if (KJ_LIKELY(cond)) {} else \ + if (auto _kjCondition = ::kj::_::MAGIC_ASSERT << cond) {} else \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \ - #cond, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal()) + #cond, "_kjCondition," #__VA_ARGS__, _kjCondition, ##__VA_ARGS__);; f.fatal()) #define KJ_FAIL_REQUIRE(...) \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \ @@ -244,21 +247,26 @@ namespace kj { for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \ errorNumber, code, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal()) -#if _WIN32 +#if _WIN32 || __CYGWIN__ #define KJ_WIN32(call, ...) \ - if (::kj::_::Debug::isWin32Success(call)) {} else \ + if (auto _kjWin32Result = ::kj::_::Debug::win32Call(call)) {} else \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \ - ::kj::_::Debug::getWin32Error(), #call, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal()) + _kjWin32Result, #call, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal()) +// Invoke a Win32 syscall that returns either BOOL or HANDLE, and throw an exception if it fails. #define KJ_WINSOCK(call, ...) \ - if ((call) != SOCKET_ERROR) {} else \ + if (auto _kjWin32Result = ::kj::_::Debug::winsockCall(call)) {} else \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \ - ::kj::_::Debug::getWin32Error(), #call, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal()) + _kjWin32Result, #call, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal()) +// Like KJ_WIN32 but for winsock calls which return `int` with SOCKET_ERROR indicating failure. +// +// Unfortunately, it's impossible to distinguish these from BOOL-returning Win32 calls by type, +// since BOOL is in fact an alias for `int`. :( #define KJ_FAIL_WIN32(code, errorNumber, ...) \ for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \ - ::kj::_::Debug::Win32Error(errorNumber), code, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal()) + ::kj::_::Debug::Win32Result(errorNumber), code, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal()) #endif @@ -274,6 +282,20 @@ namespace kj { ::kj::_::Debug::ContextImpl \ KJ_UNIQUE_NAME(_kjContext)(KJ_UNIQUE_NAME(_kjContextFunc)) +#if _MSC_VER && !defined(__clang__) + +#define KJ_REQUIRE_NONNULL(value, ...) \ + (*([&] { \ + auto _kj_result = ::kj::_::readMaybe(value); \ + if (KJ_UNLIKELY(!_kj_result)) { \ + ::kj::_::Debug::Fault(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \ + #value " != nullptr", #__VA_ARGS__, ##__VA_ARGS__).fatal(); \ + } \ + return _kj_result; \ + }())) + +#else + #define KJ_REQUIRE_NONNULL(value, ...) \ (*({ \ auto _kj_result = ::kj::_::readMaybe(value); \ @@ -284,6 +306,8 @@ namespace kj { kj::mv(_kj_result); \ })) +#endif + #define KJ_EXCEPTION(type, ...) \ ::kj::Exception(::kj::Exception::Type::type, __FILE__, __LINE__, \ ::kj::_::Debug::makeDescription(#__VA_ARGS__, ##__VA_ARGS__)) @@ -292,7 +316,7 @@ namespace kj { #define KJ_SYSCALL_HANDLE_ERRORS(call) \ if (int _kjSyscallError = ::kj::_::Debug::syscallError([&](){return (call);}, false)) \ - switch (int error = _kjSyscallError) + switch (int error KJ_UNUSED = _kjSyscallError) // Like KJ_SYSCALL, but doesn't throw. Instead, the block after the macro is a switch block on the // error. Additionally, the int value `error` is defined within the block. So you can do: // @@ -309,6 +333,29 @@ namespace kj { // handleSuccessCase(); // } +#if _WIN32 || __CYGWIN__ + +#define KJ_WIN32_HANDLE_ERRORS(call) \ + if (uint _kjWin32Error = ::kj::_::Debug::win32Call(call).number) \ + switch (uint error KJ_UNUSED = _kjWin32Error) +// Like KJ_WIN32, but doesn't throw. Instead, the block after the macro is a switch block on the +// error. Additionally, the int value `error` is defined within the block. So you can do: +// +// KJ_SYSCALL_HANDLE_ERRORS(foo()) { +// case ERROR_FILE_NOT_FOUND: +// handleNoSuchFile(); +// break; +// case ERROR_FILE_EXISTS: +// handleExists(); +// break; +// default: +// KJ_FAIL_WIN32("foo()", error); +// } else { +// handleSuccessCase(); +// } + +#endif + #define KJ_ASSERT KJ_REQUIRE #define KJ_FAIL_ASSERT KJ_FAIL_REQUIRE #define KJ_ASSERT_NONNULL KJ_REQUIRE_NONNULL @@ -319,10 +366,21 @@ namespace kj { #define KJ_DLOG KJ_LOG #define KJ_DASSERT KJ_ASSERT #define KJ_DREQUIRE KJ_REQUIRE +#define KJ_ASSUME KJ_ASSERT #else #define KJ_DLOG(...) do {} while (false) #define KJ_DASSERT(...) do {} while (false) #define KJ_DREQUIRE(...) do {} while (false) +#if defined(__GNUC__) +#define KJ_ASSUME(cond, ...) do { if (cond) {} else __builtin_unreachable(); } while (false) +#elif defined(__clang__) +#define KJ_ASSUME(cond, ...) __builtin_assume(cond) +#elif defined(_MSC_VER) +#define KJ_ASSUME(cond, ...) __assume(cond) +#else +#define KJ_ASSUME(...) do {} while (false) +#endif + #endif namespace _ { // private @@ -333,11 +391,11 @@ class Debug { typedef LogSeverity Severity; // backwards-compatibility -#if _WIN32 - struct Win32Error { - // Hack for overloading purposes. +#if _WIN32 || __CYGWIN__ + struct Win32Result { uint number; - inline explicit Win32Error(uint number): number(number) {} + inline explicit Win32Result(uint number): number(number) {} + operator bool() const { return number == 0; } }; #endif @@ -362,8 +420,8 @@ class Debug { const char* condition, const char* macroArgs); Fault(const char* file, int line, int osErrorNumber, const char* condition, const char* macroArgs); -#if _WIN32 - Fault(const char* file, int line, Win32Error osErrorNumber, +#if _WIN32 || __CYGWIN__ + Fault(const char* file, int line, Win32Result osErrorNumber, const char* condition, const char* macroArgs); #endif ~Fault() noexcept(false); @@ -376,8 +434,8 @@ class Debug { const char* condition, const char* macroArgs, ArrayPtr argValues); void init(const char* file, int line, int osErrorNumber, const char* condition, const char* macroArgs, ArrayPtr argValues); -#if _WIN32 - void init(const char* file, int line, Win32Error osErrorNumber, +#if _WIN32 || __CYGWIN__ + void init(const char* file, int line, Win32Result osErrorNumber, const char* condition, const char* macroArgs, ArrayPtr argValues); #endif @@ -399,16 +457,17 @@ class Debug { template static int syscallError(Call&& call, bool nonblocking); -#if _WIN32 - static bool isWin32Success(int boolean); - static bool isWin32Success(void* handle); - static Win32Error getWin32Error(); +#if _WIN32 || __CYGWIN__ + static Win32Result win32Call(int boolean); + static Win32Result win32Call(void* handle); + static Win32Result winsockCall(int result); + static uint getWin32ErrorCode(); #endif class Context: public ExceptionCallback { public: Context(); - KJ_DISALLOW_COPY(Context); + KJ_DISALLOW_COPY_AND_MOVE(Context); virtual ~Context() noexcept(false); struct Value { @@ -438,7 +497,7 @@ class Debug { class ContextImpl: public Context { public: inline ContextImpl(Func& func): func(func) {} - KJ_DISALLOW_COPY(ContextImpl); + KJ_DISALLOW_COPY_AND_MOVE(ContextImpl); Value evaluate() override { return func(); @@ -494,19 +553,23 @@ inline Debug::Fault::Fault(const char* file, int line, kj::Exception::Type type, init(file, line, type, condition, macroArgs, nullptr); } -#if _WIN32 -inline Debug::Fault::Fault(const char* file, int line, Win32Error osErrorNumber, +#if _WIN32 || __CYGWIN__ +inline Debug::Fault::Fault(const char* file, int line, Win32Result osErrorNumber, const char* condition, const char* macroArgs) : exception(nullptr) { init(file, line, osErrorNumber, condition, macroArgs, nullptr); } -inline bool Debug::isWin32Success(int boolean) { - return boolean; +inline Debug::Win32Result Debug::win32Call(int boolean) { + return boolean ? Win32Result(0) : Win32Result(getWin32ErrorCode()); } -inline bool Debug::isWin32Success(void* handle) { +inline Debug::Win32Result Debug::win32Call(void* handle) { // Assume null and INVALID_HANDLE_VALUE mean failure. - return handle != nullptr && handle != (void*)-1; + return win32Call(handle != nullptr && handle != (void*)-1); +} +inline Debug::Win32Result Debug::winsockCall(int result) { + // Expect a return value of SOCKET_ERROR means failure. + return win32Call(result != -1); } #endif @@ -549,7 +612,126 @@ inline String Debug::makeDescription<>(const char* macroArgs) { return makeDescriptionInternal(macroArgs, nullptr); } +// ======================================================================================= +// Magic Asserts! +// +// When KJ_ASSERT(foo == bar) fails, `foo` and `bar`'s actual values will be stringified in the +// error message. How does it work? We use template magic and operator precedence. The assertion +// actually evaluates something like this: +// +// if (auto _kjCondition = kj::_::MAGIC_ASSERT << foo == bar) +// +// `<<` has operator precedence slightly above `==`, so `kj::_::MAGIC_ASSERT << foo` gets evaluated +// first. This wraps `foo` in a little wrapper that captures the comparison operators and keeps +// enough information around to be able to stringify the left and right sides of the comparison +// independently. As always, the stringification only actually occurs if the assert fails. +// +// You might ask why we use operator `<<` and not e.g. operator `<=`, since operators of the same +// precedence are evaluated left-to-right. The answer is that some compilers trigger all sorts of +// warnings when you seem to be using a comparison as the input to another comparison. The +// particular warning GCC produces is its general "-Wparentheses" warning which is broadly useful, +// so we don't want to disable it. `<<` also produces some warnings, but only on Clang and the +// specific warning is one we're comfortable disabling (see below). This does mean that we have to +// explicitly overload `operator<<` ourselves to make sure using it in an assert still works. +// +// You might also ask, if we're using operator `<<` anyway, why not start it from the right, in +// which case it would bind after computing any `<<` operators that were actually in the user's +// code? I tried this, but it resulted in a somewhat broader warning from clang that I felt worse +// about disabling (a warning about `<<` precedence not applying specifically to overloads) and +// also created ambiguous overload errors in the KJ units code. + +#if __clang__ +// We intentionally overload operator << for the specific purpose of evaluating it before +// evaluating comparison expressions, so stop Clang from warning about it. Unfortunately this means +// eliminating a warning that would otherwise be useful for people using iostreams... sorry. +#pragma GCC diagnostic ignored "-Woverloaded-shift-op-parentheses" +#endif + +template +struct DebugExpression; + +template ()))> +inline auto tryToCharSequence(T* value) { return kj::toCharSequence(*value); } +inline StringPtr tryToCharSequence(...) { return "(can't stringify)"_kj; } +// SFINAE to stringify a value if and only if it can be stringified. + +template +struct DebugComparison { + Left left; + Right right; + StringPtr op; + bool result; + + inline operator bool() const { return KJ_LIKELY(result); } + + template inline void operator&(T&& other) = delete; + template inline void operator^(T&& other) = delete; + template inline void operator|(T&& other) = delete; +}; + +template +String KJ_STRINGIFY(DebugComparison& cmp) { + return _::concat(tryToCharSequence(&cmp.left), cmp.op, tryToCharSequence(&cmp.right)); +} + +template +struct DebugExpression { + DebugExpression(T&& value): value(kj::fwd(value)) {} + T value; + + // Handle comparison operations by constructing a DebugComparison value. +#define DEFINE_OPERATOR(OP) \ + template \ + DebugComparison operator OP(U&& other) { \ + bool result = value OP other; \ + return { kj::fwd(value), kj::fwd(other), " " #OP " "_kj, result }; \ + } + DEFINE_OPERATOR(==); + DEFINE_OPERATOR(!=); + DEFINE_OPERATOR(<=); + DEFINE_OPERATOR(>=); + DEFINE_OPERATOR(< ); + DEFINE_OPERATOR(> ); +#undef DEFINE_OPERATOR + + // Handle binary operators that have equal or lower precedence than comparisons by performing + // the operation and wrapping the result. +#define DEFINE_OPERATOR(OP) \ + template inline auto operator OP(U&& other) { \ + return DebugExpression(value) OP kj::fwd(other))>(\ + kj::fwd(value) OP kj::fwd(other)); \ + } + DEFINE_OPERATOR(<<); + DEFINE_OPERATOR(>>); + DEFINE_OPERATOR(&); + DEFINE_OPERATOR(^); + DEFINE_OPERATOR(|); +#undef DEFINE_OPERATOR + + inline operator bool() { + // No comparison performed, we're just asserting the expression is truthy. This also covers + // the case of the logic operators && and || -- we cannot overload those because doing so would + // break short-circuiting behavior. + return value; + } +}; + +template +StringPtr KJ_STRINGIFY(const DebugExpression& exp) { + // Hack: This will only ever be called in cases where the expression's truthiness was asserted + // directly, and was determined to be falsy. + return "false"_kj; +} + +struct DebugExpressionStart { + template + DebugExpression operator<<(T&& value) const { + return DebugExpression(kj::fwd(value)); + } +}; +static constexpr DebugExpressionStart MAGIC_ASSERT; + } // namespace _ (private) } // namespace kj -#endif // KJ_DEBUG_H_ +KJ_END_HEADER diff --git a/c++/src/kj/encoding-test.c++ b/c++/src/kj/encoding-test.c++ new file mode 100644 index 0000000000..50b1223dda --- /dev/null +++ b/c++/src/kj/encoding-test.c++ @@ -0,0 +1,535 @@ +// Copyright (c) 2017 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "encoding.h" +#include +#include + +namespace kj { +namespace { + +CappedArray hex(byte i) { return kj::hex((uint8_t )i); } +CappedArray hex(char i) { return kj::hex((uint8_t )i); } +CappedArray hex(char16_t i) { return kj::hex((uint16_t)i); } +CappedArray hex(char32_t i) { return kj::hex((uint32_t)i); } +CappedArray hex(wchar_t i) { return kj::hex((uint32_t)i); } +// Hexify chars correctly. +// +// TODO(cleanup): Should this go into string.h with the other definitions of hex()? + +template +void expectResImpl(EncodingResult result, + ArrayPtr expected, + bool errors = false) { + if (errors) { + KJ_EXPECT(result.hadErrors); + } else { + KJ_EXPECT(!result.hadErrors); + } + + KJ_EXPECT(result.size() == expected.size(), result.size(), expected.size()); + for (auto i: kj::zeroTo(kj::min(result.size(), expected.size()))) { + KJ_EXPECT(result[i] == expected[i], i, hex(result[i]), hex(expected[i])); + } +} + +template +void expectRes(EncodingResult result, + const U (&expected)[s], + bool errors = false) { + expectResImpl(kj::mv(result), arrayPtr(expected, s - 1), errors); +} + +#if __cplusplus >= 202000L +template +void expectRes(EncodingResult result, + const char8_t (&expected)[s], + bool errors = false) { + expectResImpl(kj::mv(result), arrayPtr(reinterpret_cast(expected), s - 1), errors); +} +#endif + +template +void expectRes(EncodingResult result, + byte (&expected)[s], + bool errors = false) { + expectResImpl(kj::mv(result), arrayPtr(expected, s), errors); +} + +// Handy reference for surrogate pair edge cases: +// +// \ud800 -> \xed\xa0\x80 +// \udc00 -> \xed\xb0\x80 +// \udbff -> \xed\xaf\xbf +// \udfff -> \xed\xbf\xbf + +KJ_TEST("encode UTF-8 to UTF-16") { + expectRes(encodeUtf16(u8"foo"), u"foo"); + expectRes(encodeUtf16(u8"Здравствуйте"), u"Здравствуйте"); + expectRes(encodeUtf16(u8"中国网络"), u"中国网络"); + expectRes(encodeUtf16(u8"😺☁☄🐵"), u"😺☁☄🐵"); +} + +KJ_TEST("invalid UTF-8 to UTF-16") { + // Disembodied continuation byte. + expectRes(encodeUtf16("\x80"), u"\ufffd", true); + expectRes(encodeUtf16("f\xbfo"), u"f\ufffdo", true); + expectRes(encodeUtf16("f\xbf\x80\xb0o"), u"f\ufffdo", true); + + // Missing continuation bytes. + expectRes(encodeUtf16("\xc2x"), u"\ufffdx", true); + expectRes(encodeUtf16("\xe0x"), u"\ufffdx", true); + expectRes(encodeUtf16("\xe0\xa0x"), u"\ufffdx", true); + expectRes(encodeUtf16("\xf0x"), u"\ufffdx", true); + expectRes(encodeUtf16("\xf0\x90x"), u"\ufffdx", true); + expectRes(encodeUtf16("\xf0\x90\x80x"), u"\ufffdx", true); + + // Overlong sequences. + expectRes(encodeUtf16("\xc0\x80"), u"\ufffd", true); + expectRes(encodeUtf16("\xc1\xbf"), u"\ufffd", true); + expectRes(encodeUtf16("\xc2\x80"), u"\u0080", false); + expectRes(encodeUtf16("\xdf\xbf"), u"\u07ff", false); + + expectRes(encodeUtf16("\xe0\x80\x80"), u"\ufffd", true); + expectRes(encodeUtf16("\xe0\x9f\xbf"), u"\ufffd", true); + expectRes(encodeUtf16("\xe0\xa0\x80"), u"\u0800", false); + expectRes(encodeUtf16("\xef\xbf\xbe"), u"\ufffe", false); + + // Due to a classic off-by-one error, GCC 4.x rather hilariously encodes '\uffff' as the + // "surrogate pair" 0xd7ff, 0xdfff: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=41698 + if (kj::size(u"\uffff") == 2) { + expectRes(encodeUtf16("\xef\xbf\xbf"), u"\uffff", false); + } + + expectRes(encodeUtf16("\xf0\x80\x80\x80"), u"\ufffd", true); + expectRes(encodeUtf16("\xf0\x8f\xbf\xbf"), u"\ufffd", true); + expectRes(encodeUtf16("\xf0\x90\x80\x80"), u"\U00010000", false); + expectRes(encodeUtf16("\xf4\x8f\xbf\xbf"), u"\U0010ffff", false); + + // Out of Unicode range. + expectRes(encodeUtf16("\xf5\x80\x80\x80"), u"\ufffd", true); + expectRes(encodeUtf16("\xf8\xbf\x80\x80\x80"), u"\ufffd", true); + expectRes(encodeUtf16("\xfc\xbf\x80\x80\x80\x80"), u"\ufffd", true); + expectRes(encodeUtf16("\xfe\xbf\x80\x80\x80\x80\x80"), u"\ufffd", true); + expectRes(encodeUtf16("\xff\xbf\x80\x80\x80\x80\x80\x80"), u"\ufffd", true); + + // Surrogates encoded as separate UTF-8 code points are flagged as errors but allowed to decode + // to UTF-16 surrogate values. + expectRes(encodeUtf16("\xed\xb0\x80\xed\xaf\xbf"), u"\xdc00\xdbff", true); + expectRes(encodeUtf16("\xed\xbf\xbf\xed\xa0\x80"), u"\xdfff\xd800", true); + + expectRes(encodeUtf16("\xed\xb0\x80\xed\xbf\xbf"), u"\xdc00\xdfff", true); + expectRes(encodeUtf16("f\xed\xa0\x80"), u"f\xd800", true); + expectRes(encodeUtf16("f\xed\xa0\x80x"), u"f\xd800x", true); + expectRes(encodeUtf16("f\xed\xa0\x80\xed\xa0\x80x"), u"f\xd800\xd800x", true); + + // However, if successive UTF-8 codepoints decode to a proper surrogate pair, the second + // surrogate is replaced with the Unicode replacement character to avoid creating valid UTF-16. + expectRes(encodeUtf16("\xed\xa0\x80\xed\xbf\xbf"), u"\xd800\xfffd", true); + expectRes(encodeUtf16("\xed\xaf\xbf\xed\xb0\x80"), u"\xdbff\xfffd", true); +} + +KJ_TEST("encode UTF-8 to UTF-32") { + expectRes(encodeUtf32(u8"foo"), U"foo"); + expectRes(encodeUtf32(u8"Здравствуйте"), U"Здравствуйте"); + expectRes(encodeUtf32(u8"中国网络"), U"中国网络"); + expectRes(encodeUtf32(u8"😺☁☄🐵"), U"😺☁☄🐵"); +} + +KJ_TEST("invalid UTF-8 to UTF-32") { + // Disembodied continuation byte. + expectRes(encodeUtf32("\x80"), U"\ufffd", true); + expectRes(encodeUtf32("f\xbfo"), U"f\ufffdo", true); + expectRes(encodeUtf32("f\xbf\x80\xb0o"), U"f\ufffdo", true); + + // Missing continuation bytes. + expectRes(encodeUtf32("\xc2x"), U"\ufffdx", true); + expectRes(encodeUtf32("\xe0x"), U"\ufffdx", true); + expectRes(encodeUtf32("\xe0\xa0x"), U"\ufffdx", true); + expectRes(encodeUtf32("\xf0x"), U"\ufffdx", true); + expectRes(encodeUtf32("\xf0\x90x"), U"\ufffdx", true); + expectRes(encodeUtf32("\xf0\x90\x80x"), U"\ufffdx", true); + + // Overlong sequences. + expectRes(encodeUtf32("\xc0\x80"), U"\ufffd", true); + expectRes(encodeUtf32("\xc1\xbf"), U"\ufffd", true); + expectRes(encodeUtf32("\xc2\x80"), U"\u0080", false); + expectRes(encodeUtf32("\xdf\xbf"), U"\u07ff", false); + + expectRes(encodeUtf32("\xe0\x80\x80"), U"\ufffd", true); + expectRes(encodeUtf32("\xe0\x9f\xbf"), U"\ufffd", true); + expectRes(encodeUtf32("\xe0\xa0\x80"), U"\u0800", false); + expectRes(encodeUtf32("\xef\xbf\xbf"), U"\uffff", false); + + expectRes(encodeUtf32("\xf0\x80\x80\x80"), U"\ufffd", true); + expectRes(encodeUtf32("\xf0\x8f\xbf\xbf"), U"\ufffd", true); + expectRes(encodeUtf32("\xf0\x90\x80\x80"), U"\U00010000", false); + expectRes(encodeUtf32("\xf4\x8f\xbf\xbf"), U"\U0010ffff", false); + + // Out of Unicode range. + expectRes(encodeUtf32("\xf5\x80\x80\x80"), U"\ufffd", true); + expectRes(encodeUtf32("\xf8\xbf\x80\x80\x80"), U"\ufffd", true); + expectRes(encodeUtf32("\xfc\xbf\x80\x80\x80\x80"), U"\ufffd", true); + expectRes(encodeUtf32("\xfe\xbf\x80\x80\x80\x80\x80"), U"\ufffd", true); + expectRes(encodeUtf32("\xff\xbf\x80\x80\x80\x80\x80\x80"), U"\ufffd", true); +} + +KJ_TEST("decode UTF-16 to UTF-8") { + expectRes(decodeUtf16(u"foo"), u8"foo"); + expectRes(decodeUtf16(u"Здравствуйте"), u8"Здравствуйте"); + expectRes(decodeUtf16(u"中国网络"), u8"中国网络"); + expectRes(decodeUtf16(u"😺☁☄🐵"), u8"😺☁☄🐵"); +} + +KJ_TEST("invalid UTF-16 to UTF-8") { + // Surrogates in wrong order. + expectRes(decodeUtf16(u"\xdc00\xdbff"), + "\xed\xb0\x80\xed\xaf\xbf", true); + expectRes(decodeUtf16(u"\xdfff\xd800"), + "\xed\xbf\xbf\xed\xa0\x80", true); + + // Missing second surrogate. + expectRes(decodeUtf16(u"f\xd800"), "f\xed\xa0\x80", true); + expectRes(decodeUtf16(u"f\xd800x"), "f\xed\xa0\x80x", true); + expectRes(decodeUtf16(u"f\xd800\xd800x"), "f\xed\xa0\x80\xed\xa0\x80x", true); +} + +KJ_TEST("decode UTF-32 to UTF-8") { + expectRes(decodeUtf32(U"foo"), u8"foo"); + expectRes(decodeUtf32(U"Здравствуйте"), u8"Здравствуйте"); + expectRes(decodeUtf32(U"中国网络"), u8"中国网络"); + expectRes(decodeUtf32(U"😺☁☄🐵"), u8"😺☁☄🐵"); +} + +KJ_TEST("invalid UTF-32 to UTF-8") { + // Surrogates rejected. + expectRes(decodeUtf32(U"\xdfff\xd800"), + "\xed\xbf\xbf\xed\xa0\x80", true); + + // Even if it would be a valid surrogate pair in UTF-16. + expectRes(decodeUtf32(U"\xd800\xdfff"), + "\xed\xa0\x80\xed\xbf\xbf", true); +} + +KJ_TEST("round-trip invalid UTF-16") { + const char16_t INVALID[] = u"\xdfff foo \xd800\xdc00 bar \xdc00\xd800 baz \xdbff qux \xd800"; + + expectRes(encodeUtf16(decodeUtf16(INVALID)), INVALID, true); + expectRes(encodeUtf16(decodeUtf32(encodeUtf32(decodeUtf16(INVALID)))), INVALID, true); +} + +KJ_TEST("EncodingResult as a Maybe") { + KJ_IF_MAYBE(result, encodeUtf16("\x80")) { + KJ_FAIL_EXPECT("expected failure"); + } + + KJ_IF_MAYBE(result, encodeUtf16("foo")) { + // good + } else { + KJ_FAIL_EXPECT("expected success"); + } + + KJ_EXPECT(KJ_ASSERT_NONNULL(decodeUtf16(u"foo")) == "foo"); +} + +KJ_TEST("encode to wchar_t") { + expectRes(encodeWideString(u8"foo"), L"foo"); + expectRes(encodeWideString(u8"Здравствуйте"), L"Здравствуйте"); + expectRes(encodeWideString(u8"中国网络"), L"中国网络"); + expectRes(encodeWideString(u8"😺☁☄🐵"), L"😺☁☄🐵"); +} + +KJ_TEST("decode from wchar_t") { + expectRes(decodeWideString(L"foo"), u8"foo"); + expectRes(decodeWideString(L"Здравствуйте"), u8"Здравствуйте"); + expectRes(decodeWideString(L"中国网络"), u8"中国网络"); + expectRes(decodeWideString(L"😺☁☄🐵"), u8"😺☁☄🐵"); +} + +// ======================================================================================= + +KJ_TEST("hex encoding/decoding") { + byte bytes[] = {0x12, 0x34, 0xab, 0xf2}; + + KJ_EXPECT(encodeHex(bytes) == "1234abf2"); + + expectRes(decodeHex("1234abf2"), bytes); + + expectRes(decodeHex("1234abf21"), bytes, true); + + bytes[2] = 0xa0; + expectRes(decodeHex("1234axf2"), bytes, true); + + bytes[2] = 0x0b; + expectRes(decodeHex("1234xbf2"), bytes, true); +} + +constexpr char RFC2396_FRAGMENT_SET_DIFF[] = "#$&+,/:;=?@[\\]^{|}"; +// These are the characters reserved in RFC 2396, but not in the fragment percent encode set. + +KJ_TEST("URI encoding/decoding") { + KJ_EXPECT(encodeUriComponent("foo") == "foo"); + KJ_EXPECT(encodeUriComponent("foo bar") == "foo%20bar"); + KJ_EXPECT(encodeUriComponent("\xab\xba") == "%AB%BA"); + KJ_EXPECT(encodeUriComponent(StringPtr("foo\0bar", 7)) == "foo%00bar"); + + KJ_EXPECT(encodeUriComponent(RFC2396_FRAGMENT_SET_DIFF) == + "%23%24%26%2B%2C%2F%3A%3B%3D%3F%40%5B%5C%5D%5E%7B%7C%7D"); + + // Encode characters reserved by application/x-www-form-urlencoded, but not by RFC 2396. + KJ_EXPECT(encodeUriComponent("'foo'! (~)") == "'foo'!%20(~)"); + + expectRes(decodeUriComponent("foo%20bar"), "foo bar"); + expectRes(decodeUriComponent("%ab%BA"), "\xab\xba"); + + expectRes(decodeUriComponent("foo%1xxx"), "foo\1xxx", true); + expectRes(decodeUriComponent("foo%1"), "foo\1", true); + expectRes(decodeUriComponent("foo%xxx"), "fooxxx", true); + expectRes(decodeUriComponent("foo%"), "foo", true); + + { + byte bytes[] = {12, 34, 56}; + KJ_EXPECT(decodeBinaryUriComponent(encodeUriComponent(bytes)).asPtr() == bytes); + + // decodeBinaryUriComponent() takes a DecodeUriOptions struct as its second parameter, but it + // once took a single `bool nulTerminate`. Verify that the old behavior still compiles and + // works. + auto bytesWithNul = decodeBinaryUriComponent(encodeUriComponent(bytes), true); + KJ_ASSERT(bytesWithNul.size() == 4); + KJ_EXPECT(bytesWithNul[3] == '\0'); + KJ_EXPECT(bytesWithNul.slice(0, 3) == bytes); + } +} + +KJ_TEST("URL component encoding") { + KJ_EXPECT(encodeUriFragment("foo") == "foo"); + KJ_EXPECT(encodeUriFragment("foo bar") == "foo%20bar"); + KJ_EXPECT(encodeUriFragment("\xab\xba") == "%AB%BA"); + KJ_EXPECT(encodeUriFragment(StringPtr("foo\0bar", 7)) == "foo%00bar"); + + KJ_EXPECT(encodeUriFragment(RFC2396_FRAGMENT_SET_DIFF) == RFC2396_FRAGMENT_SET_DIFF); + + KJ_EXPECT(encodeUriPath("foo") == "foo"); + KJ_EXPECT(encodeUriPath("foo bar") == "foo%20bar"); + KJ_EXPECT(encodeUriPath("\xab\xba") == "%AB%BA"); + KJ_EXPECT(encodeUriPath(StringPtr("foo\0bar", 7)) == "foo%00bar"); + + KJ_EXPECT(encodeUriPath(RFC2396_FRAGMENT_SET_DIFF) == "%23$&+,%2F:;=%3F@[%5C]^%7B|%7D"); + + KJ_EXPECT(encodeUriUserInfo("foo") == "foo"); + KJ_EXPECT(encodeUriUserInfo("foo bar") == "foo%20bar"); + KJ_EXPECT(encodeUriUserInfo("\xab\xba") == "%AB%BA"); + KJ_EXPECT(encodeUriUserInfo(StringPtr("foo\0bar", 7)) == "foo%00bar"); + + KJ_EXPECT(encodeUriUserInfo(RFC2396_FRAGMENT_SET_DIFF) == + "%23$&+,%2F%3A%3B%3D%3F%40%5B%5C%5D%5E%7B%7C%7D"); + + // NOTE: None of these functions have explicit decode equivalents. +} + +KJ_TEST("application/x-www-form-urlencoded encoding/decoding") { + KJ_EXPECT(encodeWwwForm("foo") == "foo"); + KJ_EXPECT(encodeWwwForm("foo bar") == "foo+bar"); + KJ_EXPECT(encodeWwwForm("\xab\xba") == "%AB%BA"); + KJ_EXPECT(encodeWwwForm(StringPtr("foo\0bar", 7)) == "foo%00bar"); + + // Encode characters reserved by application/x-www-form-urlencoded, but not by RFC 2396. + KJ_EXPECT(encodeWwwForm("'foo'! (~)") == "%27foo%27%21+%28%7E%29"); + + expectRes(decodeWwwForm("foo%20bar"), "foo bar"); + expectRes(decodeWwwForm("foo+bar"), "foo bar"); + expectRes(decodeWwwForm("%ab%BA"), "\xab\xba"); + + expectRes(decodeWwwForm("foo%1xxx"), "foo\1xxx", true); + expectRes(decodeWwwForm("foo%1"), "foo\1", true); + expectRes(decodeWwwForm("foo%xxx"), "fooxxx", true); + expectRes(decodeWwwForm("foo%"), "foo", true); + + { + byte bytes[] = {12, 34, 56}; + DecodeUriOptions options { /*.nulTerminate=*/false, /*.plusToSpace=*/true }; + KJ_EXPECT(decodeBinaryUriComponent(encodeWwwForm(bytes), options) == bytes); + } +} + +KJ_TEST("C escape encoding/decoding") { + KJ_EXPECT(encodeCEscape("fooo\a\b\f\n\r\t\v\'\"\\barПривет, Мир! Ж=О") == + "fooo\\a\\b\\f\\n\\r\\t\\v\\\'\\\"\\\\bar\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82\x2c\x20\xd0\x9c\xd0\xb8\xd1\x80\x21\x20\xd0\x96\x3d\xd0\x9e"); + KJ_EXPECT(encodeCEscape("foo\x01\x7fxxx") == + "foo\\001\\177xxx"); + byte bytes[] = {'f', 'o', 'o', 0, '\x01', '\x7f', 'x', 'x', 'x', 128, 254, 255}; + KJ_EXPECT(encodeCEscape(bytes) == "foo\\000\\001\\177xxx\\200\\376\\377"); + + expectRes(decodeCEscape("fooo\\a\\b\\f\\n\\r\\t\\v\\\'\\\"\\\\bar"), + "fooo\a\b\f\n\r\t\v\'\"\\bar"); + expectRes(decodeCEscape("foo\\x01\\x7fxxx"), "foo\x01\x7fxxx"); + expectRes(decodeCEscape("foo\\001\\177234"), "foo\001\177234"); + expectRes(decodeCEscape("foo\\x1"), "foo\x1"); + expectRes(decodeCEscape("foo\\1"), "foo\1"); + + expectRes(decodeCEscape("foo\\u1234bar"), u8"foo\u1234bar"); + expectRes(decodeCEscape("foo\\U00045678bar"), u8"foo\U00045678bar"); + + // Error cases. + expectRes(decodeCEscape("foo\\"), "foo", true); + expectRes(decodeCEscape("foo\\x123x"), u8"foo\x23x", true); + expectRes(decodeCEscape("foo\\u12"), u8"foo\u0012", true); + expectRes(decodeCEscape("foo\\u12xxx"), u8"foo\u0012xxx", true); + expectRes(decodeCEscape("foo\\U12"), u8"foo\u0012", true); + expectRes(decodeCEscape("foo\\U12xxxxxxxx"), u8"foo\u0012xxxxxxxx", true); +} + +KJ_TEST("base64 encoding/decoding") { + { + auto encoded = encodeBase64(StringPtr("").asBytes(), false); + KJ_EXPECT(encoded == "", encoded, encoded.size()); + KJ_EXPECT(heapString(decodeBase64(encoded.asArray()).asChars()) == ""); + } + + { + auto encoded = encodeBase64(StringPtr("foo").asBytes(), false); + KJ_EXPECT(encoded == "Zm9v", encoded, encoded.size()); + auto decoded = decodeBase64(encoded.asArray()); + KJ_EXPECT(!decoded.hadErrors); + KJ_EXPECT(heapString(decoded.asChars()) == "foo"); + } + + { + auto encoded = encodeBase64(StringPtr("quux").asBytes(), false); + KJ_EXPECT(encoded == "cXV1eA==", encoded, encoded.size()); + KJ_EXPECT(heapString(decodeBase64(encoded.asArray()).asChars()) == "quux"); + } + + { + auto encoded = encodeBase64(StringPtr("corge").asBytes(), false); + KJ_EXPECT(encoded == "Y29yZ2U=", encoded); + auto decoded = decodeBase64(encoded.asArray()); + KJ_EXPECT(!decoded.hadErrors); + KJ_EXPECT(heapString(decoded.asChars()) == "corge"); + } + + { + auto decoded = decodeBase64("Y29yZ2U"); + KJ_EXPECT(!decoded.hadErrors); + KJ_EXPECT(heapString(decoded.asChars()) == "corge"); + } + + { + auto decoded = decodeBase64("Y\n29y Z@2U=\n"); + KJ_EXPECT(decoded.hadErrors); // @-sign is invalid base64 input. + KJ_EXPECT(heapString(decoded.asChars()) == "corge"); + } + + { + auto decoded = decodeBase64("Y\n29y Z2U=\n"); + KJ_EXPECT(!decoded.hadErrors); + KJ_EXPECT(heapString(decoded.asChars()) == "corge"); + } + + // Too much padding. + KJ_EXPECT(decodeBase64("Y29yZ2U==").hadErrors); + KJ_EXPECT(decodeBase64("Y29yZ===").hadErrors); + + // Non-terminal padding. + KJ_EXPECT(decodeBase64("ab=c").hadErrors); + + { + auto encoded = encodeBase64(StringPtr("corge").asBytes(), true); + KJ_EXPECT(encoded == "Y29yZ2U=\n", encoded); + } + + StringPtr fullLine = "012345678901234567890123456789012345678901234567890123"; + { + auto encoded = encodeBase64(fullLine.asBytes(), false); + KJ_EXPECT( + encoded == "MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIz", + encoded); + } + { + auto encoded = encodeBase64(fullLine.asBytes(), true); + KJ_EXPECT( + encoded == "MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIz\n", + encoded); + } + + String multiLine = str(fullLine, "456"); + { + auto encoded = encodeBase64(multiLine.asBytes(), false); + KJ_EXPECT( + encoded == "MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2", + encoded); + } + { + auto encoded = encodeBase64(multiLine.asBytes(), true); + KJ_EXPECT( + encoded == "MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIz\n" + "NDU2\n", + encoded); + } +} + +KJ_TEST("base64 url encoding") { + { + // Handles empty. + auto encoded = encodeBase64Url(StringPtr("").asBytes()); + KJ_EXPECT(encoded == "", encoded, encoded.size()); + } + + { + // Handles paddingless encoding. + auto encoded = encodeBase64Url(StringPtr("foo").asBytes()); + KJ_EXPECT(encoded == "Zm9v", encoded, encoded.size()); + } + + { + // Handles padded encoding. + auto encoded1 = encodeBase64Url(StringPtr("quux").asBytes()); + KJ_EXPECT(encoded1 == "cXV1eA", encoded1, encoded1.size()); + auto encoded2 = encodeBase64Url(StringPtr("corge").asBytes()); + KJ_EXPECT(encoded2 == "Y29yZ2U", encoded2, encoded2.size()); + } + + { + // No line breaks. + StringPtr fullLine = "012345678901234567890123456789012345678901234567890123"; + auto encoded = encodeBase64Url(StringPtr(fullLine).asBytes()); + KJ_EXPECT( + encoded == "MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIz", + encoded); + } + + { + // Replaces plusses. + const byte data[] = { 0b11111011, 0b11101111, 0b10111110 }; + auto encoded = encodeBase64Url(data); + KJ_EXPECT(encoded == "----", encoded, encoded.size(), data); + } + + { + // Replaces slashes. + const byte data[] = { 0b11111111, 0b11111111, 0b11111111 }; + auto encoded = encodeBase64Url(data); + KJ_EXPECT(encoded == "____", encoded, encoded.size(), data); + } +} + +} // namespace +} // namespace kj diff --git a/c++/src/kj/encoding.c++ b/c++/src/kj/encoding.c++ new file mode 100644 index 0000000000..06ef3ab78e --- /dev/null +++ b/c++/src/kj/encoding.c++ @@ -0,0 +1,1030 @@ +// Copyright (c) 2017 Cloudflare, Inc.; Sandstorm Development Group, Inc.; and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "encoding.h" +#include "vector.h" +#include "debug.h" + +namespace kj { + +namespace { + +#define GOTO_ERROR_IF(cond) if (KJ_UNLIKELY(cond)) goto error + +inline void addChar32(Vector& vec, char32_t u) { + // Encode as surrogate pair. + u -= 0x10000; + vec.add(0xd800 | (u >> 10)); + vec.add(0xdc00 | (u & 0x03ff)); +} + +inline void addChar32(Vector& vec, char32_t u) { + vec.add(u); +} + +template +EncodingResult> encodeUtf(ArrayPtr text, bool nulTerminate) { + Vector result(text.size() + nulTerminate); + bool hadErrors = false; + + size_t i = 0; + while (i < text.size()) { + byte c = text[i++]; + if (c < 0x80) { + // 0xxxxxxx -- ASCII + result.add(c); + continue; + } else if (KJ_UNLIKELY(c < 0xc0)) { + // 10xxxxxx -- malformed continuation byte + goto error; + } else if (c < 0xe0) { + // 110xxxxx -- 2-byte + byte c2; + GOTO_ERROR_IF(i == text.size() || ((c2 = text[i]) & 0xc0) != 0x80); ++i; + char16_t u = (static_cast(c & 0x1f) << 6) + | (static_cast(c2 & 0x3f) ); + + // Disallow overlong sequence. + GOTO_ERROR_IF(u < 0x80); + + result.add(u); + continue; + } else if (c < 0xf0) { + // 1110xxxx -- 3-byte + byte c2, c3; + GOTO_ERROR_IF(i == text.size() || ((c2 = text[i]) & 0xc0) != 0x80); ++i; + GOTO_ERROR_IF(i == text.size() || ((c3 = text[i]) & 0xc0) != 0x80); ++i; + char16_t u = (static_cast(c & 0x0f) << 12) + | (static_cast(c2 & 0x3f) << 6) + | (static_cast(c3 & 0x3f) ); + + // Disallow overlong sequence. + GOTO_ERROR_IF(u < 0x0800); + + // Flag surrogate pair code points as errors, but allow them through. + if (KJ_UNLIKELY((u & 0xf800) == 0xd800)) { + if (result.size() > 0 && + (u & 0xfc00) == 0xdc00 && + (result.back() & 0xfc00) == 0xd800) { + // Whoops, the *previous* character was also an invalid surrogate, and if we add this + // one too, they'll form a valid surrogate pair. If we allowed this, then it would mean + // invalid UTF-8 round-tripped to UTF-16 and back could actually change meaning entirely. + // OTOH, the reason we allow dangling surrogates is to allow invalid UTF-16 to round-trip + // to UTF-8 without loss, but if the original UTF-16 had a valid surrogate pair, it would + // have been encoded as a valid single UTF-8 codepoint, not as separate UTF-8 codepoints + // for each surrogate. + goto error; + } + + hadErrors = true; + } + + result.add(u); + continue; + } else if (c < 0xf8) { + // 11110xxx -- 4-byte + byte c2, c3, c4; + GOTO_ERROR_IF(i == text.size() || ((c2 = text[i]) & 0xc0) != 0x80); ++i; + GOTO_ERROR_IF(i == text.size() || ((c3 = text[i]) & 0xc0) != 0x80); ++i; + GOTO_ERROR_IF(i == text.size() || ((c4 = text[i]) & 0xc0) != 0x80); ++i; + char32_t u = (static_cast(c & 0x07) << 18) + | (static_cast(c2 & 0x3f) << 12) + | (static_cast(c3 & 0x3f) << 6) + | (static_cast(c4 & 0x3f) ); + + // Disallow overlong sequence. + GOTO_ERROR_IF(u < 0x10000); + + // Unicode ends at U+10FFFF + GOTO_ERROR_IF(u >= 0x110000); + + addChar32(result, u); + continue; + } else { + // 5-byte and 6-byte sequences are not legal as they'd result in codepoints outside the + // range of Unicode. + goto error; + } + + error: + result.add(0xfffd); + hadErrors = true; + // Ignore all continuation bytes. + while (i < text.size() && (text[i] & 0xc0) == 0x80) { + ++i; + } + } + + if (nulTerminate) result.add(0); + + return { result.releaseAsArray(), hadErrors }; +} + +} // namespace + +EncodingResult> encodeUtf16(ArrayPtr text, bool nulTerminate) { + return encodeUtf(text, nulTerminate); +} + +EncodingResult> encodeUtf32(ArrayPtr text, bool nulTerminate) { + return encodeUtf(text, nulTerminate); +} + +EncodingResult decodeUtf16(ArrayPtr utf16) { + Vector result(utf16.size() + 1); + bool hadErrors = false; + + size_t i = 0; + while (i < utf16.size()) { + char16_t u = utf16[i++]; + + if (u < 0x80) { + result.add(u); + continue; + } else if (u < 0x0800) { + result.addAll>({ + static_cast(((u >> 6) ) | 0xc0), + static_cast(((u ) & 0x3f) | 0x80) + }); + continue; + } else if ((u & 0xf800) == 0xd800) { + // surrogate pair + char16_t u2; + if (KJ_UNLIKELY(i == utf16.size() // missing second half + || (u & 0x0400) != 0 // first half in wrong range + || ((u2 = utf16[i]) & 0xfc00) != 0xdc00)) { // second half in wrong range + hadErrors = true; + goto threeByte; + } + ++i; + + char32_t u32 = (((u & 0x03ff) << 10) | (u2 & 0x03ff)) + 0x10000; + result.addAll>({ + static_cast(((u32 >> 18) ) | 0xf0), + static_cast(((u32 >> 12) & 0x3f) | 0x80), + static_cast(((u32 >> 6) & 0x3f) | 0x80), + static_cast(((u32 ) & 0x3f) | 0x80) + }); + continue; + } else { + threeByte: + result.addAll>({ + static_cast(((u >> 12) ) | 0xe0), + static_cast(((u >> 6) & 0x3f) | 0x80), + static_cast(((u ) & 0x3f) | 0x80) + }); + continue; + } + } + + result.add(0); + return { String(result.releaseAsArray()), hadErrors }; +} + +EncodingResult decodeUtf32(ArrayPtr utf16) { + Vector result(utf16.size() + 1); + bool hadErrors = false; + + size_t i = 0; + while (i < utf16.size()) { + char32_t u = utf16[i++]; + + if (u < 0x80) { + result.add(u); + continue; + } else if (u < 0x0800) { + result.addAll>({ + static_cast(((u >> 6) ) | 0xc0), + static_cast(((u ) & 0x3f) | 0x80) + }); + continue; + } else if (u < 0x10000) { + if (KJ_UNLIKELY((u & 0xfffff800) == 0xd800)) { + // no surrogates allowed in utf-32 + hadErrors = true; + } + result.addAll>({ + static_cast(((u >> 12) ) | 0xe0), + static_cast(((u >> 6) & 0x3f) | 0x80), + static_cast(((u ) & 0x3f) | 0x80) + }); + continue; + } else { + GOTO_ERROR_IF(u >= 0x110000); // outside Unicode range + result.addAll>({ + static_cast(((u >> 18) ) | 0xf0), + static_cast(((u >> 12) & 0x3f) | 0x80), + static_cast(((u >> 6) & 0x3f) | 0x80), + static_cast(((u ) & 0x3f) | 0x80) + }); + continue; + } + + error: + result.addAll(StringPtr(u8"\ufffd")); + hadErrors = true; + } + + result.add(0); + return { String(result.releaseAsArray()), hadErrors }; +} + +namespace { + +#if __GNUC__ >= 8 && !__clang__ +// GCC 8's new class-memaccess warning rightly dislikes the following hacks, but we're really sure +// we want to allow them so disable the warning. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#endif + +template +Array coerceTo(Array&& array) { + static_assert(sizeof(To) == sizeof(From), "incompatible coercion"); + Array result; + memcpy(&result, &array, sizeof(array)); + memset(&array, 0, sizeof(array)); + return result; +} + +template +ArrayPtr coerceTo(ArrayPtr array) { + static_assert(sizeof(To) == sizeof(From), "incompatible coercion"); + return arrayPtr(reinterpret_cast(array.begin()), array.size()); +} + +template +EncodingResult> coerceTo(EncodingResult>&& result) { + return { coerceTo(Array(kj::mv(result))), result.hadErrors }; +} + +#if __GNUC__ >= 8 && !__clang__ +#pragma GCC diagnostic pop +#endif + +template +struct WideConverter; + +template <> +struct WideConverter { + typedef char Type; + + static EncodingResult> encode(ArrayPtr text, bool nulTerminate) { + auto result = heapArray(text.size() + nulTerminate); + memcpy(result.begin(), text.begin(), text.size()); + if (nulTerminate) result.back() = 0; + return { kj::mv(result), false }; + } + + static EncodingResult decode(ArrayPtr text) { + return { kj::heapString(text), false }; + } +}; + +template <> +struct WideConverter { + typedef char16_t Type; + + static inline EncodingResult> encode( + ArrayPtr text, bool nulTerminate) { + return encodeUtf16(text, nulTerminate); + } + + static inline EncodingResult decode(ArrayPtr text) { + return decodeUtf16(text); + } +}; + +template <> +struct WideConverter { + typedef char32_t Type; + + static inline EncodingResult> encode( + ArrayPtr text, bool nulTerminate) { + return encodeUtf32(text, nulTerminate); + } + + static inline EncodingResult decode(ArrayPtr text) { + return decodeUtf32(text); + } +}; + +} // namespace + +EncodingResult> encodeWideString(ArrayPtr text, bool nulTerminate) { + return coerceTo(WideConverter::encode(text, nulTerminate)); +} +EncodingResult decodeWideString(ArrayPtr wide) { + using Converter = WideConverter; + return Converter::decode(coerceTo(wide)); +} + +// ======================================================================================= + +namespace { + +const char HEX_DIGITS[] = "0123456789abcdef"; +// Maps integer in the range [0,16) to a hex digit. + +const char HEX_DIGITS_URI[] = "0123456789ABCDEF"; +// RFC 3986 section 2.1 says "For consistency, URI producers and normalizers should use uppercase +// hexadecimal digits for all percent-encodings. + +static Maybe tryFromHexDigit(char c) { + if ('0' <= c && c <= '9') { + return c - '0'; + } else if ('a' <= c && c <= 'f') { + return c - ('a' - 10); + } else if ('A' <= c && c <= 'F') { + return c - ('A' - 10); + } else { + return nullptr; + } +} + +static Maybe tryFromOctDigit(char c) { + if ('0' <= c && c <= '7') { + return c - '0'; + } else { + return nullptr; + } +} + +} // namespace + +String encodeHex(ArrayPtr input) { + return strArray(KJ_MAP(b, input) { + return heapArray({HEX_DIGITS[b/16], HEX_DIGITS[b%16]}); + }, ""); +} + +EncodingResult> decodeHex(ArrayPtr text) { + auto result = heapArray(text.size() / 2); + bool hadErrors = text.size() % 2; + + for (auto i: kj::indices(result)) { + byte b = 0; + KJ_IF_MAYBE(d1, tryFromHexDigit(text[i*2])) { + b = *d1 << 4; + } else { + hadErrors = true; + } + KJ_IF_MAYBE(d2, tryFromHexDigit(text[i*2+1])) { + b |= *d2; + } else { + hadErrors = true; + } + result[i] = b; + } + + return { kj::mv(result), hadErrors }; +} + +String encodeUriComponent(ArrayPtr bytes) { + Vector result(bytes.size() + 1); + for (byte b: bytes) { + if (('A' <= b && b <= 'Z') || + ('a' <= b && b <= 'z') || + ('0' <= b && b <= '9') || + b == '-' || b == '_' || b == '.' || b == '!' || b == '~' || b == '*' || b == '\'' || + b == '(' || b == ')') { + result.add(b); + } else { + result.add('%'); + result.add(HEX_DIGITS_URI[b/16]); + result.add(HEX_DIGITS_URI[b%16]); + } + } + result.add('\0'); + return String(result.releaseAsArray()); +} + +String encodeUriFragment(ArrayPtr bytes) { + Vector result(bytes.size() + 1); + for (byte b: bytes) { + if (('?' <= b && b <= '_') || // covers A-Z + ('a' <= b && b <= '~') || // covers a-z + ('&' <= b && b <= ';') || // covers 0-9 + b == '!' || b == '=' || b == '#' || b == '$') { + result.add(b); + } else { + result.add('%'); + result.add(HEX_DIGITS_URI[b/16]); + result.add(HEX_DIGITS_URI[b%16]); + } + } + result.add('\0'); + return String(result.releaseAsArray()); +} + +String encodeUriPath(ArrayPtr bytes) { + Vector result(bytes.size() + 1); + for (byte b: bytes) { + if (('@' <= b && b <= '[') || // covers A-Z + ('a' <= b && b <= 'z') || + ('0' <= b && b <= ';') || // covers 0-9 + ('&' <= b && b <= '.') || + b == '_' || b == '!' || b == '=' || b == ']' || + b == '^' || b == '|' || b == '~' || b == '$') { + result.add(b); + } else { + result.add('%'); + result.add(HEX_DIGITS_URI[b/16]); + result.add(HEX_DIGITS_URI[b%16]); + } + } + result.add('\0'); + return String(result.releaseAsArray()); +} + +String encodeUriUserInfo(ArrayPtr bytes) { + Vector result(bytes.size() + 1); + for (byte b: bytes) { + if (('A' <= b && b <= 'Z') || + ('a' <= b && b <= 'z') || + ('0' <= b && b <= '9') || + ('&' <= b && b <= '.') || + b == '_' || b == '!' || b == '~' || b == '$') { + result.add(b); + } else { + result.add('%'); + result.add(HEX_DIGITS_URI[b/16]); + result.add(HEX_DIGITS_URI[b%16]); + } + } + result.add('\0'); + return String(result.releaseAsArray()); +} + +String encodeWwwForm(ArrayPtr bytes) { + Vector result(bytes.size() + 1); + for (byte b: bytes) { + if (('A' <= b && b <= 'Z') || + ('a' <= b && b <= 'z') || + ('0' <= b && b <= '9') || + b == '-' || b == '_' || b == '.' || b == '*') { + result.add(b); + } else if (b == ' ') { + result.add('+'); + } else { + result.add('%'); + result.add(HEX_DIGITS_URI[b/16]); + result.add(HEX_DIGITS_URI[b%16]); + } + } + result.add('\0'); + return String(result.releaseAsArray()); +} + +EncodingResult> decodeBinaryUriComponent( + ArrayPtr text, DecodeUriOptions options) { + Vector result(text.size() + options.nulTerminate); + bool hadErrors = false; + + const char* ptr = text.begin(); + const char* end = text.end(); + while (ptr < end) { + if (*ptr == '%') { + ++ptr; + + if (ptr == end) { + hadErrors = true; + } else KJ_IF_MAYBE(d1, tryFromHexDigit(*ptr)) { + byte b = *d1; + ++ptr; + if (ptr == end) { + hadErrors = true; + } else KJ_IF_MAYBE(d2, tryFromHexDigit(*ptr)) { + b = (b << 4) | *d2; + ++ptr; + } else { + hadErrors = true; + } + result.add(b); + } else { + hadErrors = true; + } + } else if (options.plusToSpace && *ptr == '+') { + ++ptr; + result.add(' '); + } else { + result.add(*ptr++); + } + } + + if (options.nulTerminate) result.add(0); + return { result.releaseAsArray(), hadErrors }; +} + +// ======================================================================================= + +namespace _ { // private + +String encodeCEscapeImpl(ArrayPtr bytes, bool isBinary) { + Vector escaped(bytes.size()); + + for (byte b: bytes) { + switch (b) { + case '\a': escaped.addAll(StringPtr("\\a")); break; + case '\b': escaped.addAll(StringPtr("\\b")); break; + case '\f': escaped.addAll(StringPtr("\\f")); break; + case '\n': escaped.addAll(StringPtr("\\n")); break; + case '\r': escaped.addAll(StringPtr("\\r")); break; + case '\t': escaped.addAll(StringPtr("\\t")); break; + case '\v': escaped.addAll(StringPtr("\\v")); break; + case '\'': escaped.addAll(StringPtr("\\\'")); break; + case '\"': escaped.addAll(StringPtr("\\\"")); break; + case '\\': escaped.addAll(StringPtr("\\\\")); break; + default: + if (b < 0x20 || b == 0x7f || (isBinary && b > 0x7f)) { + // Use octal escape, not hex, because hex escapes technically have no length limit and + // so can create ambiguity with subsequent characters. + escaped.add('\\'); + escaped.add(HEX_DIGITS[b / 64]); + escaped.add(HEX_DIGITS[(b / 8) % 8]); + escaped.add(HEX_DIGITS[b % 8]); + } else { + escaped.add(b); + } + break; + } + } + + escaped.add(0); + return String(escaped.releaseAsArray()); +} + +} // namespace + +EncodingResult> decodeBinaryCEscape(ArrayPtr text, bool nulTerminate) { + Vector result(text.size() + nulTerminate); + bool hadErrors = false; + + size_t i = 0; + while (i < text.size()) { + char c = text[i++]; + if (c == '\\') { + if (i == text.size()) { + hadErrors = true; + continue; + } + char c2 = text[i++]; + switch (c2) { + case 'a' : result.add('\a'); break; + case 'b' : result.add('\b'); break; + case 'f' : result.add('\f'); break; + case 'n' : result.add('\n'); break; + case 'r' : result.add('\r'); break; + case 't' : result.add('\t'); break; + case 'v' : result.add('\v'); break; + case '\'': result.add('\''); break; + case '\"': result.add('\"'); break; + case '\\': result.add('\\'); break; + + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': { + uint value = c2 - '0'; + for (uint j = 0; j < 2 && i < text.size(); j++) { + KJ_IF_MAYBE(d, tryFromOctDigit(text[i])) { + ++i; + value = (value << 3) | *d; + } else { + break; + } + } + if (value >= 0x100) hadErrors = true; + result.add(value); + break; + } + + case 'x': { + uint value = 0; + while (i < text.size()) { + KJ_IF_MAYBE(d, tryFromHexDigit(text[i])) { + ++i; + value = (value << 4) | *d; + } else { + break; + } + } + if (value >= 0x100) hadErrors = true; + result.add(value); + break; + } + + case 'u': { + char16_t value = 0; + for (uint j = 0; j < 4; j++) { + if (i == text.size()) { + hadErrors = true; + break; + } else KJ_IF_MAYBE(d, tryFromHexDigit(text[i])) { + ++i; + value = (value << 4) | *d; + } else { + hadErrors = true; + break; + } + } + auto utf = decodeUtf16(arrayPtr(&value, 1)); + if (utf.hadErrors) hadErrors = true; + result.addAll(utf.asBytes()); + break; + } + + case 'U': { + char32_t value = 0; + for (uint j = 0; j < 8; j++) { + if (i == text.size()) { + hadErrors = true; + break; + } else KJ_IF_MAYBE(d, tryFromHexDigit(text[i])) { + ++i; + value = (value << 4) | *d; + } else { + hadErrors = true; + break; + } + } + auto utf = decodeUtf32(arrayPtr(&value, 1)); + if (utf.hadErrors) hadErrors = true; + result.addAll(utf.asBytes()); + break; + } + + default: + result.add(c2); + } + } else { + result.add(c); + } + } + + if (nulTerminate) result.add(0); + return { result.releaseAsArray(), hadErrors }; +} + +// ======================================================================================= +// This code is derived from libb64 which has been placed in the public domain. +// For details, see http://sourceforge.net/projects/libb64 + +// ------------------------------------------------------------------- +// Encoder + +namespace { + +typedef enum { + step_A, step_B, step_C +} base64_encodestep; + +typedef struct { + base64_encodestep step; + char result; + int stepcount; +} base64_encodestate; + +const int CHARS_PER_LINE = 72; + +void base64_init_encodestate(base64_encodestate* state_in) { + state_in->step = step_A; + state_in->result = 0; + state_in->stepcount = 0; +} + +char base64_encode_value(char value_in) { + static const char* encoding = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + if (value_in > 63) return '='; + return encoding[(int)value_in]; +} + +int base64_encode_block(const char* plaintext_in, int length_in, + char* code_out, base64_encodestate* state_in, bool breakLines) { + const char* plainchar = plaintext_in; + const char* const plaintextend = plaintext_in + length_in; + char* codechar = code_out; + char result; + char fragment; + + result = state_in->result; + + switch (state_in->step) { + while (1) { + KJ_FALLTHROUGH; + case step_A: + if (plainchar == plaintextend) { + state_in->result = result; + state_in->step = step_A; + return codechar - code_out; + } + fragment = *plainchar++; + result = (fragment & 0x0fc) >> 2; + *codechar++ = base64_encode_value(result); + result = (fragment & 0x003) << 4; + KJ_FALLTHROUGH; + case step_B: + if (plainchar == plaintextend) { + state_in->result = result; + state_in->step = step_B; + return codechar - code_out; + } + fragment = *plainchar++; + result |= (fragment & 0x0f0) >> 4; + *codechar++ = base64_encode_value(result); + result = (fragment & 0x00f) << 2; + KJ_FALLTHROUGH; + case step_C: + if (plainchar == plaintextend) { + state_in->result = result; + state_in->step = step_C; + return codechar - code_out; + } + fragment = *plainchar++; + result |= (fragment & 0x0c0) >> 6; + *codechar++ = base64_encode_value(result); + result = (fragment & 0x03f) >> 0; + *codechar++ = base64_encode_value(result); + + ++(state_in->stepcount); + if (breakLines && state_in->stepcount == CHARS_PER_LINE/4) { + *codechar++ = '\n'; + state_in->stepcount = 0; + } + } + } + /* control should not reach here */ + return codechar - code_out; +} + +int base64_encode_blockend(char* code_out, base64_encodestate* state_in, bool breakLines) { + char* codechar = code_out; + + switch (state_in->step) { + case step_B: + *codechar++ = base64_encode_value(state_in->result); + *codechar++ = '='; + *codechar++ = '='; + ++state_in->stepcount; + break; + case step_C: + *codechar++ = base64_encode_value(state_in->result); + *codechar++ = '='; + ++state_in->stepcount; + break; + case step_A: + break; + } + if (breakLines && state_in->stepcount > 0) { + *codechar++ = '\n'; + } + + return codechar - code_out; +} + +} // namespace + +String encodeBase64(ArrayPtr input, bool breakLines) { + /* set up a destination buffer large enough to hold the encoded data */ + // equivalent to ceil(input.size() / 3) * 4 + auto numChars = (input.size() + 2) / 3 * 4; + if (breakLines) { + // Add space for newline characters. + uint lineCount = numChars / CHARS_PER_LINE; + if (numChars % CHARS_PER_LINE > 0) { + // Partial line. + ++lineCount; + } + numChars = numChars + lineCount; + } + auto output = heapString(numChars); + /* keep track of our encoded position */ + char* c = output.begin(); + /* store the number of bytes encoded by a single call */ + int cnt = 0; + size_t total = 0; + /* we need an encoder state */ + base64_encodestate s; + + /*---------- START ENCODING ----------*/ + /* initialise the encoder state */ + base64_init_encodestate(&s); + /* gather data from the input and send it to the output */ + cnt = base64_encode_block((const char *)input.begin(), input.size(), c, &s, breakLines); + c += cnt; + total += cnt; + + /* since we have encoded the entire input string, we know that + there is no more input data; finalise the encoding */ + cnt = base64_encode_blockend(c, &s, breakLines); + c += cnt; + total += cnt; + /*---------- STOP ENCODING ----------*/ + + KJ_ASSERT(total == output.size(), total, output.size()); + + return output; +} + +// ------------------------------------------------------------------- +// Decoder + +namespace { + +typedef enum { + step_a, step_b, step_c, step_d +} base64_decodestep; + +struct base64_decodestate { + bool hadErrors = false; + size_t nPaddingBytesSeen = 0; + // Output state. `nPaddingBytesSeen` is not guaranteed to be correct if `hadErrors` is true. It is + // included in the state purely to preserve the streaming capability of the algorithm while still + // checking for errors correctly (consider chunk 1 = "abc=", chunk 2 = "d"). + + base64_decodestep step = step_a; + char plainchar = 0; +}; + +int base64_decode_value(char value_in) { + // Returns either the fragment value or: -1 on whitespace, -2 on padding, -3 on invalid input. + // + // Note that the original libb64 implementation used -1 for invalid input, -2 on padding -- this + // new scheme allows for some simpler error checks in steps A and B. + + static const signed char decoding[] = { + -3,-3,-3,-3,-3,-3,-3,-3, -3,-1,-1,-3,-1,-1,-3,-3, + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + -1,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,62,-3,-3,-3,63, + 52,53,54,55,56,57,58,59, 60,61,-3,-3,-3,-2,-3,-3, + -3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14, + 15,16,17,18,19,20,21,22, 23,24,25,-3,-3,-3,-3,-3, + -3,26,27,28,29,30,31,32, 33,34,35,36,37,38,39,40, + 41,42,43,44,45,46,47,48, 49,50,51,-3,-3,-3,-3,-3, + + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + -3,-3,-3,-3,-3,-3,-3,-3, -3,-3,-3,-3,-3,-3,-3,-3, + }; + static_assert(sizeof(decoding) == 256, "base64 decoding table size error"); + return decoding[(unsigned char)value_in]; +} + +int base64_decode_block(const char* code_in, const int length_in, + char* plaintext_out, base64_decodestate* state_in) { + const char* codechar = code_in; + char* plainchar = plaintext_out; + signed char fragment; + + if (state_in->step != step_a) { + *plainchar = state_in->plainchar; + } + +#define ERROR_IF(predicate) state_in->hadErrors = state_in->hadErrors || (predicate) + + switch (state_in->step) + { + while (1) + { + KJ_FALLTHROUGH; + case step_a: + do { + if (codechar == code_in+length_in) { + state_in->step = step_a; + state_in->plainchar = '\0'; + return plainchar - plaintext_out; + } + fragment = (signed char)base64_decode_value(*codechar++); + // It is an error to see invalid or padding bytes in step A. + ERROR_IF(fragment < -1); + } while (fragment < 0); + *plainchar = (fragment & 0x03f) << 2; + KJ_FALLTHROUGH; + case step_b: + do { + if (codechar == code_in+length_in) { + state_in->step = step_b; + state_in->plainchar = *plainchar; + // It is always an error to suspend from step B, because we don't have enough bits yet. + // TODO(someday): This actually breaks the streaming use case, if base64_decode_block() is + // to be called multiple times. We'll fix it if we ever care to support streaming. + state_in->hadErrors = true; + return plainchar - plaintext_out; + } + fragment = (signed char)base64_decode_value(*codechar++); + // It is an error to see invalid or padding bytes in step B. + ERROR_IF(fragment < -1); + } while (fragment < 0); + *plainchar++ |= (fragment & 0x030) >> 4; + *plainchar = (fragment & 0x00f) << 4; + KJ_FALLTHROUGH; + case step_c: + do { + if (codechar == code_in+length_in) { + state_in->step = step_c; + state_in->plainchar = *plainchar; + // It is an error to complete from step C if we have seen incomplete padding. + // TODO(someday): This actually breaks the streaming use case, if base64_decode_block() is + // to be called multiple times. We'll fix it if we ever care to support streaming. + ERROR_IF(state_in->nPaddingBytesSeen == 1); + return plainchar - plaintext_out; + } + fragment = (signed char)base64_decode_value(*codechar++); + // It is an error to see invalid bytes or more than two padding bytes in step C. + ERROR_IF(fragment < -2 || (fragment == -2 && ++state_in->nPaddingBytesSeen > 2)); + } while (fragment < 0); + // It is an error to continue from step C after having seen any padding. + ERROR_IF(state_in->nPaddingBytesSeen > 0); + *plainchar++ |= (fragment & 0x03c) >> 2; + *plainchar = (fragment & 0x003) << 6; + KJ_FALLTHROUGH; + case step_d: + do { + if (codechar == code_in+length_in) { + state_in->step = step_d; + state_in->plainchar = *plainchar; + return plainchar - plaintext_out; + } + fragment = (signed char)base64_decode_value(*codechar++); + // It is an error to see invalid bytes or more than one padding byte in step D. + ERROR_IF(fragment < -2 || (fragment == -2 && ++state_in->nPaddingBytesSeen > 1)); + } while (fragment < 0); + // It is an error to continue from step D after having seen padding bytes. + ERROR_IF(state_in->nPaddingBytesSeen > 0); + *plainchar++ |= (fragment & 0x03f); + } + } + +#undef ERROR_IF + + /* control should not reach here */ + return plainchar - plaintext_out; +} + +} // namespace + +EncodingResult> decodeBase64(ArrayPtr input) { + base64_decodestate state; + + auto output = heapArray((input.size() * 6 + 7) / 8); + + size_t n = base64_decode_block(input.begin(), input.size(), + reinterpret_cast(output.begin()), &state); + + if (n < output.size()) { + auto copy = heapArray(n); + memcpy(copy.begin(), output.begin(), n); + output = kj::mv(copy); + } + + return EncodingResult>(kj::mv(output), state.hadErrors); +} + +String encodeBase64Url(ArrayPtr bytes) { + // TODO(perf): Rewrite as single pass? + // TODO(someday): Write decoder? + + auto base64 = kj::encodeBase64(bytes); + + for (char& c: base64) { + if (c == '+') c = '-'; + if (c == '/') c = '_'; + } + + // Remove trailing '='s. + kj::ArrayPtr slice = base64; + while (slice.size() > 0 && slice.back() == '=') { + slice = slice.slice(0, slice.size() - 1); + } + + return kj::str(slice); +} + +} // namespace kj diff --git a/c++/src/kj/encoding.h b/c++/src/kj/encoding.h new file mode 100644 index 0000000000..d61ee473b5 --- /dev/null +++ b/c++/src/kj/encoding.h @@ -0,0 +1,445 @@ +// Copyright (c) 2017 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once +// Functions for encoding/decoding bytes and text in common formats, including: +// - UTF-{8,16,32} +// - Hex +// - URI encoding +// - Base64 + +#include "string.h" + +KJ_BEGIN_HEADER + +namespace kj { + +template +struct EncodingResult: public ResultType { + // Equivalent to ResultType (a String or wide-char array) for all intents and purposes, except + // that the bool `hadErrors` can be inspected to see if any errors were encountered in the input. + // Each encoding/decoding function that returns this type will "work around" errors in some way, + // so an application doesn't strictly have to check for errors. E.g. the Unicode functions + // replace errors with U+FFFD in the output. + // + // Through magic, KJ_IF_MAYBE() and KJ_{REQUIRE,ASSERT}_NONNULL() work on EncodingResult + // exactly if it were a Maybe that is null in case of errors. + + inline EncodingResult(ResultType&& result, bool hadErrors) + : ResultType(kj::mv(result)), hadErrors(hadErrors) {} + + const bool hadErrors; +}; + +template +inline auto KJ_STRINGIFY(const EncodingResult& value) + -> decltype(toCharSequence(implicitCast(value))) { + return toCharSequence(implicitCast(value)); +} + +EncodingResult> encodeUtf16(ArrayPtr text, bool nulTerminate = false); +EncodingResult> encodeUtf32(ArrayPtr text, bool nulTerminate = false); +// Convert UTF-8 text (which KJ strings use) to UTF-16 or UTF-32. +// +// If `nulTerminate` is true, an extra NUL character will be added to the end of the output. +// +// The returned arrays are in platform-native endianness (otherwise they wouldn't really be +// char16_t / char32_t). +// +// Note that the KJ Unicode encoding and decoding functions actually implement +// [WTF-8 encoding](http://simonsapin.github.io/wtf-8/), which affects how invalid input is +// handled. See comments on decodeUtf16() for more info. + +EncodingResult decodeUtf16(ArrayPtr utf16); +EncodingResult decodeUtf32(ArrayPtr utf32); +// Convert UTF-16 or UTF-32 to UTF-8 (which KJ strings use). +// +// The input should NOT include a NUL terminator; any NUL characters in the input array will be +// preserved in the output. +// +// The input must be in platform-native endianness. BOMs are NOT recognized by these functions. +// +// Note that the KJ Unicode encoding and decoding functions actually implement +// [WTF-8 encoding](http://simonsapin.github.io/wtf-8/). This means that if you start with an array +// of char16_t and you pass it through any number of conversions to other Unicode encodings, +// eventually returning it to UTF-16, all the while ignoring `hadErrors`, you will end up with +// exactly the same char16_t array you started with, *even if* the array is not valid UTF-16. This +// is useful because many real-world systems that were designed for UCS-2 (plain 16-bit Unicode) +// and later "upgraded" to UTF-16 do not enforce that their UTF-16 is well-formed. For example, +// file names on Windows NT are encoded using 16-bit characters, without enforcing that the +// character sequence is valid UTF-16. It is important that programs on Windows be able to handle +// such filenames, even if they choose to convert the name to UTF-8 for internal processing. +// +// Specifically, KJ's Unicode handling allows unpaired surrogate code points to round-trip through +// UTF-8 and UTF-32. Unpaired surrogates will be flagged as an error (setting `hadErrors` in the +// result), but will NOT be replaced with the Unicode replacement character as other erroneous +// sequences would be, but rather encoded as an invalid surrogate codepoint in the target encoding. +// +// KJ makes the following guarantees about invalid input: +// - A round trip from UTF-16 to other encodings and back will produce exactly the original input, +// with every leg of the trip raising the `hadErrors` flag if the original input was not valid. +// - A round trip from UTF-8 or UTF-32 to other encodings and back will either produce exactly +// the original input, or will have replaced some invalid sequences with the Unicode replacement +// character, U+FFFD. No code units will ever be removed unless they are replaced with U+FFFD, +// and no code units will ever be added except to encode U+FFFD. If the original input was not +// valid, the `hadErrors` flag will be raised on the first leg of the trip, and will also be +// raised on subsequent legs unless all invalid sequences were replaced with U+FFFD (which, after +// all, is a valid code point). + +EncodingResult> encodeWideString( + ArrayPtr text, bool nulTerminate = false); +EncodingResult decodeWideString(ArrayPtr wide); +// Encode / decode strings of wchar_t, aka "wide strings". Unfortunately, different platforms have +// different definitions for wchar_t. For example, on Windows they are 16-bit and encode UTF-16, +// but on Linux they are 32-bit and encode UTF-32. Some platforms even define wchar_t as 8-bit, +// encoding UTF-8 (e.g. BeOS did this). +// +// KJ assumes that wide strings use the UTF encoding that corresponds to the size of wchar_t on +// the target platform. So, these functions are simple aliases for encodeUtf*/decodeUtf*, above +// (or simply make a copy if wchar_t is 8 bits). + +String encodeHex(ArrayPtr bytes); +EncodingResult> decodeHex(ArrayPtr text); +// Encode/decode bytes as hex strings. + +String encodeUriComponent(ArrayPtr bytes); +String encodeUriComponent(ArrayPtr bytes); +EncodingResult decodeUriComponent(ArrayPtr text); +// Encode/decode URI components using % escapes for characters listed as "reserved" in RFC 2396. +// This is the same behavior as JavaScript's `encodeURIComponent()`. +// +// See https://tools.ietf.org/html/rfc2396#section-2.3 + +String encodeUriFragment(ArrayPtr bytes); +String encodeUriFragment(ArrayPtr bytes); +// Encode URL fragment components using the fragment percent encode set defined by the WHATWG URL +// specification. Use decodeUriComponent() to decode. +// +// Quirk: We also percent-encode the '%' sign itself, because we expect to be called on percent- +// decoded data. In other words, this function is not idempotent, in contrast to the URL spec. +// +// See https://url.spec.whatwg.org/#fragment-percent-encode-set + +String encodeUriPath(ArrayPtr bytes); +String encodeUriPath(ArrayPtr bytes); +// Encode URL path components (not entire paths!) using the path percent encode set defined by the +// WHATWG URL specification. Use decodeUriComponent() to decode. +// +// Quirk: We also percent-encode the '%' sign itself, because we expect to be called on percent- +// decoded data. In other words, this function is not idempotent, in contrast to the URL spec. +// +// Quirk: This percent-encodes '/' and '\' characters as well, which are not actually in the set +// defined by the WHATWG URL spec. Since a conforming URL implementation will only ever call this +// function on individual path components, and never entire paths, augmenting the character set to +// include these separators allows this function to be used to implement a URL class that stores +// its path components in percent-decoded form. +// +// See https://url.spec.whatwg.org/#path-percent-encode-set + +String encodeUriUserInfo(ArrayPtr bytes); +String encodeUriUserInfo(ArrayPtr bytes); +// Encode URL userinfo components using the userinfo percent encode set defined by the WHATWG URL +// specification. Use decodeUriComponent() to decode. +// +// Quirk: We also percent-encode the '%' sign itself, because we expect to be called on percent- +// decoded data. In other words, this function is not idempotent, in contrast to the URL spec. +// +// See https://url.spec.whatwg.org/#userinfo-percent-encode-set + +String encodeWwwForm(ArrayPtr bytes); +String encodeWwwForm(ArrayPtr bytes); +EncodingResult decodeWwwForm(ArrayPtr text); +// Encode/decode URI components using % escapes and '+' (for spaces) according to the +// application/x-www-form-urlencoded format defined by the WHATWG URL specification. +// +// Note: Like the fragment, path, and userinfo percent-encoding functions above, this function is +// not idempotent: we percent-encode '%' signs. However, in this particular case the spec happens +// to agree with us! +// +// See https://url.spec.whatwg.org/#concept-urlencoded-byte-serializer + +struct DecodeUriOptions { + // Parameter to `decodeBinaryUriComponent()`. + + // This struct is intentionally convertible from bool, in order to maintain backwards + // compatibility with code written when `decodeBinaryUriComponent()` took a boolean second + // parameter. + DecodeUriOptions(bool nulTerminate = false, bool plusToSpace = false) + : nulTerminate(nulTerminate), plusToSpace(plusToSpace) {} + + bool nulTerminate; + // Append a terminal NUL byte. + + bool plusToSpace; + // Convert '+' to ' ' characters before percent decoding. Used to decode + // application/x-www-form-urlencoded text, such as query strings. +}; +EncodingResult> decodeBinaryUriComponent( + ArrayPtr text, DecodeUriOptions options = DecodeUriOptions()); +// Decode URI components using % escapes. This is a lower-level interface used to implement both +// `decodeUriComponent()` and `decodeWwwForm()` + +String encodeCEscape(ArrayPtr bytes); +String encodeCEscape(ArrayPtr bytes); +EncodingResult> decodeBinaryCEscape( + ArrayPtr text, bool nulTerminate = false); +EncodingResult decodeCEscape(ArrayPtr text); + +String encodeBase64(ArrayPtr bytes, bool breakLines = false); +// Encode the given bytes as base64 text. If `breakLines` is true, line breaks will be inserted +// into the output every 72 characters (e.g. for encoding e-mail bodies). + +EncodingResult> decodeBase64(ArrayPtr text); +// Decode base64 text. This function reports errors required by the WHATWG HTML/Infra specs: see +// https://html.spec.whatwg.org/multipage/webappapis.html#atob for details. + +String encodeBase64Url(ArrayPtr bytes); +// Encode the given bytes as URL-safe base64 text. (RFC 4648, section 5) + +// ======================================================================================= +// inline implementation details + +namespace _ { // private + +template +NullableValue readMaybe(EncodingResult&& value) { + if (value.hadErrors) { + return nullptr; + } else { + return kj::mv(value); + } +} + +template +T* readMaybe(EncodingResult& value) { + if (value.hadErrors) { + return nullptr; + } else { + return &value; + } +} + +template +const T* readMaybe(const EncodingResult& value) { + if (value.hadErrors) { + return nullptr; + } else { + return &value; + } +} + +String encodeCEscapeImpl(ArrayPtr bytes, bool isBinary); + +} // namespace _ (private) + +inline String encodeUriComponent(ArrayPtr text) { + return encodeUriComponent(text.asBytes()); +} +inline EncodingResult decodeUriComponent(ArrayPtr text) { + auto result = decodeBinaryUriComponent(text, DecodeUriOptions { /*.nulTerminate=*/true }); + return { String(result.releaseAsChars()), result.hadErrors }; +} + +inline String encodeUriFragment(ArrayPtr text) { + return encodeUriFragment(text.asBytes()); +} +inline String encodeUriPath(ArrayPtr text) { + return encodeUriPath(text.asBytes()); +} +inline String encodeUriUserInfo(ArrayPtr text) { + return encodeUriUserInfo(text.asBytes()); +} + +inline String encodeWwwForm(ArrayPtr text) { + return encodeWwwForm(text.asBytes()); +} +inline EncodingResult decodeWwwForm(ArrayPtr text) { + auto result = decodeBinaryUriComponent(text, DecodeUriOptions { /*.nulTerminate=*/true, + /*.plusToSpace=*/true }); + return { String(result.releaseAsChars()), result.hadErrors }; +} + +inline String encodeCEscape(ArrayPtr text) { + return _::encodeCEscapeImpl(text.asBytes(), false); +} + +inline String encodeCEscape(ArrayPtr bytes) { + return _::encodeCEscapeImpl(bytes, true); +} + +inline EncodingResult decodeCEscape(ArrayPtr text) { + auto result = decodeBinaryCEscape(text, true); + return { String(result.releaseAsChars()), result.hadErrors }; +} + +// If you pass a string literal to a function taking ArrayPtr, it'll include the NUL +// termintator, which is surprising. Let's add overloads that avoid that. In practice this probably +// only even matters for encoding-test.c++. + +template +inline EncodingResult> encodeUtf16(const char (&text)[s], bool nulTerminate=false) { + return encodeUtf16(arrayPtr(text, s - 1), nulTerminate); +} +template +inline EncodingResult> encodeUtf32(const char (&text)[s], bool nulTerminate=false) { + return encodeUtf32(arrayPtr(text, s - 1), nulTerminate); +} +template +inline EncodingResult> encodeWideString( + const char (&text)[s], bool nulTerminate=false) { + return encodeWideString(arrayPtr(text, s - 1), nulTerminate); +} +template +inline EncodingResult decodeUtf16(const char16_t (&utf16)[s]) { + return decodeUtf16(arrayPtr(utf16, s - 1)); +} +template +inline EncodingResult decodeUtf32(const char32_t (&utf32)[s]) { + return decodeUtf32(arrayPtr(utf32, s - 1)); +} +template +inline EncodingResult decodeWideString(const wchar_t (&utf32)[s]) { + return decodeWideString(arrayPtr(utf32, s - 1)); +} +template +inline EncodingResult> decodeHex(const char (&text)[s]) { + return decodeHex(arrayPtr(text, s - 1)); +} +template +inline String encodeUriComponent(const char (&text)[s]) { + return encodeUriComponent(arrayPtr(text, s - 1)); +} +template +inline Array decodeBinaryUriComponent(const char (&text)[s]) { + return decodeBinaryUriComponent(arrayPtr(text, s - 1)); +} +template +inline EncodingResult decodeUriComponent(const char (&text)[s]) { + return decodeUriComponent(arrayPtr(text, s-1)); +} +template +inline String encodeUriFragment(const char (&text)[s]) { + return encodeUriFragment(arrayPtr(text, s - 1)); +} +template +inline String encodeUriPath(const char (&text)[s]) { + return encodeUriPath(arrayPtr(text, s - 1)); +} +template +inline String encodeUriUserInfo(const char (&text)[s]) { + return encodeUriUserInfo(arrayPtr(text, s - 1)); +} +template +inline String encodeWwwForm(const char (&text)[s]) { + return encodeWwwForm(arrayPtr(text, s - 1)); +} +template +inline EncodingResult decodeWwwForm(const char (&text)[s]) { + return decodeWwwForm(arrayPtr(text, s-1)); +} +template +inline String encodeCEscape(const char (&text)[s]) { + return encodeCEscape(arrayPtr(text, s - 1)); +} +template +inline EncodingResult> decodeBinaryCEscape(const char (&text)[s]) { + return decodeBinaryCEscape(arrayPtr(text, s - 1)); +} +template +inline EncodingResult decodeCEscape(const char (&text)[s]) { + return decodeCEscape(arrayPtr(text, s-1)); +} +template +EncodingResult> decodeBase64(const char (&text)[s]) { + return decodeBase64(arrayPtr(text, s - 1)); +} + +#if __cplusplus >= 202000L +template +inline EncodingResult> encodeUtf16(const char8_t (&text)[s], bool nulTerminate=false) { + return encodeUtf16(arrayPtr(reinterpret_cast(text), s - 1), nulTerminate); +} +template +inline EncodingResult> encodeUtf32(const char8_t (&text)[s], bool nulTerminate=false) { + return encodeUtf32(arrayPtr(reinterpret_cast(text), s - 1), nulTerminate); +} +template +inline EncodingResult> encodeWideString( + const char8_t (&text)[s], bool nulTerminate=false) { + return encodeWideString(arrayPtr(reinterpret_cast(text), s - 1), nulTerminate); +} +template +inline EncodingResult> decodeHex(const char8_t (&text)[s]) { + return decodeHex(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline String encodeUriComponent(const char8_t (&text)[s]) { + return encodeUriComponent(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline Array decodeBinaryUriComponent(const char8_t (&text)[s]) { + return decodeBinaryUriComponent(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline EncodingResult decodeUriComponent(const char8_t (&text)[s]) { + return decodeUriComponent(arrayPtr(reinterpret_cast(text), s-1)); +} +template +inline String encodeUriFragment(const char8_t (&text)[s]) { + return encodeUriFragment(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline String encodeUriPath(const char8_t (&text)[s]) { + return encodeUriPath(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline String encodeUriUserInfo(const char8_t (&text)[s]) { + return encodeUriUserInfo(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline String encodeWwwForm(const char8_t (&text)[s]) { + return encodeWwwForm(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline EncodingResult decodeWwwForm(const char8_t (&text)[s]) { + return decodeWwwForm(arrayPtr(reinterpret_cast(text), s-1)); +} +template +inline String encodeCEscape(const char8_t (&text)[s]) { + return encodeCEscape(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline EncodingResult> decodeBinaryCEscape(const char8_t (&text)[s]) { + return decodeBinaryCEscape(arrayPtr(reinterpret_cast(text), s - 1)); +} +template +inline EncodingResult decodeCEscape(const char8_t (&text)[s]) { + return decodeCEscape(arrayPtr(reinterpret_cast(text), s-1)); +} +template +EncodingResult> decodeBase64(const char8_t (&text)[s]) { + return decodeBase64(arrayPtr(reinterpret_cast(text), s - 1)); +} +#endif + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/exception-override-symbolizer-test.c++ b/c++/src/kj/exception-override-symbolizer-test.c++ new file mode 100644 index 0000000000..bb9de02496 --- /dev/null +++ b/c++/src/kj/exception-override-symbolizer-test.c++ @@ -0,0 +1,49 @@ +// Copyright (c) 2022 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if __GNUC__ && !_WIN32 + +#include "debug.h" +#include +#include "kj/common.h" +#include "kj/array.h" +#include +#include +#include + +namespace kj { + +// override weak symbol +String stringifyStackTrace(ArrayPtr trace) { + return kj::str("\n\nTEST_SYMBOLIZER\n\n"); +} + +namespace { + +KJ_TEST("getStackTrace() uses symbolizer override") { + auto trace = getStackTrace(); + KJ_ASSERT(strstr(trace.cStr(), "TEST_SYMBOLIZER") != nullptr, trace); +} + +} // namespace +} // namespace kj + +#endif diff --git a/c++/src/kj/exception-test.c++ b/c++/src/kj/exception-test.c++ index 0b46c935d8..50054ab24c 100644 --- a/c++/src/kj/exception-test.c++ +++ b/c++/src/kj/exception-test.c++ @@ -22,6 +22,8 @@ #include "exception.h" #include "debug.h" #include +#include +#include namespace kj { namespace _ { // private @@ -29,9 +31,11 @@ namespace { TEST(Exception, TrimSourceFilename) { #if _WIN32 - if (trimSourceFilename(__FILE__) != "kj\\exception-test.c++") -#endif + EXPECT_TRUE(trimSourceFilename(__FILE__) == "kj/exception-test.c++" || + trimSourceFilename(__FILE__) == "kj\\exception-test.c++"); +#else EXPECT_EQ(trimSourceFilename(__FILE__), "kj/exception-test.c++"); +#endif } TEST(Exception, RunCatchingExceptions) { @@ -56,6 +60,36 @@ TEST(Exception, RunCatchingExceptions) { } } +#if !KJ_NO_EXCEPTIONS +TEST(Exception, RunCatchingExceptionsStdException) { + Maybe e = kj::runCatchingExceptions([&]() { + throw std::logic_error("foo"); + }); + + KJ_IF_MAYBE(ex, e) { + EXPECT_EQ("std::exception: foo", ex->getDescription()); + } else { + ADD_FAILURE() << "Expected exception"; + } +} + +TEST(Exception, RunCatchingExceptionsOtherException) { + Maybe e = kj::runCatchingExceptions([&]() { + throw 123; + }); + + KJ_IF_MAYBE(ex, e) { +#if __GNUC__ && !KJ_NO_RTTI + EXPECT_EQ("unknown non-KJ exception of type: int", ex->getDescription()); +#else + EXPECT_EQ("unknown non-KJ exception", ex->getDescription()); +#endif + } else { + ADD_FAILURE() << "Expected exception"; + } +} +#endif + #if !KJ_NO_EXCEPTIONS // We skip this test when exceptions are disabled because making it no-exceptions-safe defeats // the purpose of the test: recoverable exceptions won't throw inside a destructor in the first @@ -98,10 +132,16 @@ TEST(Exception, UnwindDetector) { } #endif +#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) || \ + KJ_HAS_COMPILER_FEATURE(address_sanitizer) || \ + defined(__SANITIZE_ADDRESS__) +// The implementation skips this check in these cases. +#else #if !__MINGW32__ // Inexplicably crashes when exception is thrown from constructor. TEST(Exception, ExceptionCallbackMustBeOnStack) { KJ_EXPECT_THROW_MESSAGE("must be allocated on the stack", new ExceptionCallback); } +#endif #endif // !__MINGW32__ #if !KJ_NO_EXCEPTIONS @@ -138,6 +178,122 @@ TEST(Exception, ScopeSuccessFail) { } #endif +#if __GNUG__ || defined(__clang__) +kj::String testStackTrace() __attribute__((noinline)); +#elif _MSC_VER +__declspec(noinline) kj::String testStackTrace(); +#endif + +kj::String testStackTrace() { + // getStackTrace() normally skips its immediate caller, so we wrap it in another layer. + return getStackTrace(); +} + +KJ_TEST("getStackTrace() returns correct line number, not line + 1") { + // Backtraces normally produce the return address of each stack frame, but that's usually the + // address immediately after the one that made the call. As a result, it used to be that stack + // traces often pointed to the line after the one that made a call, which was confusing. This + // checks that this bug is fixed. + // + // This is not a very robust test, because: + // 1) Since symbolic stack traces are not available in many situations (e.g. release builds + // lacking debug symbols, systems where addr2line isn't present, etc.), we only check that + // the stack trace does *not* contain the *wrong* value, rather than checking that it does + // contain the right one. + // 2) This test only detects the problem if the call instruction to testStackTrace() is the + // *last* instruction attributed to its line of code. Whether or not this is true seems to be + // dependent on obscure compiler behavior. For example, below, it could only be the case if + // RVO is applied -- but in my testing, RVO does seem to be applied here. I tried several + // variations involving passing via an output parameter or a global variable rather than + // returning, but found some variations detected the problem and others didn't, essentially + // at random. + + auto trace = testStackTrace(); + auto wrong = kj::str("exception-test.c++:", __LINE__); + + KJ_ASSERT(strstr(trace.cStr(), wrong.cStr()) == nullptr, trace, wrong); +} + +#if !KJ_NO_EXCEPTIONS +KJ_TEST("InFlightExceptionIterator works") { + bool caught = false; + try { + KJ_DEFER({ + try { + KJ_FAIL_ASSERT("bar"); + } catch (const kj::Exception& e) { + InFlightExceptionIterator iter; + KJ_IF_MAYBE(e2, iter.next()) { + KJ_EXPECT(e2 == &e, e2->getDescription()); + } else { + KJ_FAIL_EXPECT("missing first exception"); + } + + KJ_IF_MAYBE(e2, iter.next()) { + KJ_EXPECT(e2->getDescription() == "foo", e2->getDescription()); + } else { + KJ_FAIL_EXPECT("missing second exception"); + } + + KJ_EXPECT(iter.next() == nullptr, "more than two exceptions"); + + caught = true; + } + }); + KJ_FAIL_ASSERT("foo"); + } catch (const kj::Exception& e) { + // expected + } + + KJ_EXPECT(caught); +} +#endif + +KJ_TEST("computeRelativeTrace") { + auto testCase = [](uint expectedPrefix, + ArrayPtr trace, ArrayPtr relativeTo) { + auto tracePtr = KJ_MAP(x, trace) { return (void*)x; }; + auto relativeToPtr = KJ_MAP(x, relativeTo) { return (void*)x; }; + + auto result = computeRelativeTrace(tracePtr, relativeToPtr); + KJ_EXPECT(result.begin() == tracePtr.begin()); + + KJ_EXPECT(result.size() == expectedPrefix, trace, relativeTo, result); + }; + + testCase(8, + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 4, 3, 2, 1}); + + testCase(5, + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 5, 6, 7, 8}); + + testCase(5, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + {8, 7, 6, 5, 5, 6, 7, 8}); + + testCase(5, + {1, 2, 3, 4, 5, 6, 7, 8, 6, 7, 8}, + {8, 7, 6, 5, 5, 6, 7, 8}); + + testCase(9, + {1, 2, 3, 4, 5, 6, 7, 8, 5, 5, 6, 7, 8}, + {8, 7, 6, 5, 5, 6, 7, 8}); + + testCase(5, + {1, 2, 3, 4, 5, 5, 6, 7, 8, 5, 6, 7, 8}, + {8, 7, 6, 5, 5, 6, 7, 8}); + + testCase(5, + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 5, 6, 7, 8, 7, 8}); + + testCase(5, + {1, 2, 3, 4, 5, 6, 7, 8}, + {8, 7, 6, 5, 6, 7, 8, 7, 8}); +} + } // namespace } // namespace _ (private) } // namespace kj diff --git a/c++/src/kj/exception.c++ b/c++/src/kj/exception.c++ index 2bac1b1aef..246830dea4 100644 --- a/c++/src/kj/exception.c++ +++ b/c++/src/kj/exception.c++ @@ -19,37 +19,93 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#if _WIN32 || __CYGWIN__ +#include "win32-api-version.h" +#endif + +#if (_WIN32 && _M_X64) || (__CYGWIN__ && __x86_64__) +// Currently the Win32 stack-trace code only supports x86_64. We could easily extend it to support +// i386 as well but it requires some code changes around how we read the context to start the +// trace. +#define KJ_USE_WIN32_DBGHELP 1 +#endif + #include "exception.h" #include "string.h" #include "debug.h" #include "threadlocal.h" #include "miniposix.h" +#include "function.h" +#include "main.h" #include #include #include #include +#include #ifndef _WIN32 #include #endif #include "io.h" -#if (__linux__ && __GLIBC__) || __APPLE__ +#if !KJ_NO_RTTI +#include +#endif +#if __GNUC__ +#include +#endif + +#if (__linux__ && __GLIBC__ && !__UCLIBC__) || __APPLE__ #define KJ_HAS_BACKTRACE 1 #include #endif -#if _WIN32 -#define WIN32_LEAN_AND_MEAN +#if _WIN32 || __CYGWIN__ #include #include "windows-sanity.h" #include #endif -#if (__linux__ || __APPLE__) +#if (__linux__ || __APPLE__ || __CYGWIN__) #include #include #endif +#if __CYGWIN__ +#include +#include +#endif + +#if KJ_HAS_LIBDL +#include "dlfcn.h" +#endif + +#if _MSC_VER +#include +#endif + +#if KJ_HAS_COMPILER_FEATURE(address_sanitizer) || defined(__SANITIZE_ADDRESS__) +#include +#else +static void __lsan_ignore_object(const void* p) {} +#endif +// TODO(cleanup): Remove the LSAN stuff per https://github.com/capnproto/capnproto/pull/1255 +// feedback. + +namespace { +template +inline T* lsanIgnoreObjectAndReturn(T* ptr) { + // Defensively lsan_ignore_object since the documentation doesn't explicitly specify what happens + // if you call this multiple times on the same object. + // TODO(cleanup): Remove this per https://github.com/capnproto/capnproto/pull/1255. + __lsan_ignore_object(ptr); + return ptr; +} +} + namespace kj { StringPtr KJ_STRINGIFY(LogSeverity severity) { @@ -64,10 +120,7 @@ StringPtr KJ_STRINGIFY(LogSeverity severity) { return SEVERITY_STRINGS[static_cast(severity)]; } -#if _WIN32 && _M_X64 -// Currently the Win32 stack-trace code only supports x86_64. We could easily extend it to support -// i386 as well but it requires some code changes around how we read the context to start the -// trace. +#if KJ_USE_WIN32_DBGHELP namespace { @@ -89,6 +142,13 @@ struct Dbghelp { BOOL (WINAPI *symGetLineFromAddr64)( HANDLE hProcess,DWORD64 qwAddr,PDWORD pdwDisplacement,PIMAGEHLP_LINE64 Line64); +#if __GNUC__ && !__clang__ && __GNUC__ >= 8 +// GCC 8 warns that our reinterpret_casts of function pointers below are casting between +// incompatible types. Yes, GCC, we know that. This is the nature of GetProcAddress(); it returns +// everything as `long long int (*)()` and we have to cast to the actual type. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wcast-function-type" +#endif Dbghelp() : lib(LoadLibraryA("dbghelp.dll")), symInitialize(lib == nullptr ? nullptr : @@ -110,6 +170,9 @@ struct Dbghelp { symInitialize(GetCurrentProcess(), NULL, TRUE); } } +#if __GNUC__ && !__clang__ && __GNUC__ >= 9 +#pragma GCC diagnostic pop +#endif }; const Dbghelp& getDbghelp() { @@ -119,6 +182,13 @@ const Dbghelp& getDbghelp() { ArrayPtr getStackTrace(ArrayPtr space, uint ignoreCount, HANDLE thread, CONTEXT& context) { + // NOTE: Apparently there is a function CaptureStackBackTrace() that is equivalent to glibc's + // backtrace(). Somehow I missed that when I originally wrote this. However, + // CaptureStackBackTrace() does not accept a CONTEXT parameter; it can only trace the caller. + // That's more problematic on Windows where breakHandler(), sehHandler(), and Cygwin signal + // handlers all depend on the ability to pass a CONTEXT. So we'll keep this code, which works + // after all. + const Dbghelp& dbghelp = getDbghelp(); if (dbghelp.stackWalk64 == nullptr || dbghelp.symFunctionTableAccess64 == nullptr || @@ -146,7 +216,9 @@ ArrayPtr getStackTrace(ArrayPtr space, uint ignoreCount, break; } - space[count] = reinterpret_cast(frame.AddrPC.Offset); + // Subtract 1 from each address so that we identify the calling instructions, rather than the + // return addresses (which are typically the instruction after the call). + space[count] = reinterpret_cast(frame.AddrPC.Offset - 1); } return space.slice(kj::min(ignoreCount, count), count); @@ -160,25 +232,42 @@ ArrayPtr getStackTrace(ArrayPtr space, uint ignoreCount) { return nullptr; } -#if _WIN32 && _M_X64 +#if KJ_USE_WIN32_DBGHELP CONTEXT context; RtlCaptureContext(&context); return getStackTrace(space, ignoreCount, GetCurrentThread(), context); #elif KJ_HAS_BACKTRACE size_t size = backtrace(space.begin(), space.size()); + for (auto& addr: space.slice(0, size)) { + // The addresses produced by backtrace() are return addresses, which means they point to the + // instruction immediately after the call. Invoking addr2line on these can be confusing because + // it often points to the next line. If the next instruction is inlined from another function, + // the trace can be extra-confusing, since now it claims to be in a function that was not + // actually on the call stack. If we subtract 1 from each address, though, we get a much more + // reasonable trace. This may cause the addresses to be invalid instruction pointers if the + // instructions were multi-byte, but it appears addr2line is able to cope with this. + addr = reinterpret_cast(reinterpret_cast(addr) - 1); + } return space.slice(kj::min(ignoreCount + 1, size), size); #else return nullptr; #endif } +#if (__GNUC__ && !_WIN32) || __clang__ +// Allow dependents to override the implementation of stack symbolication by making it a weak +// symbol. We prefer weak symbols over some sort of callback registration mechanism becasue this +// allows an alternate symbolication library to be easily linked into tests without changing the +// code of the test. +__attribute__((weak)) +#endif String stringifyStackTrace(ArrayPtr trace) { if (trace.size() == 0) return nullptr; if (getExceptionCallback().stackTraceMode() != ExceptionCallback::StackTraceMode::FULL) { return nullptr; } -#if _WIN32 && _M_X64 && _MSC_VER +#if KJ_USE_WIN32_DBGHELP && _MSC_VER // Try to get file/line using SymGetLineFromAddr64(). We don't bother if we aren't on MSVC since // this requires MSVC debug info. @@ -196,14 +285,15 @@ String stringifyStackTrace(ArrayPtr trace) { IMAGEHLP_LINE64 lineInfo; memset(&lineInfo, 0, sizeof(lineInfo)); lineInfo.SizeOfStruct = sizeof(lineInfo); - if (dbghelp.symGetLineFromAddr64(process, reinterpret_cast(trace[i]), NULL, &lineInfo)) { + DWORD displacement; + if (dbghelp.symGetLineFromAddr64(process, reinterpret_cast(trace[i]), &displacement, &lineInfo)) { lines[i] = kj::str('\n', lineInfo.FileName, ':', lineInfo.LineNumber); } } return strArray(lines, ""); -#elif (__linux__ || __APPLE__) && !__ANDROID__ +#elif (__linux__ || __APPLE__ || __CYGWIN__) && !__ANDROID__ // We want to generate a human-readable stack trace. // TODO(someday): It would be really great if we could avoid farming out to another process @@ -245,6 +335,16 @@ String stringifyStackTrace(ArrayPtr trace) { // The Mac OS X equivalent of addr2line is atos. // (Internally, it uses the private CoreSymbolication.framework library.) p = popen(str("xcrun atos -p ", getpid(), ' ', strTrace).cStr(), "r"); +#elif __CYGWIN__ + wchar_t exeWinPath[MAX_PATH]; + if (GetModuleFileNameW(nullptr, exeWinPath, sizeof(exeWinPath)) == 0) { + return nullptr; + } + char exePosixPath[MAX_PATH * 2]; + if (cygwin_conv_path(CCP_WIN_W_TO_POSIX, exeWinPath, exePosixPath, sizeof(exePosixPath)) < 0) { + return nullptr; + } + p = popen(str("addr2line -e '", exePosixPath, "' ", strTrace).cStr(), "r"); #endif if (p == nullptr) { @@ -284,13 +384,113 @@ String stringifyStackTrace(ArrayPtr trace) { #endif } +String stringifyStackTraceAddresses(ArrayPtr trace) { +#if KJ_HAS_LIBDL + return strArray(KJ_MAP(addr, trace) { + Dl_info info; + // Shared libraries are mapped near the end of the address space while the executable is mapped + // near the beginning. We want to print addresses in the executable as raw addresses, not + // offsets, since that's what addr2line expects for executables. For shared libraries it + // expects offsets. In any case, most frames are likely to be in the main executable so it + // makes the output cleaner if we don't repeatedly write its name. + if (reinterpret_cast(addr) >= 0x400000000000ull && dladdr(addr, &info)) { + uintptr_t offset = reinterpret_cast(addr) - + reinterpret_cast(info.dli_fbase); + return kj::str(info.dli_fname, '@', reinterpret_cast(offset)); + } else { + return kj::str(addr); + } + }, " "); +#else + // TODO(someday): Support other platforms. + return kj::strArray(trace, " "); +#endif +} + +StringPtr stringifyStackTraceAddresses(ArrayPtr trace, ArrayPtr scratch) { + // Version which writes into a pre-allocated buffer. This is safe for signal handlers to the + // extent that dladdr() is safe. + // + // TODO(cleanup): We should improve the KJ stringification framework so that there's a way to + // write this string directly into a larger message buffer with strPreallocated(). + +#if KJ_HAS_LIBDL + char* ptr = scratch.begin(); + char* limit = scratch.end() - 1; + + for (auto addr: trace) { + Dl_info info; + // Shared libraries are mapped near the end of the address space while the executable is mapped + // near the beginning. We want to print addresses in the executable as raw addresses, not + // offsets, since that's what addr2line expects for executables. For shared libraries it + // expects offsets. In any case, most frames are likely to be in the main executable so it + // makes the output cleaner if we don't repeatedly write its name. + if (reinterpret_cast(addr) >= 0x400000000000ull && dladdr(addr, &info)) { + uintptr_t offset = reinterpret_cast(addr) - + reinterpret_cast(info.dli_fbase); + ptr = _::fillLimited(ptr, limit, kj::StringPtr(info.dli_fname), "@0x"_kj, hex(offset)); + } else { + ptr = _::fillLimited(ptr, limit, toCharSequence(addr)); + } + + ptr = _::fillLimited(ptr, limit, " "_kj); + } + *ptr = '\0'; + return StringPtr(scratch.begin(), ptr); +#else + // TODO(someday): Support other platforms. + return kj::strPreallocated(scratch, kj::delimited(trace, " ")); +#endif +} + String getStackTrace() { void* space[32]; auto trace = getStackTrace(space, 2); - return kj::str(kj::strArray(trace, " "), stringifyStackTrace(trace)); + return kj::str(stringifyStackTraceAddresses(trace), stringifyStackTrace(trace)); +} + +namespace { + +#if !KJ_NO_EXCEPTIONS + +[[noreturn]] void terminateHandler() { + void* traceSpace[32]; + + // ignoreCount = 3 to ignore std::terminate entry. + auto trace = kj::getStackTrace(traceSpace, 3); + + kj::String message; + + auto eptr = std::current_exception(); + if (eptr != nullptr) { + try { + std::rethrow_exception(eptr); + } catch (const kj::Exception& exception) { + message = kj::str("*** Fatal uncaught kj::Exception: ", exception, '\n'); + } catch (const std::exception& exception) { + message = kj::str("*** Fatal uncaught std::exception: ", exception.what(), + "\nstack: ", stringifyStackTraceAddresses(trace), + stringifyStackTrace(trace), '\n'); + } catch (...) { + message = kj::str("*** Fatal uncaught exception of type: ", kj::getCaughtExceptionType(), + "\nstack: ", stringifyStackTraceAddresses(trace), + stringifyStackTrace(trace), '\n'); + } + } else { + message = kj::str("*** std::terminate() called with no exception" + "\nstack: ", stringifyStackTraceAddresses(trace), + stringifyStackTrace(trace), '\n'); + } + + kj::FdOutputStream(STDERR_FILENO).write(message.begin(), message.size()); + _exit(1); } -#if _WIN32 && _M_X64 +#endif + +} // namespace + +#if KJ_USE_WIN32_DBGHELP && !__CYGWIN__ namespace { DWORD mainThreadId = 0; @@ -307,9 +507,10 @@ BOOL WINAPI breakHandler(DWORD type) { context.ContextFlags = CONTEXT_FULL; if (GetThreadContext(thread, &context)) { void* traceSpace[32]; - auto trace = getStackTrace(traceSpace, 2, thread, context); + auto trace = getStackTrace(traceSpace, 0, thread, context); ResumeThread(thread); - auto message = kj::str("*** Received CTRL+C. stack: ", strArray(trace, " "), + auto message = kj::str("*** Received CTRL+C. stack: ", + stringifyStackTraceAddresses(trace), stringifyStackTrace(trace), '\n'); FdOutputStream(STDERR_FILENO).write(message.begin(), message.size()); } else { @@ -327,24 +528,93 @@ BOOL WINAPI breakHandler(DWORD type) { return FALSE; // still crash } +kj::StringPtr exceptionDescription(DWORD code) { + switch (code) { + case EXCEPTION_ACCESS_VIOLATION: return "access violation"; + case EXCEPTION_ARRAY_BOUNDS_EXCEEDED: return "array bounds exceeded"; + case EXCEPTION_BREAKPOINT: return "breakpoint"; + case EXCEPTION_DATATYPE_MISALIGNMENT: return "datatype misalignment"; + case EXCEPTION_FLT_DENORMAL_OPERAND: return "denormal floating point operand"; + case EXCEPTION_FLT_DIVIDE_BY_ZERO: return "floating point division by zero"; + case EXCEPTION_FLT_INEXACT_RESULT: return "inexact floating point result"; + case EXCEPTION_FLT_INVALID_OPERATION: return "invalid floating point operation"; + case EXCEPTION_FLT_OVERFLOW: return "floating point overflow"; + case EXCEPTION_FLT_STACK_CHECK: return "floating point stack overflow"; + case EXCEPTION_FLT_UNDERFLOW: return "floating point underflow"; + case EXCEPTION_ILLEGAL_INSTRUCTION: return "illegal instruction"; + case EXCEPTION_IN_PAGE_ERROR: return "page error"; + case EXCEPTION_INT_DIVIDE_BY_ZERO: return "integer divided by zero"; + case EXCEPTION_INT_OVERFLOW: return "integer overflow"; + case EXCEPTION_INVALID_DISPOSITION: return "invalid disposition"; + case EXCEPTION_NONCONTINUABLE_EXCEPTION: return "noncontinuable exception"; + case EXCEPTION_PRIV_INSTRUCTION: return "privileged instruction"; + case EXCEPTION_SINGLE_STEP: return "single step"; + case EXCEPTION_STACK_OVERFLOW: return "stack overflow"; + default: return "(unknown exception code)"; + } +} + +LONG WINAPI sehHandler(EXCEPTION_POINTERS* info) { + void* traceSpace[32]; + auto trace = getStackTrace(traceSpace, 0, GetCurrentThread(), *info->ContextRecord); + auto message = kj::str("*** Received structured exception #0x", + hex(info->ExceptionRecord->ExceptionCode), ": ", + exceptionDescription(info->ExceptionRecord->ExceptionCode), + "; stack: ", + stringifyStackTraceAddresses(trace), + stringifyStackTrace(trace), '\n'); + FdOutputStream(STDERR_FILENO).write(message.begin(), message.size()); + return EXCEPTION_EXECUTE_HANDLER; // still crash +} + } // namespace void printStackTraceOnCrash() { mainThreadId = GetCurrentThreadId(); KJ_WIN32(SetConsoleCtrlHandler(breakHandler, TRUE)); + SetUnhandledExceptionFilter(&sehHandler); + +#if !KJ_NO_EXCEPTIONS + // Also override std::terminate() handler with something nicer for KJ. + std::set_terminate(&terminateHandler); +#endif } -#elif KJ_HAS_BACKTRACE +#elif _WIN32 +// Windows, but KJ_USE_WIN32_DBGHELP is not enabled. We can't print useful stack traces, so don't +// try to catch SEH nor ctrl+C. + +void printStackTraceOnCrash() { +#if !KJ_NO_EXCEPTIONS + std::set_terminate(&terminateHandler); +#endif +} + +#else namespace { -void crashHandler(int signo, siginfo_t* info, void* context) { +[[noreturn]] void crashHandler(int signo, siginfo_t* info, void* context) { void* traceSpace[32]; +#if KJ_USE_WIN32_DBGHELP + // Win32 backtracing can't trace its way out of a Cygwin signal handler. However, Cygwin gives + // us direct access to the CONTEXT, which we can pass to the Win32 tracing functions. + ucontext_t* ucontext = reinterpret_cast(context); + // Cygwin's mcontext_t has the same layout as CONTEXT. + // TODO(someday): Figure out why this produces garbage for SIGINT from ctrl+C. It seems to work + // correctly for SIGSEGV. + CONTEXT win32Context; + static_assert(sizeof(ucontext->uc_mcontext) >= sizeof(win32Context), + "mcontext_t should be an extension of CONTEXT"); + memcpy(&win32Context, &ucontext->uc_mcontext, sizeof(win32Context)); + auto trace = getStackTrace(traceSpace, 0, GetCurrentThread(), win32Context); +#else // ignoreCount = 2 to ignore crashHandler() and signal trampoline. auto trace = getStackTrace(traceSpace, 2); +#endif auto message = kj::str("*** Received signal #", signo, ": ", strsignal(signo), - "\nstack: ", strArray(trace, " "), + "\nstack: ", stringifyStackTraceAddresses(trace), stringifyStackTrace(trace), '\n'); FdOutputStream(STDERR_FILENO).write(message.begin(), message.size()); @@ -394,9 +664,11 @@ void printStackTraceOnCrash() { // because stack traces on ctrl+c can be obnoxious for, say, command-line tools. KJ_SYSCALL(sigaction(SIGINT, &action, nullptr)); #endif -} -#else -void printStackTraceOnCrash() { + +#if !KJ_NO_EXCEPTIONS + // Also override std::terminate() handler with something nicer for KJ. + std::set_terminate(&terminateHandler); +#endif } #endif @@ -450,6 +722,29 @@ retry: return filename; } +void resetCrashHandlers() { +#ifndef _WIN32 + struct sigaction action; + memset(&action, 0, sizeof(action)); + + action.sa_handler = SIG_DFL; + KJ_SYSCALL(sigaction(SIGSEGV, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGBUS, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGFPE, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGABRT, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGILL, &action, nullptr)); + KJ_SYSCALL(sigaction(SIGSYS, &action, nullptr)); + +#ifdef KJ_DEBUG + KJ_SYSCALL(sigaction(SIGINT, &action, nullptr)); +#endif +#endif + +#if !KJ_NO_EXCEPTIONS + std::set_terminate(nullptr); +#endif +} + StringPtr KJ_STRINGIFY(Exception::Type type) { static const char* TYPE_STRINGS[] = { "failed", @@ -481,17 +776,22 @@ String KJ_STRINGIFY(const Exception& e) { for (;;) { KJ_IF_MAYBE(c, contextPtr) { contextText[contextDepth++] = - str(c->file, ":", c->line, ": context: ", c->description, "\n"); + str(trimSourceFilename(c->file), ":", c->line, ": context: ", c->description, "\n"); contextPtr = c->next; } else { break; } } + // Note that we put "remote" before "stack" because trace frames are ordered callee before + // caller, so this is the most natural presentation ordering. return str(strArray(contextText, ""), e.getFile(), ":", e.getLine(), ": ", e.getType(), e.getDescription() == nullptr ? "" : ": ", e.getDescription(), - e.getStackTrace().size() > 0 ? "\nstack: " : "", strArray(e.getStackTrace(), " "), + e.getRemoteTrace().size() > 0 ? "\nremote: " : "", + e.getRemoteTrace(), + e.getStackTrace().size() > 0 ? "\nstack: " : "", + stringifyStackTraceAddresses(e.getStackTrace()), stringifyStackTrace(e.getStackTrace())); } @@ -511,6 +811,10 @@ Exception::Exception(const Exception& other) noexcept file = ownFile.cStr(); } + if (other.remoteTrace != nullptr) { + remoteTrace = kj::str(other.remoteTrace); + } + memcpy(trace, other.trace, sizeof(trace[0]) * traceCount); KJ_IF_MAYBE(c, other.context) { @@ -531,8 +835,17 @@ void Exception::wrapContext(const char* file, int line, String&& description) { context = heap(file, line, mv(description), mv(context)); } -void Exception::extendTrace(uint ignoreCount) { - KJ_STACK_ARRAY(void*, newTraceSpace, kj::size(trace) + ignoreCount + 1, +void Exception::extendTrace(uint ignoreCount, uint limit) { + if (isFullTrace) { + // Awkward: extendTrace() was called twice without truncating in between. This should probably + // be an error, but historically we didn't check for this so I'm hesitant to make it an error + // now. We shouldn't actually extend the trace, though, as our current trace is presumably + // rooted in main() and it'd be weird to append frames "above" that. + // TODO(cleanup): Abort here and see what breaks? + return; + } + + KJ_STACK_ARRAY(void*, newTraceSpace, kj::min(kj::size(trace), limit) + ignoreCount + 1, sizeof(trace)/sizeof(trace[0]) + 8, 128); auto newTrace = kj::getStackTrace(newTraceSpace, ignoreCount + 1); @@ -543,10 +856,26 @@ void Exception::extendTrace(uint ignoreCount) { // Copy the rest into our trace. memcpy(trace + traceCount, newTrace.begin(), newTrace.asBytes().size()); traceCount += newTrace.size(); + isFullTrace = true; } } void Exception::truncateCommonTrace() { + if (isFullTrace) { + // We're truncating the common portion of the full trace, turning it back into a limited + // trace. + isFullTrace = false; + } else { + // If the trace was never extended in the first place, trying to truncate it is at best a waste + // of time and at worst might remove information for no reason. So, don't. + // + // This comes up in particular in coroutines, when the exception originated from a co_awaited + // promise. In that case we manually add the one relevant frame to the trace, rather than + // call extendTrace() just to have to truncate most of it again a moment later in the + // unhandled_exception() callback. + return; + } + if (traceCount > 0) { // Create a "reference" stack trace that is a little bit deeper than the one in the exception. void* refTraceSpace[sizeof(this->trace) / sizeof(this->trace[0]) + 4]; @@ -584,22 +913,67 @@ void Exception::truncateCommonTrace() { } void Exception::addTrace(void* ptr) { + // TODO(cleanup): Abort here if isFullTrace is true, and see what breaks. This method only makes + // sense to call on partial traces. + if (traceCount < kj::size(trace)) { trace[traceCount++] = ptr; } } +void Exception::addTraceHere() { +#if __GNUC__ + addTrace(__builtin_return_address(0)); +#elif _MSC_VER + addTrace(_ReturnAddress()); +#else + #error "please implement for your compiler" +#endif +} + +#if !KJ_NO_EXCEPTIONS + +namespace { + +KJ_THREADLOCAL_PTR(ExceptionImpl) currentException = nullptr; + +} // namespace + class ExceptionImpl: public Exception, public std::exception { public: - inline ExceptionImpl(Exception&& other): Exception(mv(other)) {} + inline ExceptionImpl(Exception&& other): Exception(mv(other)) { + insertIntoCurrentExceptions(); + } ExceptionImpl(const ExceptionImpl& other): Exception(other) { // No need to copy whatBuffer since it's just to hold the return value of what(). + insertIntoCurrentExceptions(); + } + ~ExceptionImpl() { + // Look for ourselves in the list. + for (auto* ptr = ¤tException; *ptr != nullptr; ptr = &(*ptr)->nextCurrentException) { + if (*ptr == this) { + *ptr = nextCurrentException; + return; + } + } + + // Possibly the ExceptionImpl was destroyed on a different thread than created it? That's + // pretty bad, we'd better abort. + abort(); } const char* what() const noexcept override; private: mutable String whatBuffer; + ExceptionImpl* nextCurrentException = nullptr; + + void insertIntoCurrentExceptions() { + nextCurrentException = currentException; + currentException = this; + } + + friend class InFlightExceptionIterator; }; const char* ExceptionImpl::what() const noexcept { @@ -607,6 +981,45 @@ const char* ExceptionImpl::what() const noexcept { return whatBuffer.begin(); } +InFlightExceptionIterator::InFlightExceptionIterator() + : ptr(currentException) {} + +Maybe InFlightExceptionIterator::next() { + if (ptr == nullptr) return nullptr; + + const ExceptionImpl& result = *static_cast(ptr); + ptr = result.nextCurrentException; + return result; +} + +#endif // !KJ_NO_EXCEPTIONS + +kj::Exception getDestructionReason(void* traceSeparator, kj::Exception::Type defaultType, + const char* defaultFile, int defaultLine, kj::StringPtr defaultDescription) { +#if !KJ_NO_EXCEPTIONS + InFlightExceptionIterator iter; + KJ_IF_MAYBE(e, iter.next()) { + auto copy = kj::cp(*e); + copy.truncateCommonTrace(); + return copy; + } else { +#endif + // Darn, use a generic exception. + kj::Exception exception(defaultType, defaultFile, defaultLine, + kj::heapString(defaultDescription)); + + // Let's give some context on where the PromiseFulfiller was destroyed. + exception.extendTrace(2, 16); + + // Add a separator that hopefully makes this understandable... + exception.addTrace(traceSeparator); + + return exception; +#if !KJ_NO_EXCEPTIONS + } +#endif +} + // ======================================================================================= namespace { @@ -615,12 +1028,21 @@ KJ_THREADLOCAL_PTR(ExceptionCallback) threadLocalCallback = nullptr; } // namespace -ExceptionCallback::ExceptionCallback(): next(getExceptionCallback()) { +void requireOnStack(void* ptr, kj::StringPtr description) { +#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) || \ + KJ_HAS_COMPILER_FEATURE(address_sanitizer) || \ + defined(__SANITIZE_ADDRESS__) + // When using libfuzzer or ASAN, this sanity check may spurriously fail, so skip it. +#else char stackVar; - ptrdiff_t offset = reinterpret_cast(this) - &stackVar; - KJ_ASSERT(offset < 65536 && offset > -65536, - "ExceptionCallback must be allocated on the stack."); + ptrdiff_t offset = reinterpret_cast(ptr) - &stackVar; + KJ_REQUIRE(offset < 65536 && offset > -65536, + kj::str(description)); +#endif +} +ExceptionCallback::ExceptionCallback(): next(getExceptionCallback()) { + requireOnStack(this, "ExceptionCallback must be allocated on the stack."); threadLocalCallback = this; } @@ -649,6 +1071,14 @@ ExceptionCallback::StackTraceMode ExceptionCallback::stackTraceMode() { return next.stackTraceMode(); } +Function)> ExceptionCallback::getThreadInitializer() { + return next.getThreadInitializer(); +} + +namespace _ { // private + uint uncaughtExceptionCount(); // defined later in this file +} + class ExceptionCallback::RootExceptionCallback: public ExceptionCallback { public: RootExceptionCallback(): ExceptionCallback(*this) {} @@ -657,7 +1087,7 @@ public: #if KJ_NO_EXCEPTIONS logException(LogSeverity::ERROR, mv(exception)); #else - if (std::uncaught_exception()) { + if (_::uncaughtExceptionCount() > 0) { // Bad time to throw an exception. Just log instead. // // TODO(someday): We should really compare uncaughtExceptionCount() against the count at @@ -685,7 +1115,7 @@ public: StringPtr textPtr = text; - while (text != nullptr) { + while (textPtr != nullptr) { miniposix::ssize_t n = miniposix::write(STDERR_FILENO, textPtr.begin(), textPtr.size()); if (n <= 0) { // stderr is broken. Give up. @@ -703,6 +1133,14 @@ public: #endif } + Function)> getThreadInitializer() override { + return [](Function func) { + // No initialization needed since RootExceptionCallback is automatically the root callback + // for new threads. + func(); + }; + } + private: void logException(LogSeverity severity, Exception&& e) { // We intentionally go back to the top exception callback on the stack because we don't want to @@ -712,25 +1150,53 @@ private: // anyway. getExceptionCallback().logMessage(severity, e.getFile(), e.getLine(), 0, str( e.getType(), e.getDescription() == nullptr ? "" : ": ", e.getDescription(), - e.getStackTrace().size() > 0 ? "\nstack: " : "", strArray(e.getStackTrace(), " "), + e.getRemoteTrace().size() > 0 ? "\nremote: " : "", + e.getRemoteTrace(), + e.getStackTrace().size() > 0 ? "\nstack: " : "", + stringifyStackTraceAddresses(e.getStackTrace()), stringifyStackTrace(e.getStackTrace()), "\n")); } }; ExceptionCallback& getExceptionCallback() { - static ExceptionCallback::RootExceptionCallback defaultCallback; + static auto defaultCallback = lsanIgnoreObjectAndReturn( + new ExceptionCallback::RootExceptionCallback()); + // We allocate on the heap because some objects may throw in their destructors. If those objects + // had static storage, they might get fully constructed before the root callback. If they however + // then throw an exception during destruction, there would be a lifetime issue because their + // destructor would end up getting registered after the root callback's destructor. One solution + // is to just leak this pointer & allocate on first-use. The cost is that the initialization is + // mildly more expensive (+ we need to annotate sanitizers to ignore the problem). A great + // compiler annotation that would simply things would be one that allowed static variables to have + // their destruction omitted wholesale. That would allow us to avoid the heap but still have the + // same robust safety semantics leaking would give us. A practical alternative that could be + // implemented without new compilers would be to define another static root callback in + // RootExceptionCallback's destructor (+ a separate pointer to share its value with this + // function). Since this would end up getting constructed during exit unwind, it would have the + // nice property of effectively being guaranteed to be evicted last. + // + // All this being said, I came back to leaking the object is the easiest tweak here: + // * Can't go wrong + // * Easy to maintain + // * Throwing exceptions is bound to do be expensive and malloc-happy anyway, so the incremental + // cost of 1 heap allocation is minimal. + // + // TODO(cleanup): Harris has an excellent suggestion in + // https://github.com/capnproto/capnproto/pull/1255 that should ensure we initialize the root + // callback once on first use as a global & never destroy it. + ExceptionCallback* scoped = threadLocalCallback; - return scoped != nullptr ? *scoped : defaultCallback; + return scoped != nullptr ? *scoped : *defaultCallback; } void throwFatalException(kj::Exception&& exception, uint ignoreCount) { - exception.extendTrace(ignoreCount + 1); + if (ignoreCount != (uint)kj::maxValue) exception.extendTrace(ignoreCount + 1); getExceptionCallback().onFatalException(kj::mv(exception)); abort(); } void throwRecoverableException(kj::Exception&& exception, uint ignoreCount) { - exception.extendTrace(ignoreCount + 1); + if (ignoreCount != (uint)kj::maxValue) exception.extendTrace(ignoreCount + 1); getExceptionCallback().onRecoverableException(kj::mv(exception)); } @@ -820,12 +1286,83 @@ bool UnwindDetector::isUnwinding() const { return _::uncaughtExceptionCount() > uncaughtCount; } -void UnwindDetector::catchExceptionsAsSecondaryFaults(_::Runnable& runnable) const { +#if !KJ_NO_EXCEPTIONS +void UnwindDetector::catchThrownExceptionAsSecondaryFault() const { // TODO(someday): Attach the secondary exception to whatever primary exception is causing // the unwind. For now we just drop it on the floor as this is probably fine most of the // time. - runCatchingExceptions(runnable); + getCaughtExceptionAsKj(); +} +#endif + +#if __GNUC__ && !KJ_NO_RTTI +static kj::String demangleTypeName(const char* name) { + if (name == nullptr) return kj::heapString("(nil)"); + + int status; + char* buf = abi::__cxa_demangle(name, nullptr, nullptr, &status); + kj::String result = kj::heapString(buf == nullptr ? name : buf); + free(buf); + return kj::mv(result); +} + +kj::String getCaughtExceptionType() { + return demangleTypeName(abi::__cxa_current_exception_type()->name()); } +#else +kj::String getCaughtExceptionType() { + return kj::heapString("(unknown)"); +} +#endif + +namespace { + +size_t sharedSuffixLength(kj::ArrayPtr a, kj::ArrayPtr b) { + size_t result = 0; + while (a.size() > 0 && b.size() > 0 && a.back() == b.back()) { + ++result; + a = a.slice(0, a.size() - 1); + b = b.slice(0, b.size() - 1); + } + return result; +} + +} // namespace + +kj::ArrayPtr computeRelativeTrace( + kj::ArrayPtr trace, kj::ArrayPtr relativeTo) { + using miniposix::ssize_t; + + static constexpr size_t MIN_MATCH_LEN = 4; + if (trace.size() < MIN_MATCH_LEN || relativeTo.size() < MIN_MATCH_LEN) { + return trace; + } + + kj::ArrayPtr bestMatch = trace; + uint bestMatchLen = MIN_MATCH_LEN - 1; // must beat this to choose something else + + // `trace` and `relativeTrace` may have been truncated at different points. We iterate through + // truncating various suffixes from one of the two and then seeing if the remaining suffixes + // match. + for (ssize_t i = -(ssize_t)(trace.size() - MIN_MATCH_LEN); + i <= (ssize_t)(relativeTo.size() - MIN_MATCH_LEN); + i++) { + // Negative values truncate `trace`, positive values truncate `relativeTo`. + kj::ArrayPtr subtrace = trace.slice(0, trace.size() - kj::max(0, -i)); + kj::ArrayPtr subrt = relativeTo + .slice(0, relativeTo.size() - kj::max(0, i)); + + uint matchLen = sharedSuffixLength(subtrace, subrt); + if (matchLen > bestMatchLen) { + bestMatchLen = matchLen; + bestMatch = subtrace.slice(0, subtrace.size() - matchLen + 1); + } + } + + return bestMatch; +} + +#if KJ_NO_EXCEPTIONS namespace _ { // private @@ -847,34 +1384,44 @@ public: Maybe caught; }; -Maybe runCatchingExceptions(Runnable& runnable) noexcept { -#if KJ_NO_EXCEPTIONS +Maybe runCatchingExceptions(Runnable& runnable) { RecoverableExceptionCatcher catcher; runnable.run(); KJ_IF_MAYBE(e, catcher.caught) { e->truncateCommonTrace(); } return mv(catcher.caught); -#else +} + +} // namespace _ (private) + +#else // KJ_NO_EXCEPTIONS + +kj::Exception getCaughtExceptionAsKj() { try { - runnable.run(); - return nullptr; + throw; } catch (Exception& e) { e.truncateCommonTrace(); return kj::mv(e); + } catch (CanceledException) { + throw; } catch (std::bad_alloc& e) { return Exception(Exception::Type::OVERLOADED, "(unknown)", -1, str("std::bad_alloc: ", e.what())); } catch (std::exception& e) { return Exception(Exception::Type::FAILED, "(unknown)", -1, str("std::exception: ", e.what())); + } catch (TopLevelProcessContext::CleanShutdownException) { + throw; } catch (...) { - return Exception(Exception::Type::FAILED, - "(unknown)", -1, str("Unknown non-KJ exception.")); - } +#if __GNUC__ && !KJ_NO_RTTI + return Exception(Exception::Type::FAILED, "(unknown)", -1, str( + "unknown non-KJ exception of type: ", getCaughtExceptionType())); +#else + return Exception(Exception::Type::FAILED, "(unknown)", -1, str("unknown non-KJ exception")); #endif + } } - -} // namespace _ (private) +#endif // !KJ_NO_EXCEPTIONS } // namespace kj diff --git a/c++/src/kj/exception.h b/c++/src/kj/exception.h index f6c0b2daa6..be90163f93 100644 --- a/c++/src/kj/exception.h +++ b/c++/src/kj/exception.h @@ -19,20 +19,19 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_EXCEPTION_H_ -#define KJ_EXCEPTION_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "memory.h" #include "array.h" #include "string.h" +#include "windows-sanity.h" // work-around macro conflict with `ERROR` + +KJ_BEGIN_HEADER namespace kj { class ExceptionImpl; +template class Function; class Exception { // Exception thrown in case of fatal errors. @@ -81,6 +80,15 @@ class Exception { StringPtr getDescription() const { return description; } ArrayPtr getStackTrace() const { return arrayPtr(trace, traceCount); } + void setDescription(kj::String&& desc) { description = kj::mv(desc); } + + StringPtr getRemoteTrace() const { return remoteTrace; } + void setRemoteTrace(kj::String&& value) { remoteTrace = kj::mv(value); } + // Additional stack trace data originating from a remote server. If present, then + // `getStackTrace()` only traces up until entry into the RPC system, and the remote trace + // contains any trace information returned over the wire. This string is human-readable but the + // format is otherwise unspecified. + struct Context { // Describes a bit about what was going on when the exception was thrown. @@ -107,9 +115,11 @@ class Exception { // is expected that contexts will be added in reverse order as the exception passes up the // callback stack. - KJ_NOINLINE void extendTrace(uint ignoreCount); + KJ_NOINLINE void extendTrace(uint ignoreCount, uint limit = kj::maxValue); // Append the current stack trace to the exception's trace, ignoring the first `ignoreCount` // frames (see `getStackTrace()` for discussion of `ignoreCount`). + // + // If `limit` is set, limit the number of frames added to the given number. KJ_NOINLINE void truncateCommonTrace(); // Remove the part of the stack trace which the exception shares with the caller of this method. @@ -120,6 +130,9 @@ class Exception { // Append the given pointer to the backtrace, if it is not already full. This is used by the // async library to trace through the promise chain that led to the exception. + KJ_NOINLINE void addTraceHere(); + // Adds the location that called this method to the stack trace. + private: String ownFile; const char* file; @@ -127,12 +140,32 @@ class Exception { Type type; String description; Maybe> context; + String remoteTrace; void* trace[32]; uint traceCount; + bool isFullTrace = false; + // Is `trace` a full trace to the top of the stack (or as close as we could get before we ran + // out of space)? If this is false, then `trace` is instead a partial trace covering just the + // frames between where the exception was thrown and where it was caught. + // + // extendTrace() transitions this to true, and truncateCommonTrace() changes it back to false. + // + // In theory, an exception should only hold a full trace when it is in the process of being + // thrown via the C++ exception handling mechanism -- extendTrace() is called before the throw + // and truncateCommonTrace() after it is caught. Note that when exceptions propagate through + // async promises, the trace is extended one frame at a time instead, so isFullTrace should + // remain false. + friend class ExceptionImpl; }; +struct CanceledException { }; +// This exception is thrown to force-unwind a stack in order to immediately cancel whatever that +// stack was doing. It is used in the implementation of fibers in particular. Application code +// should almost never catch this exception, unless you need to modify stack unwinding for some +// reason. kj::runCatchingExceptions() does not catch it. + StringPtr KJ_STRINGIFY(Exception::Type type); String KJ_STRINGIFY(const Exception& e); @@ -167,7 +200,7 @@ class ExceptionCallback { public: ExceptionCallback(); - KJ_DISALLOW_COPY(ExceptionCallback); + KJ_DISALLOW_COPY_AND_MOVE(ExceptionCallback); virtual ~ExceptionCallback() noexcept(false); virtual void onRecoverableException(Exception&& exception); @@ -216,6 +249,11 @@ class ExceptionCallback { virtual StackTraceMode stackTraceMode(); // Returns the current preferred stack trace mode. + virtual Function)> getThreadInitializer(); + // Called just before a new thread is spawned using kj::Thread. Returns a function which should + // be invoked inside the new thread to initialize the thread's ExceptionCallback. The initializer + // function itself receives, as its parameter, the thread's main function, which it must call. + protected: ExceptionCallback& next; @@ -224,6 +262,8 @@ class ExceptionCallback { class RootExceptionCallback; friend ExceptionCallback& getExceptionCallback(); + + friend class Thread; }; ExceptionCallback& getExceptionCallback(); @@ -242,7 +282,7 @@ KJ_NOINLINE void throwRecoverableException(kj::Exception&& exception, uint ignor namespace _ { class Runnable; } template -Maybe runCatchingExceptions(Func&& func) noexcept; +Maybe runCatchingExceptions(Func&& func); // Executes the given function (usually, a lambda returning nothing) catching any exceptions that // are thrown. Returns the Exception if there was one, or null if the operation completed normally. // Non-KJ exceptions will be wrapped. @@ -250,6 +290,20 @@ Maybe runCatchingExceptions(Func&& func) noexcept; // If exception are disabled (e.g. with -fno-exceptions), this will still detect whether any // recoverable exceptions occurred while running the function and will return those. +#if !KJ_NO_EXCEPTIONS + +kj::Exception getCaughtExceptionAsKj(); +// Call from the catch block of a try/catch to get a `kj::Exception` representing the exception +// that was caught, the same way that `kj::runCatchingExceptions` would when catching an exception. +// This is sometimes useful if `runCatchingExceptions()` doesn't quite fit your use case. You can +// call this from any catch block, including `catch (...)`. +// +// Some exception types will actually be rethrown by this function, rather than returned. The most +// common example is `CanceledException`, whose purpose is to unwind the stack and is not meant to +// be caught. + +#endif // !KJ_NO_EXCEPTIONS + class UnwindDetector { // Utility for detecting when a destructor is called due to unwind. Useful for: // - Avoiding throwing exceptions in this case, which would terminate the program. @@ -276,9 +330,13 @@ class UnwindDetector { private: uint uncaughtCount; - void catchExceptionsAsSecondaryFaults(_::Runnable& runnable) const; +#if !KJ_NO_EXCEPTIONS + void catchThrownExceptionAsSecondaryFault() const; +#endif }; +#if KJ_NO_EXCEPTIONS + namespace _ { // private class Runnable { @@ -289,7 +347,7 @@ class Runnable { template class RunnableImpl: public Runnable { public: - RunnableImpl(Func&& func): func(kj::mv(func)) {} + RunnableImpl(Func&& func): func(kj::fwd(func)) {} void run() override { func(); } @@ -297,24 +355,43 @@ class RunnableImpl: public Runnable { Func func; }; -Maybe runCatchingExceptions(Runnable& runnable) noexcept; +Maybe runCatchingExceptions(Runnable& runnable); } // namespace _ (private) +#endif // KJ_NO_EXCEPTIONS + template -Maybe runCatchingExceptions(Func&& func) noexcept { - _::RunnableImpl> runnable(kj::fwd(func)); +Maybe runCatchingExceptions(Func&& func) { +#if KJ_NO_EXCEPTIONS + _::RunnableImpl runnable(kj::fwd(func)); return _::runCatchingExceptions(runnable); +#else + try { + func(); + return nullptr; + } catch (...) { + return getCaughtExceptionAsKj(); + } +#endif } template void UnwindDetector::catchExceptionsIfUnwinding(Func&& func) const { +#if KJ_NO_EXCEPTIONS + // Can't possibly be unwinding... + func(); +#else if (isUnwinding()) { - _::RunnableImpl> runnable(kj::fwd(func)); - catchExceptionsAsSecondaryFaults(runnable); + try { + func(); + } catch (...) { + catchThrownExceptionAsSecondaryFault(); + } } else { func(); } +#endif } #define KJ_ON_SCOPE_SUCCESS(code) \ @@ -346,6 +423,15 @@ String stringifyStackTrace(ArrayPtr); // Convert the stack trace to a string with file names and line numbers. This may involve executing // suprocesses. +String stringifyStackTraceAddresses(ArrayPtr trace); +StringPtr stringifyStackTraceAddresses(ArrayPtr trace, ArrayPtr scratch); +// Construct a string containing just enough information about a stack trace to be able to convert +// it to file and line numbers later using offline tools. This produces a sequence of +// space-separated code location identifiers. Each identifier may be an absolute address +// (hex number starting with 0x) or may be a module-relative address "@0x". The +// latter case is preferred when ASLR is in effect and has loaded different modules at different +// addresses. + String getStackTrace(); // Get a stack trace right now and stringify it. Useful for debugging. @@ -354,10 +440,65 @@ void printStackTraceOnCrash(); // a stack trace. You should call this as early as possible on program startup. Programs using // KJ_MAIN get this automatically. +void resetCrashHandlers(); +// Resets all signal handlers set by printStackTraceOnCrash(). + kj::StringPtr trimSourceFilename(kj::StringPtr filename); // Given a source code file name, trim off noisy prefixes like "src/" or // "/ekam-provider/canonical/". +kj::String getCaughtExceptionType(); +// Utility function which attempts to return the human-readable type name of the exception +// currently being thrown. This can be called inside a catch block, including a catch (...) block, +// for the purpose of error logging. This function is best-effort; on some platforms it may simply +// return "(unknown)". + +#if !KJ_NO_EXCEPTIONS + +class InFlightExceptionIterator { + // A class that can be used to iterate over exceptions that are in-flight in the current thread, + // meaning they are either uncaught, or caught by a catch block that is current executing. + // + // This is meant for debugging purposes, and the results are best-effort. The C++ standard + // library does not provide any way to inspect uncaught exceptions, so this class can only + // discover KJ exceptions thrown using throwFatalException() or throwRecoverableException(). + // All KJ code uses those two functions to throw exceptions, but if your own code uses a bare + // `throw`, or if the standard library throws an exception, these cannot be inspected. + // + // This class is safe to use in a signal handler. + +public: + InFlightExceptionIterator(); + + Maybe next(); + +private: + const Exception* ptr; +}; + +#endif // !KJ_NO_EXCEPTIONS + +kj::Exception getDestructionReason(void* traceSeparator, + kj::Exception::Type defaultType, const char* defaultFile, int defaultLine, + kj::StringPtr defaultDescription); +// Returns an exception that attempts to capture why a destructor has been invoked. If a KJ +// exception is currently in-flight (see InFlightExceptionIterator), then that exception is +// returned. Otherwise, an exception is constructed using the current stack trace and the type, +// file, line, and description provided. In the latter case, `traceSeparator` is appended to the +// stack trace; this should be a pointer to some dummy symbol which acts as a separator between the +// original stack trace and any new trace frames added later. + +kj::ArrayPtr computeRelativeTrace( + kj::ArrayPtr trace, kj::ArrayPtr relativeTo); +// Given two traces expected to have started from the same root, try to find the part of `trace` +// that is different from `relativeTo`, considering that either or both traces might be truncated. +// +// This is useful for debugging, when reporting several related traces at once. + +void requireOnStack(void* ptr, kj::StringPtr description); +// Throw an exception if `ptr` does not appear to point to something near the top of the stack. +// Used as a safety check for types that must be stack-allocated, like ExceptionCallback. + } // namespace kj -#endif // KJ_EXCEPTION_H_ +KJ_END_HEADER diff --git a/c++/src/kj/filesystem-disk-generic-test.c++ b/c++/src/kj/filesystem-disk-generic-test.c++ new file mode 100644 index 0000000000..08c54b0e34 --- /dev/null +++ b/c++/src/kj/filesystem-disk-generic-test.c++ @@ -0,0 +1,69 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if __linux__ + +// This test compiles filesystem-disk-unix.c++ with various features #undefed, causing it to take +// different code paths, then runs filesystem-disk-test.c++ against that. +// +// This test is only intended to run on Linux, but is intended to make the code behave like it +// would on a generic flavor of Unix. +// +// At present this test only runs under Ekam builds. Integrating it into other builds would be +// awkward since it #includes filesystem-disk-unix.c++, so it cannot link against that file, but +// needs to link against the rest of KJ. Ekam "just figures it out", but other build systems would +// require a lot of work here. + +#include "filesystem.h" +#include "debug.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "vector.h" +#include "miniposix.h" + +#undef __linux__ +#undef O_CLOEXEC +#undef O_DIRECTORY +#undef O_TMPFILE +#undef FIOCLEX +#undef DT_UNKNOWN +#undef F_DUPFD_CLOEXEC +#undef FALLOC_FL_PUNCH_HOLE +#undef FICLONE +#undef FICLONERANGE +#undef SEEK_HOLE +#undef SEEK_DATA +#undef RENAME_EXCHANGE + +#define HOLES_NOT_SUPPORTED + +#include "filesystem-disk-unix.c++" +#include "filesystem-disk-test.c++" + +#endif // __linux__ diff --git a/c++/src/kj/filesystem-disk-old-kernel-test.c++ b/c++/src/kj/filesystem-disk-old-kernel-test.c++ new file mode 100644 index 0000000000..88d24e74fd --- /dev/null +++ b/c++/src/kj/filesystem-disk-old-kernel-test.c++ @@ -0,0 +1,133 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if __linux__ && __x86_64__ && defined(__has_include) +#if __has_include() && \ + __has_include() && \ + __has_include() && \ + __has_include() && \ + __has_include() +// This test re-runs filesystem-disk-test.c++ with newfangled Linux kernel features disabled. +// +// This test must be compiled as a separate program, since it alters the calling process by +// enabling seccomp to disable the kernel features. +// +// At present this test only runs under Ekam builds. It *could* reasonably easily be added to the +// autotools or cmake builds, but would require compiling a separate test binary, which is a bit +// weird, and may lead to spurious error reports on systems that don't support seccomp for whatever +// reason. + +#include +#include +#include +#include +#include +#include +#include + +#ifdef SECCOMP_SET_MODE_FILTER + +namespace { + +#if 0 +// Source code of the seccomp filter: + + ld [0] /* offsetof(struct seccomp_data, nr) */ + jeq #8, lseek /* __NR_lseek */ + jeq #16, inval /* __NR_ioctl */ + jeq #40, nosys /* __NR_sendfile */ + jeq #257, openat /* __NR_openat */ + jeq #285, notsup /* __NR_fallocate */ + jeq #316, nosys /* __NR_renameat2 */ + jmp good + +openat: + ld [32] /* offsetof(struct seccomp_data, args[2]), aka flags */ + and #4259840 /* O_TMPFILE */ + jeq #4259840, notsup + jmp good + +lseek: + ld [32] /* offsetof(struct seccomp_data, args[2]), aka whence */ + jeq #3, inval /* SEEK_DATA */ + jeq #4, inval /* SEEK_HOLE */ + jmp good + +inval: ret #0x00050016 /* SECCOMP_RET_ERRNO | EINVAL */ +nosys: ret #0x00050026 /* SECCOMP_RET_ERRNO | ENOSYS */ +notsup: ret #0x0005005f /* SECCOMP_RET_ERRNO | EOPNOTSUPP */ +good: ret #0x7fff0000 /* SECCOMP_RET_ALLOW */ + +#endif + +struct SetupSeccompForFilesystemTest { + SetupSeccompForFilesystemTest() { + struct sock_filter filter[] { + { 0x20, 0, 0, 0000000000 }, + { 0x15, 10, 0, 0x00000008 }, + { 0x15, 13, 0, 0x00000010 }, + { 0x15, 13, 0, 0x00000028 }, + { 0x15, 3, 0, 0x00000101 }, + { 0x15, 12, 0, 0x0000011d }, + { 0x15, 10, 0, 0x0000013c }, + { 0x05, 0, 0, 0x0000000b }, + { 0x20, 0, 0, 0x00000020 }, + { 0x54, 0, 0, 0x00410000 }, + { 0x15, 7, 0, 0x00410000 }, + { 0x05, 0, 0, 0x00000007 }, + { 0x20, 0, 0, 0x00000020 }, + { 0x15, 2, 0, 0x00000003 }, + { 0x15, 1, 0, 0x00000004 }, + { 0x05, 0, 0, 0x00000003 }, + { 0x06, 0, 0, 0x00050016 }, + { 0x06, 0, 0, 0x00050026 }, + { 0x06, 0, 0, 0x0005005f }, + { 0x06, 0, 0, 0x7fff0000 }, + }; + + struct sock_fprog prog { sizeof(filter) / sizeof(filter[0]), filter }; + + KJ_SYSCALL(prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0)); + KJ_SYSCALL(syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER, 0, &prog)); + } +}; + +SetupSeccompForFilesystemTest setupSeccompForFilesystemTest; + +} // namespace + +#define HOLES_NOT_SUPPORTED + +// OK, now run all the regular filesystem tests! +#include "filesystem-disk-test.c++" + +#endif +#endif +#endif + +#if __linux__ && !__x86_64__ +// HACK: We may be cross-compiling. Ekam's cross-compiling is currently hacky -- if a test is a +// test on the host platform then it needs to be a test on all other targets, too. So add a dummy +// test here. +// TODO(cleanup): Make Ekam cross-compiling better. +#include +KJ_TEST("old kernel test -- not supported on this architecture") {} +#endif diff --git a/c++/src/kj/filesystem-disk-test.c++ b/c++/src/kj/filesystem-disk-test.c++ new file mode 100644 index 0000000000..259964ad3c --- /dev/null +++ b/c++/src/kj/filesystem-disk-test.c++ @@ -0,0 +1,993 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "debug.h" +#include "filesystem.h" +#include "string.h" +#include "test.h" +#include "encoding.h" +#include +#include +#if _WIN32 +#include +#include "windows-sanity.h" +#else +#include +#include +#include +#include +#include +#include +#endif + +namespace kj { +namespace { + +bool isWine() KJ_UNUSED; + +#if _WIN32 + +bool detectWine() { + HMODULE hntdll = GetModuleHandle("ntdll.dll"); + if(hntdll == NULL) return false; + return GetProcAddress(hntdll, "wine_get_version") != nullptr; +} + +bool isWine() { + static bool result = detectWine(); + return result; +} + +template +static auto newTemp(Func&& create) + -> Decay())))> { + wchar_t wtmpdir[MAX_PATH + 1]; + DWORD len = GetTempPathW(kj::size(wtmpdir), wtmpdir); + KJ_ASSERT(len < kj::size(wtmpdir)); + auto tmpdir = decodeWideString(arrayPtr(wtmpdir, len)); + + static uint counter = 0; + for (;;) { + auto path = kj::str(tmpdir, "kj-filesystem-test.", GetCurrentProcessId(), ".", counter++); + KJ_IF_MAYBE(result, create(encodeWideString(path, true))) { + return kj::mv(*result); + } + } +} + +static Own newTempFile() { + return newTemp([](Array candidatePath) -> Maybe> { + HANDLE handle; + KJ_WIN32_HANDLE_ERRORS(handle = CreateFileW( + candidatePath.begin(), + FILE_GENERIC_READ | FILE_GENERIC_WRITE, + 0, + NULL, + CREATE_NEW, + FILE_ATTRIBUTE_TEMPORARY | FILE_FLAG_DELETE_ON_CLOSE, + NULL)) { + case ERROR_ALREADY_EXISTS: + case ERROR_FILE_EXISTS: + return nullptr; + default: + KJ_FAIL_WIN32("CreateFileW", error); + } + return newDiskFile(AutoCloseHandle(handle)); + }); +} + +static Array join16(ArrayPtr path, const wchar_t* file) { + // Assumes `path` ends with a NUL terminator (and `file` is of course NUL terminated as well). + + size_t len = wcslen(file) + 1; + auto result = kj::heapArray(path.size() + len); + memcpy(result.begin(), path.begin(), path.asBytes().size() - sizeof(wchar_t)); + result[path.size() - 1] = '\\'; + memcpy(result.begin() + path.size(), file, len * sizeof(wchar_t)); + return result; +} + +class TempDir { +public: + TempDir(): filename(newTemp([](Array candidatePath) -> Maybe> { + KJ_WIN32_HANDLE_ERRORS(CreateDirectoryW(candidatePath.begin(), NULL)) { + case ERROR_ALREADY_EXISTS: + case ERROR_FILE_EXISTS: + return nullptr; + default: + KJ_FAIL_WIN32("CreateDirectoryW", error); + } + return kj::mv(candidatePath); + })) {} + + Own get() { + HANDLE handle; + KJ_WIN32(handle = CreateFileW( + filename.begin(), + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE, + NULL, + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, // apparently, this flag is required for directories + NULL)); + return newDiskDirectory(AutoCloseHandle(handle)); + } + + ~TempDir() noexcept(false) { + recursiveDelete(filename); + } + +private: + Array filename; + + static void recursiveDelete(ArrayPtr path) { + // Recursively delete the temp dir, verifying that no .kj-tmp. files were left over. + // + // Mostly copied from rmrfChildren() in filesystem-win32.c++. + + auto glob = join16(path, L"\\*"); + + WIN32_FIND_DATAW data; + HANDLE handle = FindFirstFileW(glob.begin(), &data); + if (handle == INVALID_HANDLE_VALUE) { + auto error = GetLastError(); + if (error == ERROR_FILE_NOT_FOUND) return; + KJ_FAIL_WIN32("FindFirstFile", error, path) { return; } + } + KJ_DEFER(KJ_WIN32(FindClose(handle)) { break; }); + + do { + // Ignore "." and "..", ugh. + if (data.cFileName[0] == L'.') { + if (data.cFileName[1] == L'\0' || + (data.cFileName[1] == L'.' && data.cFileName[2] == L'\0')) { + continue; + } + } + + String utf8Name = decodeWideString(arrayPtr(data.cFileName, wcslen(data.cFileName))); + KJ_EXPECT(!utf8Name.startsWith(".kj-tmp."), "temp file not cleaned up", utf8Name); + + auto child = join16(path, data.cFileName); + if ((data.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) && + !(data.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT)) { + recursiveDelete(child); + } else { + KJ_WIN32(DeleteFileW(child.begin())); + } + } while (FindNextFileW(handle, &data)); + + auto error = GetLastError(); + if (error != ERROR_NO_MORE_FILES) { + KJ_FAIL_WIN32("FindNextFile", error, path) { return; } + } + + uint retryCount = 0; + retry: + KJ_WIN32_HANDLE_ERRORS(RemoveDirectoryW(path.begin())) { + case ERROR_DIR_NOT_EMPTY: + if (retryCount++ < 10) { + Sleep(10); + goto retry; + } + KJ_FALLTHROUGH; + default: + KJ_FAIL_WIN32("RemoveDirectory", error) { break; } + } + } +}; + +#else + +bool isWine() { return false; } + +#if __APPLE__ || __CYGWIN__ +#define HOLES_NOT_SUPPORTED 1 +#endif + +#if __ANDROID__ +#define VAR_TMP "/data/local/tmp" +#else +#define VAR_TMP "/var/tmp" +#endif + +static Own newTempFile() { + const char* tmpDir = getenv("TEST_TMPDIR"); + auto filename = str(tmpDir != nullptr ? tmpDir : VAR_TMP, "/kj-filesystem-test.XXXXXX"); + int fd; + KJ_SYSCALL(fd = mkstemp(filename.begin())); + KJ_DEFER(KJ_SYSCALL(unlink(filename.cStr()))); + return newDiskFile(AutoCloseFd(fd)); +} + +class TempDir { +public: + TempDir() { + const char* tmpDir = getenv("TEST_TMPDIR"); + filename = str(tmpDir != nullptr ? tmpDir : VAR_TMP, "/kj-filesystem-test.XXXXXX"); + if (mkdtemp(filename.begin()) == nullptr) { + KJ_FAIL_SYSCALL("mkdtemp", errno, filename); + } + } + + Own get() { + int fd; + KJ_SYSCALL(fd = open(filename.cStr(), O_RDONLY)); + return newDiskDirectory(AutoCloseFd(fd)); + } + + ~TempDir() noexcept(false) { + recursiveDelete(filename); + } + +private: + String filename; + + static void recursiveDelete(StringPtr path) { + // Recursively delete the temp dir, verifying that no .kj-tmp. files were left over. + + { + DIR* dir = opendir(path.cStr()); + KJ_ASSERT(dir != nullptr); + KJ_DEFER(closedir(dir)); + + for (;;) { + auto entry = readdir(dir); + if (entry == nullptr) break; + + StringPtr name = entry->d_name; + if (name == "." || name == "..") continue; + + auto subPath = kj::str(path, '/', entry->d_name); + + KJ_EXPECT(!name.startsWith(".kj-tmp."), "temp file not cleaned up", subPath); + + struct stat stats; + KJ_SYSCALL(lstat(subPath.cStr(), &stats)); + + if (S_ISDIR(stats.st_mode)) { + recursiveDelete(subPath); + } else { + KJ_SYSCALL(unlink(subPath.cStr())); + } + } + } + + KJ_SYSCALL(rmdir(path.cStr())); + } +}; + +#endif // _WIN32, else + +KJ_TEST("DiskFile") { + auto file = newTempFile(); + + KJ_EXPECT(file->readAllText() == ""); + + // mmaping empty file should work + KJ_EXPECT(file->mmap(0, 0).size() == 0); + KJ_EXPECT(file->mmapPrivate(0, 0).size() == 0); + KJ_EXPECT(file->mmapWritable(0, 0)->get().size() == 0); + + file->writeAll("foo"); + KJ_EXPECT(file->readAllText() == "foo"); + + file->write(3, StringPtr("bar").asBytes()); + KJ_EXPECT(file->readAllText() == "foobar"); + + file->write(3, StringPtr("baz").asBytes()); + KJ_EXPECT(file->readAllText() == "foobaz"); + + file->write(9, StringPtr("qux").asBytes()); + KJ_EXPECT(file->readAllText() == kj::StringPtr("foobaz\0\0\0qux", 12)); + + file->truncate(6); + KJ_EXPECT(file->readAllText() == "foobaz"); + + file->truncate(18); + KJ_EXPECT(file->readAllText() == kj::StringPtr("foobaz\0\0\0\0\0\0\0\0\0\0\0\0", 18)); + + // empty mappings work, even if useless + KJ_EXPECT(file->mmap(0, 0).size() == 0); + KJ_EXPECT(file->mmapPrivate(0, 0).size() == 0); + KJ_EXPECT(file->mmapWritable(0, 0)->get().size() == 0); + KJ_EXPECT(file->mmap(2, 0).size() == 0); + KJ_EXPECT(file->mmapPrivate(2, 0).size() == 0); + KJ_EXPECT(file->mmapWritable(2, 0)->get().size() == 0); + + { + auto mapping = file->mmap(0, 18); + auto privateMapping = file->mmapPrivate(0, 18); + auto writableMapping = file->mmapWritable(0, 18); + + KJ_EXPECT(mapping.size() == 18); + KJ_EXPECT(privateMapping.size() == 18); + KJ_EXPECT(writableMapping->get().size() == 18); + + KJ_EXPECT(writableMapping->get().begin() != mapping.begin()); + KJ_EXPECT(privateMapping.begin() != mapping.begin()); + KJ_EXPECT(writableMapping->get().begin() != privateMapping.begin()); + + KJ_EXPECT(kj::str(mapping.slice(0, 6).asChars()) == "foobaz"); + KJ_EXPECT(kj::str(writableMapping->get().slice(0, 6).asChars()) == "foobaz"); + KJ_EXPECT(kj::str(privateMapping.slice(0, 6).asChars()) == "foobaz"); + + privateMapping[0] = 'F'; + KJ_EXPECT(kj::str(mapping.slice(0, 6).asChars()) == "foobaz"); + KJ_EXPECT(kj::str(writableMapping->get().slice(0, 6).asChars()) == "foobaz"); + KJ_EXPECT(kj::str(privateMapping.slice(0, 6).asChars()) == "Foobaz"); + + writableMapping->get()[1] = 'D'; + writableMapping->changed(writableMapping->get().slice(1, 2)); + KJ_EXPECT(kj::str(mapping.slice(0, 6).asChars()) == "fDobaz"); + KJ_EXPECT(kj::str(writableMapping->get().slice(0, 6).asChars()) == "fDobaz"); + KJ_EXPECT(kj::str(privateMapping.slice(0, 6).asChars()) == "Foobaz"); + + file->write(0, StringPtr("qux").asBytes()); + KJ_EXPECT(kj::str(mapping.slice(0, 6).asChars()) == "quxbaz"); + KJ_EXPECT(kj::str(writableMapping->get().slice(0, 6).asChars()) == "quxbaz"); + KJ_EXPECT(kj::str(privateMapping.slice(0, 6).asChars()) == "Foobaz"); + + file->write(12, StringPtr("corge").asBytes()); + KJ_EXPECT(kj::str(mapping.slice(12, 17).asChars()) == "corge"); + +#if !_WIN32 && !__CYGWIN__ // Windows doesn't allow the file size to change while mapped. + // Can shrink. + file->truncate(6); + KJ_EXPECT(kj::str(mapping.slice(12, 17).asChars()) == kj::StringPtr("\0\0\0\0\0", 5)); + + // Can regrow. + file->truncate(18); + KJ_EXPECT(kj::str(mapping.slice(12, 17).asChars()) == kj::StringPtr("\0\0\0\0\0", 5)); + + // Can even regrow past previous capacity. + file->truncate(100); +#endif + } + + file->truncate(6); + + KJ_EXPECT(file->readAllText() == "quxbaz"); + file->zero(3, 3); + KJ_EXPECT(file->readAllText() == StringPtr("qux\0\0\0", 6)); +} + +KJ_TEST("DiskFile::copy()") { + auto source = newTempFile(); + source->writeAll("foobarbaz"); + + auto dest = newTempFile(); + dest->writeAll("quxcorge"); + + KJ_EXPECT(dest->copy(3, *source, 6, kj::maxValue) == 3); + KJ_EXPECT(dest->readAllText() == "quxbazge"); + + KJ_EXPECT(dest->copy(0, *source, 3, 4) == 4); + KJ_EXPECT(dest->readAllText() == "barbazge"); + + KJ_EXPECT(dest->copy(0, *source, 128, kj::maxValue) == 0); + + KJ_EXPECT(dest->copy(4, *source, 3, 0) == 0); + + String bigString = strArray(repeat("foobar", 10000), ""); + source->truncate(bigString.size() + 1000); + source->write(123, bigString.asBytes()); + + dest->copy(321, *source, 123, bigString.size()); + KJ_EXPECT(dest->readAllText().slice(321) == bigString); +} + +KJ_TEST("DiskDirectory") { + TempDir tempDir; + auto dir = tempDir.get(); + + KJ_EXPECT(dir->listNames() == nullptr); + KJ_EXPECT(dir->listEntries() == nullptr); + KJ_EXPECT(!dir->exists(Path("foo"))); + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::MODIFY) == nullptr); + + { + auto file = dir->openFile(Path("foo"), WriteMode::CREATE); + file->writeAll("foobar"); + } + + KJ_EXPECT(dir->exists(Path("foo"))); + + { + auto stats = dir->lstat(Path("foo")); + KJ_EXPECT(stats.type == FsNode::Type::FILE); + KJ_EXPECT(stats.size == 6); + } + + { + auto list = dir->listNames(); + KJ_ASSERT(list.size() == 1); + KJ_EXPECT(list[0] == "foo"); + } + + { + auto list = dir->listEntries(); + KJ_ASSERT(list.size() == 1); + KJ_EXPECT(list[0].name == "foo"); + KJ_EXPECT(list[0].type == FsNode::Type::FILE); + } + + KJ_EXPECT(dir->openFile(Path("foo"))->readAllText() == "foobar"); + + KJ_EXPECT(dir->tryOpenFile(Path({"foo", "bar"}), WriteMode::MODIFY) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path({"bar", "baz"}), WriteMode::MODIFY) == nullptr); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("parent is not a directory", + dir->tryOpenFile(Path({"bar", "baz"}), WriteMode::CREATE)); + + { + auto file = dir->openFile(Path({"bar", "baz"}), WriteMode::CREATE | WriteMode::CREATE_PARENT); + file->writeAll("bazqux"); + } + + KJ_EXPECT(dir->openFile(Path({"bar", "baz"}))->readAllText() == "bazqux"); + + { + auto stats = dir->lstat(Path("bar")); + KJ_EXPECT(stats.type == FsNode::Type::DIRECTORY); + } + + { + auto list = dir->listNames(); + KJ_ASSERT(list.size() == 2); + KJ_EXPECT(list[0] == "bar"); + KJ_EXPECT(list[1] == "foo"); + } + + { + auto list = dir->listEntries(); + KJ_ASSERT(list.size() == 2); + KJ_EXPECT(list[0].name == "bar"); + KJ_EXPECT(list[0].type == FsNode::Type::DIRECTORY); + KJ_EXPECT(list[1].name == "foo"); + KJ_EXPECT(list[1].type == FsNode::Type::FILE); + } + + { + auto subdir = dir->openSubdir(Path("bar")); + + KJ_EXPECT(subdir->openFile(Path("baz"))->readAllText() == "bazqux"); + } + + auto subdir = dir->openSubdir(Path("corge"), WriteMode::CREATE); + + subdir->openFile(Path("grault"), WriteMode::CREATE)->writeAll("garply"); + + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "garply"); + + dir->openFile(Path({"corge", "grault"}), WriteMode::CREATE | WriteMode::MODIFY) + ->write(0, StringPtr("rag").asBytes()); + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "ragply"); + + KJ_EXPECT(dir->openSubdir(Path("corge"))->listNames().size() == 1); + + { + auto replacer = + dir->replaceFile(Path({"corge", "grault"}), WriteMode::CREATE | WriteMode::MODIFY); + replacer->get().writeAll("rag"); + + // temp file not in list + KJ_EXPECT(dir->openSubdir(Path("corge"))->listNames().size() == 1); + + // Don't commit. + } + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "ragply"); + + { + auto replacer = + dir->replaceFile(Path({"corge", "grault"}), WriteMode::CREATE | WriteMode::MODIFY); + replacer->get().writeAll("rag"); + + // temp file not in list + KJ_EXPECT(dir->openSubdir(Path("corge"))->listNames().size() == 1); + + replacer->commit(); + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "rag"); + } + + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "rag"); + + { + auto appender = dir->appendFile(Path({"corge", "grault"}), WriteMode::MODIFY); + appender->write("waldo", 5); + appender->write("fred", 4); + } + + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "ragwaldofred"); + + KJ_EXPECT(dir->exists(Path("foo"))); + dir->remove(Path("foo")); + KJ_EXPECT(!dir->exists(Path("foo"))); + KJ_EXPECT(!dir->tryRemove(Path("foo"))); + + KJ_EXPECT(dir->exists(Path({"bar", "baz"}))); + dir->remove(Path({"bar", "baz"})); + KJ_EXPECT(!dir->exists(Path({"bar", "baz"}))); + KJ_EXPECT(dir->exists(Path("bar"))); + KJ_EXPECT(!dir->tryRemove(Path({"bar", "baz"}))); + +#if _WIN32 + // On Windows, we can't delete a directory while we still have it open. + subdir = nullptr; +#endif + + KJ_EXPECT(dir->exists(Path("corge"))); + KJ_EXPECT(dir->exists(Path({"corge", "grault"}))); + dir->remove(Path("corge")); + KJ_EXPECT(!dir->exists(Path("corge"))); + KJ_EXPECT(!dir->exists(Path({"corge", "grault"}))); + KJ_EXPECT(!dir->tryRemove(Path("corge"))); +} + +#if !_WIN32 // Creating symlinks on Win32 requires admin privileges prior to Windows 10. +KJ_TEST("DiskDirectory symlinks") { + TempDir tempDir; + auto dir = tempDir.get(); + + dir->symlink(Path("foo"), "bar/qux/../baz", WriteMode::CREATE); + + KJ_EXPECT(!dir->trySymlink(Path("foo"), "bar/qux/../baz", WriteMode::CREATE)); + + { + auto stats = dir->lstat(Path("foo")); + KJ_EXPECT(stats.type == FsNode::Type::SYMLINK); + } + + KJ_EXPECT(dir->readlink(Path("foo")) == "bar/qux/../baz"); + + // Broken link into non-existing directory cannot be opened in any mode. + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::CREATE) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::MODIFY) == nullptr); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("parent is not a directory", + dir->tryOpenFile(Path("foo"), WriteMode::CREATE | WriteMode::MODIFY)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("parent is not a directory", + dir->tryOpenFile(Path("foo"), + WriteMode::CREATE | WriteMode::MODIFY | WriteMode::CREATE_PARENT)); + + // Create the directory. + auto subdir = dir->openSubdir(Path("bar"), WriteMode::CREATE); + subdir->openSubdir(Path("qux"), WriteMode::CREATE); + + // Link still points to non-existing file so cannot be open in most modes. + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::CREATE) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::MODIFY) == nullptr); + + // But... CREATE | MODIFY works. + dir->openFile(Path("foo"), WriteMode::CREATE | WriteMode::MODIFY) + ->writeAll("foobar"); + + KJ_EXPECT(dir->openFile(Path({"bar", "baz"}))->readAllText() == "foobar"); + KJ_EXPECT(dir->openFile(Path("foo"))->readAllText() == "foobar"); + KJ_EXPECT(dir->openFile(Path("foo"), WriteMode::MODIFY)->readAllText() == "foobar"); + + // operations that modify the symlink + dir->symlink(Path("foo"), "corge", WriteMode::MODIFY); + KJ_EXPECT(dir->openFile(Path({"bar", "baz"}))->readAllText() == "foobar"); + KJ_EXPECT(dir->readlink(Path("foo")) == "corge"); + KJ_EXPECT(!dir->exists(Path("foo"))); + KJ_EXPECT(dir->lstat(Path("foo")).type == FsNode::Type::SYMLINK); + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); + + dir->remove(Path("foo")); + KJ_EXPECT(!dir->exists(Path("foo"))); + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); +} +#endif + +KJ_TEST("DiskDirectory link") { + TempDir tempDirSrc; + TempDir tempDirDst; + + auto src = tempDirSrc.get(); + auto dst = tempDirDst.get(); + + src->openFile(Path("foo"), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + + dst->transfer(Path("link"), WriteMode::CREATE, *src, Path("foo"), TransferMode::LINK); + + KJ_EXPECT(dst->openFile(Path("link"))->readAllText() == "foobar"); + + // Writing the old location modifies the new. + src->openFile(Path("foo"), WriteMode::MODIFY)->writeAll("bazqux"); + KJ_EXPECT(dst->openFile(Path("link"))->readAllText() == "bazqux"); + + // Replacing the old location doesn't modify the new. + { + auto replacer = src->replaceFile(Path("foo"), WriteMode::MODIFY); + replacer->get().writeAll("corge"); + replacer->commit(); + } + KJ_EXPECT(src->openFile(Path("foo"))->readAllText() == "corge"); + KJ_EXPECT(dst->openFile(Path("link"))->readAllText() == "bazqux"); +} + +KJ_TEST("DiskDirectory copy") { + TempDir tempDirSrc; + TempDir tempDirDst; + + auto src = tempDirSrc.get(); + auto dst = tempDirDst.get(); + + src->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + src->openFile(Path({"foo", "baz", "qux"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("bazqux"); + + dst->transfer(Path("link"), WriteMode::CREATE, *src, Path("foo"), TransferMode::COPY); + + KJ_EXPECT(src->openFile(Path({"foo", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(src->openFile(Path({"foo", "baz", "qux"}))->readAllText() == "bazqux"); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); + + KJ_EXPECT(dst->exists(Path({"link", "bar"}))); + src->remove(Path({"foo", "bar"})); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); +} + +KJ_TEST("DiskDirectory copy-replace") { + TempDir tempDirSrc; + TempDir tempDirDst; + + auto src = tempDirSrc.get(); + auto dst = tempDirDst.get(); + + src->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + src->openFile(Path({"foo", "baz", "qux"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("bazqux"); + + dst->openFile(Path({"link", "corge"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("abcd"); + + // CREATE fails. + KJ_EXPECT(!dst->tryTransfer(Path("link"), WriteMode::CREATE, + *src, Path("foo"), TransferMode::COPY)); + + // Verify nothing changed. + KJ_EXPECT(dst->openFile(Path({"link", "corge"}))->readAllText() == "abcd"); + KJ_EXPECT(!dst->exists(Path({"foo", "bar"}))); + + // Now try MODIFY. + dst->transfer(Path("link"), WriteMode::MODIFY, *src, Path("foo"), TransferMode::COPY); + + KJ_EXPECT(src->openFile(Path({"foo", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(src->openFile(Path({"foo", "baz", "qux"}))->readAllText() == "bazqux"); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); + KJ_EXPECT(!dst->exists(Path({"link", "corge"}))); + + KJ_EXPECT(dst->exists(Path({"link", "bar"}))); + src->remove(Path({"foo", "bar"})); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); +} + +KJ_TEST("DiskDirectory move") { + TempDir tempDirSrc; + TempDir tempDirDst; + + auto src = tempDirSrc.get(); + auto dst = tempDirDst.get(); + + src->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + src->openFile(Path({"foo", "baz", "qux"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("bazqux"); + + dst->transfer(Path("link"), WriteMode::CREATE, *src, Path("foo"), TransferMode::MOVE); + + KJ_EXPECT(!src->exists(Path({"foo"}))); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); +} + +KJ_TEST("DiskDirectory move-replace") { + TempDir tempDirSrc; + TempDir tempDirDst; + + auto src = tempDirSrc.get(); + auto dst = tempDirDst.get(); + + src->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + src->openFile(Path({"foo", "baz", "qux"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("bazqux"); + + dst->openFile(Path({"link", "corge"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("abcd"); + + // CREATE fails. + KJ_EXPECT(!dst->tryTransfer(Path("link"), WriteMode::CREATE, + *src, Path("foo"), TransferMode::MOVE)); + + // Verify nothing changed. + KJ_EXPECT(dst->openFile(Path({"link", "corge"}))->readAllText() == "abcd"); + KJ_EXPECT(!dst->exists(Path({"foo", "bar"}))); + KJ_EXPECT(src->exists(Path({"foo"}))); + + // Now try MODIFY. + dst->transfer(Path("link"), WriteMode::MODIFY, *src, Path("foo"), TransferMode::MOVE); + + KJ_EXPECT(!src->exists(Path({"foo"}))); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); +} + +KJ_TEST("DiskDirectory createTemporary") { + TempDir tempDir; + auto dir = tempDir.get(); + auto file = dir->createTemporary(); + file->writeAll("foobar"); + KJ_EXPECT(file->readAllText() == "foobar"); + KJ_EXPECT(dir->listNames() == nullptr); +} + +#if !__CYGWIN__ // TODO(someday): Figure out why this doesn't work on Cygwin. +KJ_TEST("DiskDirectory replaceSubdir()") { + TempDir tempDir; + auto dir = tempDir.get(); + + { + auto replacer = dir->replaceSubdir(Path("foo"), WriteMode::CREATE); + replacer->get().openFile(Path("bar"), WriteMode::CREATE)->writeAll("original"); + KJ_EXPECT(replacer->get().openFile(Path("bar"))->readAllText() == "original"); + KJ_EXPECT(!dir->exists(Path({"foo", "bar"}))); + + replacer->commit(); + KJ_EXPECT(replacer->get().openFile(Path("bar"))->readAllText() == "original"); + KJ_EXPECT(dir->openFile(Path({"foo", "bar"}))->readAllText() == "original"); + } + + { + // CREATE fails -- already exists. + auto replacer = dir->replaceSubdir(Path("foo"), WriteMode::CREATE); + replacer->get().openFile(Path("corge"), WriteMode::CREATE)->writeAll("bazqux"); + KJ_EXPECT(dir->listNames().size() == 1 && dir->listNames()[0] == "foo"); + KJ_EXPECT(!replacer->tryCommit()); + } + + // Unchanged. + KJ_EXPECT(dir->openFile(Path({"foo", "bar"}))->readAllText() == "original"); + KJ_EXPECT(!dir->exists(Path({"foo", "corge"}))); + + { + // MODIFY succeeds. + auto replacer = dir->replaceSubdir(Path("foo"), WriteMode::MODIFY); + replacer->get().openFile(Path("corge"), WriteMode::CREATE)->writeAll("bazqux"); + KJ_EXPECT(dir->listNames().size() == 1 && dir->listNames()[0] == "foo"); + replacer->commit(); + } + + // Replaced with new contents. + KJ_EXPECT(!dir->exists(Path({"foo", "bar"}))); + KJ_EXPECT(dir->openFile(Path({"foo", "corge"}))->readAllText() == "bazqux"); +} +#endif // !__CYGWIN__ + +KJ_TEST("DiskDirectory replace directory with file") { + TempDir tempDir; + auto dir = tempDir.get(); + + dir->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + + { + // CREATE fails -- already exists. + auto replacer = dir->replaceFile(Path("foo"), WriteMode::CREATE); + replacer->get().writeAll("bazqux"); + KJ_EXPECT(!replacer->tryCommit()); + } + + // Still a directory. + KJ_EXPECT(dir->lstat(Path("foo")).type == FsNode::Type::DIRECTORY); + + { + // MODIFY succeeds. + auto replacer = dir->replaceFile(Path("foo"), WriteMode::MODIFY); + replacer->get().writeAll("bazqux"); + replacer->commit(); + } + + // Replaced with file. + KJ_EXPECT(dir->openFile(Path("foo"))->readAllText() == "bazqux"); +} + +KJ_TEST("DiskDirectory replace file with directory") { + TempDir tempDir; + auto dir = tempDir.get(); + + dir->openFile(Path("foo"), WriteMode::CREATE) + ->writeAll("foobar"); + + { + // CREATE fails -- already exists. + auto replacer = dir->replaceSubdir(Path("foo"), WriteMode::CREATE); + replacer->get().openFile(Path("bar"), WriteMode::CREATE)->writeAll("bazqux"); + KJ_EXPECT(dir->listNames().size() == 1 && dir->listNames()[0] == "foo"); + KJ_EXPECT(!replacer->tryCommit()); + } + + // Still a file. + KJ_EXPECT(dir->openFile(Path("foo"))->readAllText() == "foobar"); + + { + // MODIFY succeeds. + auto replacer = dir->replaceSubdir(Path("foo"), WriteMode::MODIFY); + replacer->get().openFile(Path("bar"), WriteMode::CREATE)->writeAll("bazqux"); + KJ_EXPECT(dir->listNames().size() == 1 && dir->listNames()[0] == "foo"); + replacer->commit(); + } + + // Replaced with directory. + KJ_EXPECT(dir->openFile(Path({"foo", "bar"}))->readAllText() == "bazqux"); +} + +#if !defined(HOLES_NOT_SUPPORTED) && (CAPNP_DEBUG_TYPES || CAPNP_EXPENSIVE_TESTS) +// Not all filesystems support sparse files, and if they do, they don't necessarily support +// copying them in a way that preserves holes. We don't want the capnp test suite to fail just +// because it was run on the wrong filesystem. We could design the test to check first if the +// filesystem supports holes, but the code to do that would be almost the same as the code being +// tested... Instead, we've marked this test so it only runs when building this library using +// defines that only the Cap'n Proto maintainers use. So, we run the test ourselves but we don't +// make other people run it. + +KJ_TEST("DiskFile holes") { + if (isWine()) { + // WINE doesn't support sparse files. + return; + } + + TempDir tempDir; + auto dir = tempDir.get(); + + auto file = dir->openFile(Path("holes"), WriteMode::CREATE); + +#if _WIN32 + FILE_SET_SPARSE_BUFFER sparseInfo; + memset(&sparseInfo, 0, sizeof(sparseInfo)); + sparseInfo.SetSparse = TRUE; + DWORD dummy; + KJ_WIN32(DeviceIoControl( + KJ_ASSERT_NONNULL(file->getWin32Handle()), + FSCTL_SET_SPARSE, &sparseInfo, sizeof(sparseInfo), + NULL, 0, &dummy, NULL)); +#endif + + file->writeAll("foobar"); + file->write(1 << 20, StringPtr("foobar").asBytes()); + + // Some filesystems, like BTRFS, report zero `spaceUsed` until synced. + file->datasync(); + + // Allow for block sizes as low as 512 bytes and as high as 64k. Since we wrote two locations, + // two blocks should be used. + auto meta = file->stat(); +#if __FreeBSD__ + // On FreeBSD with ZFS it seems to report 512 bytes used even if I write more than 512 random + // (i.e. non-compressible) bytes. I couldn't figure it out so I'm giving up for now. Maybe it's + // a bug in the system? + KJ_EXPECT(meta.spaceUsed >= 512, meta.spaceUsed); +#else + KJ_EXPECT(meta.spaceUsed >= 2 * 512, meta.spaceUsed); +#endif + KJ_EXPECT(meta.spaceUsed <= 2 * 65536); + + byte buf[7]; + +#if !_WIN32 // Win32 CopyFile() does NOT preserve sparseness. + { + // Copy doesn't fill in holes. + dir->transfer(Path("copy"), WriteMode::CREATE, Path("holes"), TransferMode::COPY); + auto copy = dir->openFile(Path("copy")); + KJ_EXPECT(copy->stat().spaceUsed == meta.spaceUsed); + KJ_EXPECT(copy->read(0, buf) == 7); + KJ_EXPECT(StringPtr(reinterpret_cast(buf), 6) == "foobar"); + + KJ_EXPECT(copy->read(1 << 20, buf) == 6); + KJ_EXPECT(StringPtr(reinterpret_cast(buf), 6) == "foobar"); + + KJ_EXPECT(copy->read(1 << 19, buf) == 7); + KJ_EXPECT(StringPtr(reinterpret_cast(buf), 6) == StringPtr("\0\0\0\0\0\0", 6)); + } +#endif + + file->truncate(1 << 21); + file->datasync(); + KJ_EXPECT(file->stat().spaceUsed == meta.spaceUsed); + KJ_EXPECT(file->read(1 << 20, buf) == 7); + KJ_EXPECT(StringPtr(reinterpret_cast(buf), 6) == "foobar"); + +#if !_WIN32 // Win32 CopyFile() does NOT preserve sparseness. + { + dir->transfer(Path("copy"), WriteMode::MODIFY, Path("holes"), TransferMode::COPY); + auto copy = dir->openFile(Path("copy")); + KJ_EXPECT(copy->stat().spaceUsed == meta.spaceUsed); + KJ_EXPECT(copy->read(0, buf) == 7); + KJ_EXPECT(StringPtr(reinterpret_cast(buf), 6) == "foobar"); + + KJ_EXPECT(copy->read(1 << 20, buf) == 7); + KJ_EXPECT(StringPtr(reinterpret_cast(buf), 6) == "foobar"); + + KJ_EXPECT(copy->read(1 << 19, buf) == 7); + KJ_EXPECT(StringPtr(reinterpret_cast(buf), 6) == StringPtr("\0\0\0\0\0\0", 6)); + } +#endif + + // Try punching a hole with zero(). +#if _WIN32 + uint64_t blockSize = 4096; // TODO(someday): Actually ask the OS. +#else + struct stat stats; + KJ_SYSCALL(fstat(KJ_ASSERT_NONNULL(file->getFd()), &stats)); + uint64_t blockSize = stats.st_blksize; +#endif + file->zero(1 << 20, blockSize); + file->datasync(); +#if !_WIN32 && !__FreeBSD__ + // TODO(someday): This doesn't work on Windows. I don't know why. We're definitely using the + // proper ioctl. Oh well. It also doesn't work on FreeBSD-ZFS, due to the bug(?) mentioned + // earlier -- the size is just always reported as 512. + KJ_EXPECT(file->stat().spaceUsed < meta.spaceUsed); +#endif + KJ_EXPECT(file->read(1 << 20, buf) == 7); + KJ_EXPECT(StringPtr(reinterpret_cast(buf), 6) == StringPtr("\0\0\0\0\0\0", 6)); +} +#endif + +#if !_WIN32 // Only applies to Unix. +// Ensure the current path is correctly computed. +// +// See issue #1425. +KJ_TEST("DiskFilesystem::computeCurrentPath") { + TempDir tempDir; + auto dir = tempDir.get(); + + // Paths can be PATH_MAX, but the segments which make up that path typically + // can't exceed 255 bytes. + auto maxPathSegment = std::string(255, 'a'); + + // Create a path which exceeds the 256 byte buffer used in + // computeCurrentPath. + auto subdir = dir->openSubdir(Path({ + maxPathSegment, + maxPathSegment, + "some_path_longer_than_256_bytes" + }), WriteMode::CREATE | WriteMode::CREATE_PARENT); + + auto origDir = open(".", O_RDONLY); + KJ_SYSCALL(fchdir(KJ_ASSERT_NONNULL(subdir->getFd()))); + KJ_DEFER(KJ_SYSCALL(fchdir(origDir))); + + // Test computeCurrentPath indirectly. + newDiskFilesystem(); +} +#endif + +} // namespace +} // namespace kj diff --git a/c++/src/kj/filesystem-disk-unix.c++ b/c++/src/kj/filesystem-disk-unix.c++ new file mode 100644 index 0000000000..67d7bf22c7 --- /dev/null +++ b/c++/src/kj/filesystem-disk-unix.c++ @@ -0,0 +1,1788 @@ +// Copyright (c) 2015 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if !_WIN32 + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#ifndef _FILE_OFFSET_BITS +#define _FILE_OFFSET_BITS 64 +// Request 64-bit off_t. (The code will still work if we get 32-bit off_t as long as actual files +// are under 4GB.) +#endif + +#include "filesystem.h" +#include "debug.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "vector.h" +#include "miniposix.h" +#include + +#if __linux__ +#include +#include +#include +#endif + +namespace kj { +namespace { + +#define HIDDEN_PREFIX ".kj-tmp." +// Prefix for temp files which should be hidden when listing a directory. +// +// If you change this, make sure to update the unit test. + +#ifdef O_CLOEXEC +#define MAYBE_O_CLOEXEC O_CLOEXEC +#else +#define MAYBE_O_CLOEXEC 0 +#endif + +#ifdef O_DIRECTORY +#define MAYBE_O_DIRECTORY O_DIRECTORY +#else +#define MAYBE_O_DIRECTORY 0 +#endif + +#if __APPLE__ +// Mac OSX defines SEEK_HOLE, but it doesn't work. ("Inappropriate ioctl for device", it says.) +#undef SEEK_HOLE +#endif + +#if __BIONIC__ +// No no DTTOIF function +#undef DT_UNKNOWN +#endif + +static void setCloexec(int fd) KJ_UNUSED; +static void setCloexec(int fd) { + // Set the O_CLOEXEC flag on the given fd. + // + // We try to avoid the need to call this by taking advantage of syscall flags that set it + // atomically on new file descriptors. Unfortunately some platforms do not support such syscalls. + +#ifdef FIOCLEX + // Yay, we can set the flag in one call. + KJ_SYSCALL_HANDLE_ERRORS(ioctl(fd, FIOCLEX)) { + case EINVAL: + case EOPNOTSUPP: + break; + default: + KJ_FAIL_SYSCALL("ioctl(fd, FIOCLEX)", error) { break; } + break; + } else { + // success + return; + } +#endif + + // Sadness, we must resort to read/modify/write. + // + // (On many platforms, FD_CLOEXEC is the only flag modifiable via F_SETFD and therefore we could + // skip the read... but it seems dangerous to assume that's true of all platforms, and anyway + // most platforms support FIOCLEX.) + int flags; + KJ_SYSCALL(flags = fcntl(fd, F_GETFD)); + if (!(flags & FD_CLOEXEC)) { + KJ_SYSCALL(fcntl(fd, F_SETFD, flags | FD_CLOEXEC)); + } +} + +static Date toKjDate(struct timespec tv) { + return tv.tv_sec * SECONDS + tv.tv_nsec * NANOSECONDS + UNIX_EPOCH; +} + +static FsNode::Type modeToType(mode_t mode) { + switch (mode & S_IFMT) { + case S_IFREG : return FsNode::Type::FILE; + case S_IFDIR : return FsNode::Type::DIRECTORY; + case S_IFLNK : return FsNode::Type::SYMLINK; + case S_IFBLK : return FsNode::Type::BLOCK_DEVICE; + case S_IFCHR : return FsNode::Type::CHARACTER_DEVICE; + case S_IFIFO : return FsNode::Type::NAMED_PIPE; + case S_IFSOCK: return FsNode::Type::SOCKET; + default: return FsNode::Type::OTHER; + } +} + +static FsNode::Metadata statToMetadata(struct stat& stats) { + // Probably st_ino and st_dev are usually under 32 bits, so mix by rotating st_dev left 32 bits + // and XOR. + uint64_t d = stats.st_dev; + uint64_t hash = ((d << 32) | (d >> 32)) ^ stats.st_ino; + + return FsNode::Metadata { + modeToType(stats.st_mode), + implicitCast(stats.st_size), + implicitCast(stats.st_blocks * 512u), +#if __APPLE__ + toKjDate(stats.st_mtimespec), +#else + toKjDate(stats.st_mtim), +#endif + implicitCast(stats.st_nlink), + hash + }; +} + +static bool rmrf(int fd, StringPtr path); + +static void rmrfChildrenAndClose(int fd) { + // Assumes fd is seeked to beginning. + + DIR* dir = fdopendir(fd); + if (dir == nullptr) { + close(fd); + KJ_FAIL_SYSCALL("fdopendir", errno); + }; + KJ_DEFER(closedir(dir)); + + for (;;) { + errno = 0; + struct dirent* entry = readdir(dir); + if (entry == nullptr) { + int error = errno; + if (error == 0) { + break; + } else { + KJ_FAIL_SYSCALL("readdir", error); + } + } + + if (entry->d_name[0] == '.' && + (entry->d_name[1] == '\0' || + (entry->d_name[1] == '.' && + entry->d_name[2] == '\0'))) { + // ignore . and .. + } else { +#ifdef DT_UNKNOWN // d_type is not available on all platforms. + if (entry->d_type == DT_DIR) { + int subdirFd; + KJ_SYSCALL(subdirFd = openat( + fd, entry->d_name, O_RDONLY | MAYBE_O_DIRECTORY | MAYBE_O_CLOEXEC | O_NOFOLLOW)); + rmrfChildrenAndClose(subdirFd); + KJ_SYSCALL(unlinkat(fd, entry->d_name, AT_REMOVEDIR)); + } else if (entry->d_type != DT_UNKNOWN) { + KJ_SYSCALL(unlinkat(fd, entry->d_name, 0)); + } else { +#endif + KJ_ASSERT(rmrf(fd, entry->d_name)); +#ifdef DT_UNKNOWN + } +#endif + } + } +} + +static bool rmrf(int fd, StringPtr path) { + struct stat stats; + KJ_SYSCALL_HANDLE_ERRORS(fstatat(fd, path.cStr(), &stats, AT_SYMLINK_NOFOLLOW)) { + case ENOENT: + case ENOTDIR: + // Doesn't exist. + return false; + default: + KJ_FAIL_SYSCALL("lstat(path)", error, path) { return false; } + } + + if (S_ISDIR(stats.st_mode)) { + int subdirFd; + KJ_SYSCALL(subdirFd = openat( + fd, path.cStr(), O_RDONLY | MAYBE_O_DIRECTORY | MAYBE_O_CLOEXEC | O_NOFOLLOW)) { + return false; + } + rmrfChildrenAndClose(subdirFd); + KJ_SYSCALL(unlinkat(fd, path.cStr(), AT_REMOVEDIR)) { return false; } + } else { + KJ_SYSCALL(unlinkat(fd, path.cStr(), 0)) { return false; } + } + + return true; +} + +struct MmapRange { + uint64_t offset; + uint64_t size; +}; + +static MmapRange getMmapRange(uint64_t offset, uint64_t size) { + // Comes up with an offset and size to pass to mmap(), given an offset and size requested by + // the caller, and considering the fact that mappings must start at a page boundary. + // + // The offset is rounded down to the nearest page boundary, and the size is increased to + // compensate. Note that the endpoint of the mapping is *not* rounded up to a page boundary, as + // mmap() does not actually require this, and it causes trouble on some systems (notably Cygwin). + +#ifndef _SC_PAGESIZE +#define _SC_PAGESIZE _SC_PAGE_SIZE +#endif + static const uint64_t pageSize = sysconf(_SC_PAGESIZE); + uint64_t pageMask = pageSize - 1; + + uint64_t realOffset = offset & ~pageMask; + + return { realOffset, offset + size - realOffset }; +} + +class MmapDisposer: public ArrayDisposer { +protected: + void disposeImpl(void* firstElement, size_t elementSize, size_t elementCount, + size_t capacity, void (*destroyElement)(void*)) const { + auto range = getMmapRange(reinterpret_cast(firstElement), + elementSize * elementCount); + KJ_SYSCALL(munmap(reinterpret_cast(range.offset), range.size)) { break; } + } +}; + +constexpr MmapDisposer mmapDisposer = MmapDisposer(); + +class DiskHandle { + // We need to implement each of ReadableFile, AppendableFile, File, ReadableDirectory, and + // Directory for disk handles. There is a lot of implementation overlap between these, especially + // stat(), sync(), etc. We can't have everything inherit from a common DiskFsNode that implements + // these because then we get diamond inheritance which means we need to make all our inheritance + // virtual which means downcasting requires RTTI which violates our goal of supporting compiling + // with no RTTI. So instead we have the DiskHandle class which implements all the methods without + // inheriting anything, and then we have DiskFile, DiskDirectory, etc. hold this and delegate to + // it. Ugly, but works. + +public: + DiskHandle(AutoCloseFd&& fd): fd(kj::mv(fd)) {} + + // OsHandle ------------------------------------------------------------------ + + AutoCloseFd clone() const { + int fd2; +#ifdef F_DUPFD_CLOEXEC + KJ_SYSCALL_HANDLE_ERRORS(fd2 = fcntl(fd, F_DUPFD_CLOEXEC, 3)) { + case EINVAL: + case EOPNOTSUPP: + // fall back + break; + default: + KJ_FAIL_SYSCALL("fnctl(fd, F_DUPFD_CLOEXEC, 3)", error) { break; } + break; + } else { + return AutoCloseFd(fd2); + } +#endif + + KJ_SYSCALL(fd2 = ::dup(fd)); + AutoCloseFd result(fd2); + setCloexec(result); + return result; + } + + int getFd() const { + return fd.get(); + } + + void setFd(AutoCloseFd newFd) { + // Used for one hack in DiskFilesystem's constructor... + fd = kj::mv(newFd); + } + + // FsNode -------------------------------------------------------------------- + + FsNode::Metadata stat() const { + struct stat stats; + KJ_SYSCALL(::fstat(fd, &stats)); + return statToMetadata(stats); + } + + void sync() const { +#if __APPLE__ + // For whatever reason, fsync() on OSX only flushes kernel buffers. It does not flush hardware + // disk buffers. This makes it not very useful. But OSX documents fcntl F_FULLFSYNC which does + // the right thing. Why they don't just make fsync() do the right thing, I do not know. + KJ_SYSCALL(fcntl(fd, F_FULLFSYNC)); +#else + KJ_SYSCALL(fsync(fd)); +#endif + } + + void datasync() const { + // The presence of the _POSIX_SYNCHRONIZED_IO define is supposed to tell us that fdatasync() + // exists. But Apple defines this yet doesn't offer fdatasync(). Thanks, Apple. +#if _POSIX_SYNCHRONIZED_IO && !__APPLE__ + KJ_SYSCALL(fdatasync(fd)); +#else + this->sync(); +#endif + } + + // ReadableFile -------------------------------------------------------------- + + size_t read(uint64_t offset, ArrayPtr buffer) const { + // pread() probably never returns short reads unless it hits EOF. Unfortunately, though, per + // spec we are not allowed to assume this. + + size_t total = 0; + while (buffer.size() > 0) { + ssize_t n; + KJ_SYSCALL(n = pread(fd, buffer.begin(), buffer.size(), offset)); + if (n == 0) break; + total += n; + offset += n; + buffer = buffer.slice(n, buffer.size()); + } + return total; + } + + Array mmap(uint64_t offset, uint64_t size) const { + if (size == 0) return nullptr; // zero-length mmap() returns EINVAL, so avoid it + auto range = getMmapRange(offset, size); + const void* mapping = ::mmap(NULL, range.size, PROT_READ, MAP_SHARED, fd, range.offset); + if (mapping == MAP_FAILED) { + KJ_FAIL_SYSCALL("mmap", errno); + } + return Array(reinterpret_cast(mapping) + (offset - range.offset), + size, mmapDisposer); + } + + Array mmapPrivate(uint64_t offset, uint64_t size) const { + if (size == 0) return nullptr; // zero-length mmap() returns EINVAL, so avoid it + auto range = getMmapRange(offset, size); + void* mapping = ::mmap(NULL, range.size, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, range.offset); + if (mapping == MAP_FAILED) { + KJ_FAIL_SYSCALL("mmap", errno); + } + return Array(reinterpret_cast(mapping) + (offset - range.offset), + size, mmapDisposer); + } + + // File ---------------------------------------------------------------------- + + void write(uint64_t offset, ArrayPtr data) const { + // pwrite() probably never returns short writes unless there's no space left on disk. + // Unfortunately, though, per spec we are not allowed to assume this. + + while (data.size() > 0) { + ssize_t n; + KJ_SYSCALL(n = pwrite(fd, data.begin(), data.size(), offset)); + KJ_ASSERT(n > 0, "pwrite() returned zero?"); + offset += n; + data = data.slice(n, data.size()); + } + } + + void zero(uint64_t offset, uint64_t size) const { + // If FALLOC_FL_PUNCH_HOLE is defined, use it to efficiently zero the area. + // + // A fallocate() wrapper was only added to Android's Bionic C library as of API level 21, + // but FALLOC_FL_PUNCH_HOLE is apparently defined in the headers before that, so we'll + // have to explicitly test for that case. +#if defined(FALLOC_FL_PUNCH_HOLE) && !(__ANDROID__ && __BIONIC__ && __ANDROID_API__ < 21) + KJ_SYSCALL_HANDLE_ERRORS( + fallocate(fd, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, offset, size)) { + case EOPNOTSUPP: + // fall back to below + break; + default: + KJ_FAIL_SYSCALL("fallocate(FALLOC_FL_PUNCH_HOLE)", error) { return; } + } else { + return; + } +#endif + + static const byte ZEROS[4096] = { 0 }; + +#if __APPLE__ || __CYGWIN__ || (defined(__ANDROID__) && __ANDROID_API__ < 24) + // Mac & Cygwin & Android API levels 23 and lower doesn't have pwritev(). + while (size > sizeof(ZEROS)) { + write(offset, ZEROS); + size -= sizeof(ZEROS); + offset += sizeof(ZEROS); + } + write(offset, kj::arrayPtr(ZEROS, size)); +#else + // Use a 4k buffer of zeros amplified by iov to write zeros with as few syscalls as possible. + size_t count = (size + sizeof(ZEROS) - 1) / sizeof(ZEROS); + const size_t iovmax = miniposix::iovMax(); + KJ_STACK_ARRAY(struct iovec, iov, kj::min(iovmax, count), 16, 256); + + for (auto& item: iov) { + item.iov_base = const_cast(ZEROS); + item.iov_len = sizeof(ZEROS); + } + + while (size > 0) { + size_t iovCount; + if (size >= iov.size() * sizeof(ZEROS)) { + iovCount = iov.size(); + } else { + iovCount = size / sizeof(ZEROS); + size_t rem = size % sizeof(ZEROS); + if (rem > 0) { + iov[iovCount++].iov_len = rem; + } + } + + ssize_t n; + KJ_SYSCALL(n = pwritev(fd, iov.begin(), count, offset)); + KJ_ASSERT(n > 0, "pwrite() returned zero?"); + + offset += n; + size -= n; + } +#endif + } + + void truncate(uint64_t size) const { + KJ_SYSCALL(ftruncate(fd, size)); + } + + class WritableFileMappingImpl final: public WritableFileMapping { + public: + WritableFileMappingImpl(Array bytes): bytes(kj::mv(bytes)) {} + + ArrayPtr get() const override { + // const_cast OK because WritableFileMapping does indeed provide a writable view despite + // being const itself. + return arrayPtr(const_cast(bytes.begin()), bytes.size()); + } + + void changed(ArrayPtr slice) const override { + KJ_REQUIRE(slice.begin() >= bytes.begin() && slice.end() <= bytes.end(), + "byte range is not part of this mapping"); + if (slice.size() == 0) return; + + // msync() requires page-alignment, apparently, so use getMmapRange() to accomplish that. + auto range = getMmapRange(reinterpret_cast(slice.begin()), slice.size()); + KJ_SYSCALL(msync(reinterpret_cast(range.offset), range.size, MS_ASYNC)); + } + + void sync(ArrayPtr slice) const override { + KJ_REQUIRE(slice.begin() >= bytes.begin() && slice.end() <= bytes.end(), + "byte range is not part of this mapping"); + if (slice.size() == 0) return; + + // msync() requires page-alignment, apparently, so use getMmapRange() to accomplish that. + auto range = getMmapRange(reinterpret_cast(slice.begin()), slice.size()); + KJ_SYSCALL(msync(reinterpret_cast(range.offset), range.size, MS_SYNC)); + } + + private: + Array bytes; + }; + + Own mmapWritable(uint64_t offset, uint64_t size) const { + if (size == 0) { + // zero-length mmap() returns EINVAL, so avoid it + return heap(nullptr); + } + auto range = getMmapRange(offset, size); + void* mapping = ::mmap(NULL, range.size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, range.offset); + if (mapping == MAP_FAILED) { + KJ_FAIL_SYSCALL("mmap", errno); + } + auto array = Array(reinterpret_cast(mapping) + (offset - range.offset), + size, mmapDisposer); + return heap(kj::mv(array)); + } + + size_t copyChunk(uint64_t offset, int fromFd, uint64_t fromOffset, uint64_t size) const { + // Copies a range of bytes from `fromFd` to this file in the most efficient way possible for + // the OS. Only returns less than `size` if EOF. Does not account for holes. + +#if __linux__ + { + KJ_SYSCALL(lseek(fd, offset, SEEK_SET)); + off_t fromPos = fromOffset; + off_t end = fromOffset + size; + while (fromPos < end) { + ssize_t n; + KJ_SYSCALL_HANDLE_ERRORS(n = sendfile(fd, fromFd, &fromPos, end - fromPos)) { + case EINVAL: + case ENOSYS: + goto sendfileNotAvailable; + default: + KJ_FAIL_SYSCALL("sendfile", error) { return fromPos - fromOffset; } + } + if (n == 0) break; + } + return fromPos - fromOffset; + } + + sendfileNotAvailable: +#endif + uint64_t total = 0; + while (size > 0) { + byte buffer[4096]; + ssize_t n; + KJ_SYSCALL(n = pread(fromFd, buffer, kj::min(sizeof(buffer), size), fromOffset)); + if (n == 0) break; + write(offset, arrayPtr(buffer, n)); + fromOffset += n; + offset += n; + total += n; + size -= n; + } + return total; + } + + kj::Maybe copy(uint64_t offset, const ReadableFile& from, + uint64_t fromOffset, uint64_t size) const { + KJ_IF_MAYBE(otherFd, from.getFd()) { +#ifdef FICLONE + if (offset == 0 && fromOffset == 0 && size == kj::maxValue && stat().size == 0) { + if (ioctl(fd, FICLONE, *otherFd) >= 0) { + return stat().size; + } + } else if (size > 0) { // src_length = 0 has special meaning for the syscall, so avoid. + struct file_clone_range range; + memset(&range, 0, sizeof(range)); + range.src_fd = *otherFd; + range.dest_offset = offset; + range.src_offset = fromOffset; + range.src_length = size == kj::maxValue ? 0 : size; + if (ioctl(fd, FICLONERANGE, &range) >= 0) { + // TODO(someday): What does FICLONERANGE actually do if the range goes past EOF? The docs + // don't say. Maybe it only copies the parts that exist. Maybe it punches holes for the + // rest. Where does the destination file's EOF marker end up? Who knows? + return kj::min(from.stat().size - fromOffset, size); + } + } else { + // size == 0 + return size_t(0); + } + + // ioctl failed. Almost all failures documented for these are of the form "the operation is + // not supported for the filesystem(s) specified", so fall back to other approaches. +#endif + + off_t toPos = offset; + off_t fromPos = fromOffset; + off_t end = size == kj::maxValue ? off_t(kj::maxValue) : off_t(fromOffset + size); + + for (;;) { + // Handle data. + { + // Find out how much data there is before the next hole. + off_t nextHole; +#ifdef SEEK_HOLE + KJ_SYSCALL_HANDLE_ERRORS(nextHole = lseek(*otherFd, fromPos, SEEK_HOLE)) { + case EINVAL: + // SEEK_HOLE probably not supported. Assume no holes. + nextHole = end; + break; + case ENXIO: + // Past EOF. Stop here. + return fromPos - fromOffset; + default: + KJ_FAIL_SYSCALL("lseek(fd, pos, SEEK_HOLE)", error) { return fromPos - fromOffset; } + } +#else + // SEEK_HOLE not supported. Assume no holes. + nextHole = end; +#endif + + // Copy the next chunk of data. + off_t copyTo = kj::min(end, nextHole); + size_t amount = copyTo - fromPos; + if (amount > 0) { + size_t n = copyChunk(toPos, *otherFd, fromPos, amount); + fromPos += n; + toPos += n; + + if (n < amount) { + return fromPos - fromOffset; + } + } + + if (fromPos == end) { + return fromPos - fromOffset; + } + } + +#ifdef SEEK_HOLE + // Handle hole. + { + // Find out how much hole there is before the next data. + off_t nextData; + KJ_SYSCALL_HANDLE_ERRORS(nextData = lseek(*otherFd, fromPos, SEEK_DATA)) { + case EINVAL: + // SEEK_DATA probably not supported. But we should only have gotten here if we + // were expecting a hole. + KJ_FAIL_ASSERT("can't determine hole size; SEEK_DATA not supported"); + break; + case ENXIO: + // No more data. Set to EOF. + KJ_SYSCALL(nextData = lseek(*otherFd, 0, SEEK_END)); + if (nextData > end) { + end = nextData; + } + break; + default: + KJ_FAIL_SYSCALL("lseek(fd, pos, SEEK_HOLE)", error) { return fromPos - fromOffset; } + } + + // Write zeros. + off_t zeroTo = kj::min(end, nextData); + off_t amount = zeroTo - fromPos; + if (amount > 0) { + zero(toPos, amount); + toPos += amount; + fromPos = zeroTo; + } + + if (fromPos == end) { + return fromPos - fromOffset; + } + } +#endif + } + } + + // Indicates caller should call File::copy() default implementation. + return nullptr; + } + + // ReadableDirectory --------------------------------------------------------- + + template + auto list(bool needTypes, Func&& func) const + -> Array(), instance()))>> { + // Seek to start of directory. + KJ_SYSCALL(lseek(fd, 0, SEEK_SET)); + + // Unfortunately, fdopendir() takes ownership of the file descriptor. Therefore we need to + // make a duplicate. + int duped; + KJ_SYSCALL(duped = dup(fd)); + DIR* dir = fdopendir(duped); + if (dir == nullptr) { + close(duped); + KJ_FAIL_SYSCALL("fdopendir", errno); + } + + KJ_DEFER(closedir(dir)); + typedef Decay(), instance()))> Entry; + kj::Vector entries; + + for (;;) { + errno = 0; + struct dirent* entry = readdir(dir); + if (entry == nullptr) { + int error = errno; + if (error == 0) { + break; + } else { + KJ_FAIL_SYSCALL("readdir", error); + } + } + + kj::StringPtr name = entry->d_name; + if (name != "." && name != ".." && !name.startsWith(HIDDEN_PREFIX)) { +#ifdef DT_UNKNOWN // d_type is not available on all platforms. + if (entry->d_type != DT_UNKNOWN) { + entries.add(func(name, modeToType(DTTOIF(entry->d_type)))); + } else { +#endif + if (needTypes) { + // Unknown type. Fall back to stat. + struct stat stats; + KJ_SYSCALL(fstatat(fd, name.cStr(), &stats, AT_SYMLINK_NOFOLLOW)); + entries.add(func(name, modeToType(stats.st_mode))); + } else { + entries.add(func(name, FsNode::Type::OTHER)); + } +#ifdef DT_UNKNOWN + } +#endif + } + } + + auto result = entries.releaseAsArray(); + std::sort(result.begin(), result.end()); + return result; + } + + Array listNames() const { + return list(false, [](StringPtr name, FsNode::Type type) { return heapString(name); }); + } + + Array listEntries() const { + return list(true, [](StringPtr name, FsNode::Type type) { + return ReadableDirectory::Entry { type, heapString(name), }; + }); + } + + bool exists(PathPtr path) const { + KJ_SYSCALL_HANDLE_ERRORS(faccessat(fd, path.toString().cStr(), F_OK, 0)) { + case ENOENT: + case ENOTDIR: + return false; + default: + KJ_FAIL_SYSCALL("faccessat(fd, path)", error, path) { return false; } + } + return true; + } + + Maybe tryLstat(PathPtr path) const { + struct stat stats; + KJ_SYSCALL_HANDLE_ERRORS(fstatat(fd, path.toString().cStr(), &stats, AT_SYMLINK_NOFOLLOW)) { + case ENOENT: + case ENOTDIR: + return nullptr; + default: + KJ_FAIL_SYSCALL("faccessat(fd, path)", error, path) { return nullptr; } + } + return statToMetadata(stats); + } + + Maybe> tryOpenFile(PathPtr path) const { + int newFd; + KJ_SYSCALL_HANDLE_ERRORS(newFd = openat( + fd, path.toString().cStr(), O_RDONLY | MAYBE_O_CLOEXEC)) { + case ENOENT: + case ENOTDIR: + return nullptr; + default: + KJ_FAIL_SYSCALL("openat(fd, path, O_RDONLY)", error, path) { return nullptr; } + } + + kj::AutoCloseFd result(newFd); +#ifndef O_CLOEXEC + setCloexec(result); +#endif + + return newDiskReadableFile(kj::mv(result)); + } + + Maybe tryOpenSubdirInternal(PathPtr path) const { + int newFd; + KJ_SYSCALL_HANDLE_ERRORS(newFd = openat( + fd, path.toString().cStr(), O_RDONLY | MAYBE_O_CLOEXEC | MAYBE_O_DIRECTORY)) { + case ENOENT: + return nullptr; + case ENOTDIR: + // Could mean that a parent is not a directory, which we treat as "doesn't exist". + // Could also mean that the specified file is not a directory, which should throw. + // Check using exists(). + if (!exists(path)) { + return nullptr; + } + KJ_FALLTHROUGH; + default: + KJ_FAIL_SYSCALL("openat(fd, path, O_DIRECTORY)", error, path) { return nullptr; } + } + + kj::AutoCloseFd result(newFd); +#ifndef O_CLOEXEC + setCloexec(result); +#endif + + return kj::mv(result); + } + + Maybe> tryOpenSubdir(PathPtr path) const { + return tryOpenSubdirInternal(path).map(newDiskReadableDirectory); + } + + Maybe tryReadlink(PathPtr path) const { + size_t trySize = 256; + for (;;) { + KJ_STACK_ARRAY(char, buf, trySize, 256, 4096); + ssize_t n = readlinkat(fd, path.toString().cStr(), buf.begin(), buf.size()); + if (n < 0) { + int error = errno; + switch (error) { + case EINTR: + continue; + case ENOENT: + case ENOTDIR: + case EINVAL: // not a link + return nullptr; + default: + KJ_FAIL_SYSCALL("readlinkat(fd, path)", error, path) { return nullptr; } + } + } + + if (n >= buf.size()) { + // Didn't give it enough space. Better retry with a bigger buffer. + trySize *= 2; + continue; + } + + return heapString(buf.begin(), n); + } + } + + // Directory ----------------------------------------------------------------- + + bool tryMkdir(PathPtr path, WriteMode mode, bool noThrow) const { + // Internal function to make a directory. + + auto filename = path.toString(); + mode_t acl = has(mode, WriteMode::PRIVATE) ? 0700 : 0777; + + KJ_SYSCALL_HANDLE_ERRORS(mkdirat(fd, filename.cStr(), acl)) { + case EEXIST: { + // Apparently this path exists. + if (!has(mode, WriteMode::MODIFY)) { + // Require exclusive create. + return false; + } + + // MODIFY is allowed, so we just need to check whether the existing entry is a directory. + struct stat stats; + KJ_SYSCALL_HANDLE_ERRORS(fstatat(fd, filename.cStr(), &stats, 0)) { + default: + // mkdir() says EEXIST but we can't stat it. Maybe it's a dangling link, or maybe + // we can't access it for some reason. Assume failure. + // + // TODO(someday): Maybe we should be creating the directory at the target of the + // link? + goto failed; + } + return (stats.st_mode & S_IFMT) == S_IFDIR; + } + case ENOENT: + if (has(mode, WriteMode::CREATE_PARENT) && path.size() > 0 && + tryMkdir(path.parent(), WriteMode::CREATE | WriteMode::MODIFY | + WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + return tryMkdir(path, mode - WriteMode::CREATE_PARENT, noThrow); + } else { + goto failed; + } + default: + failed: + if (noThrow) { + // Caller requested no throwing. + return false; + } else { + KJ_FAIL_SYSCALL("mkdirat(fd, path)", error, path); + } + } + + return true; + } + + kj::Maybe createNamedTemporary( + PathPtr finalName, WriteMode mode, Function tryCreate) const { + // Create a temporary file which will eventually replace `finalName`. + // + // Calls `tryCreate` to actually create the temporary, passing in the desired path. tryCreate() + // is expected to behave like a syscall, returning a negative value and setting `errno` on + // error. tryCreate() MUST fail with EEXIST if the path exists -- this is not checked in + // advance, since it needs to be checked atomically. In the case of EEXIST, tryCreate() will + // be called again with a new path. + // + // Returns the temporary path that succeeded. Only returns nullptr if there was an exception + // but we're compiled with -fno-exceptions. + + if (finalName.size() == 0) { + KJ_FAIL_REQUIRE("can't replace self") { break; } + return nullptr; + } + + static uint counter = 0; + static const pid_t pid = getpid(); + String pathPrefix; + if (finalName.size() > 1) { + pathPrefix = kj::str(finalName.parent(), '/'); + } + auto path = kj::str(pathPrefix, HIDDEN_PREFIX, pid, '.', counter++, '.', + finalName.basename()[0], ".partial"); + + KJ_SYSCALL_HANDLE_ERRORS(tryCreate(path)) { + case EEXIST: + return createNamedTemporary(finalName, mode, kj::mv(tryCreate)); + case ENOENT: + if (has(mode, WriteMode::CREATE_PARENT) && finalName.size() > 1 && + tryMkdir(finalName.parent(), WriteMode::CREATE | WriteMode::MODIFY | + WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + mode = mode - WriteMode::CREATE_PARENT; + return createNamedTemporary(finalName, mode, kj::mv(tryCreate)); + } + KJ_FALLTHROUGH; + default: + KJ_FAIL_SYSCALL("create(path)", error, path) { break; } + return nullptr; + } + + return kj::mv(path); + } + + bool tryReplaceNode(PathPtr path, WriteMode mode, Function tryCreate) const { + // Replaces the given path with an object created by calling tryCreate(). + // + // tryCreate() must behave like a syscall which creates the node at the path passed to it, + // returning a negative value on error. If the path passed to tryCreate already exists, it + // MUST fail with EEXIST. + // + // When `mode` includes MODIFY, replaceNode() reacts to EEXIST by creating the node in a + // temporary location and then rename()ing it into place. + + if (path.size() == 0) { + KJ_FAIL_REQUIRE("can't replace self") { return false; } + } + + auto filename = path.toString(); + + if (has(mode, WriteMode::CREATE)) { + // First try just cerating the node in-place. + KJ_SYSCALL_HANDLE_ERRORS(tryCreate(filename)) { + case EEXIST: + // Target exists. + if (has(mode, WriteMode::MODIFY)) { + // Fall back to MODIFY path, below. + break; + } else { + return false; + } + case ENOENT: + if (has(mode, WriteMode::CREATE_PARENT) && path.size() > 0 && + tryMkdir(path.parent(), WriteMode::CREATE | WriteMode::MODIFY | + WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + return tryReplaceNode(path, mode - WriteMode::CREATE_PARENT, kj::mv(tryCreate)); + } + KJ_FALLTHROUGH; + default: + KJ_FAIL_SYSCALL("create(path)", error, path) { return false; } + } else { + // Success. + return true; + } + } + + // Either we don't have CREATE mode or the target already exists. We need to perform a + // replacement instead. + + KJ_IF_MAYBE(tempPath, createNamedTemporary(path, mode, kj::mv(tryCreate))) { + if (tryCommitReplacement(filename, fd, *tempPath, mode)) { + return true; + } else { + KJ_SYSCALL_HANDLE_ERRORS(unlinkat(fd, tempPath->cStr(), 0)) { + case ENOENT: + // meh + break; + default: + KJ_FAIL_SYSCALL("unlinkat(fd, tempPath, 0)", error, *tempPath); + } + return false; + } + } else { + // threw, but exceptions are disabled + return false; + } + } + + Maybe tryOpenFileInternal(PathPtr path, WriteMode mode, bool append) const { + uint flags = O_RDWR | MAYBE_O_CLOEXEC; + mode_t acl = 0666; + if (has(mode, WriteMode::CREATE)) { + flags |= O_CREAT; + } + if (!has(mode, WriteMode::MODIFY)) { + if (!has(mode, WriteMode::CREATE)) { + // Neither CREATE nor MODIFY -- impossible to satisfy preconditions. + return nullptr; + } + flags |= O_EXCL; + } + if (append) { + flags |= O_APPEND; + } + if (has(mode, WriteMode::EXECUTABLE)) { + acl = 0777; + } + if (has(mode, WriteMode::PRIVATE)) { + acl &= 0700; + } + + auto filename = path.toString(); + + int newFd; + KJ_SYSCALL_HANDLE_ERRORS(newFd = openat(fd, filename.cStr(), flags, acl)) { + case ENOENT: + if (has(mode, WriteMode::CREATE)) { + // Either: + // - The file is a broken symlink. + // - A parent directory didn't exist. + if (has(mode, WriteMode::CREATE_PARENT) && path.size() > 0 && + tryMkdir(path.parent(), WriteMode::CREATE | WriteMode::MODIFY | + WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + return tryOpenFileInternal(path, mode - WriteMode::CREATE_PARENT, append); + } + + // Check for broken link. + if (!has(mode, WriteMode::MODIFY) && + faccessat(fd, filename.cStr(), F_OK, AT_SYMLINK_NOFOLLOW) >= 0) { + // Yep. We treat this as already-exists, which means in CREATE-only mode this is a + // simple failure. + return nullptr; + } + + KJ_FAIL_REQUIRE("parent is not a directory", path) { return nullptr; } + } else { + // MODIFY-only mode. ENOENT = doesn't exist = return null. + return nullptr; + } + case ENOTDIR: + if (!has(mode, WriteMode::CREATE)) { + // MODIFY-only mode. ENOTDIR = parent not a directory = doesn't exist = return null. + return nullptr; + } + goto failed; + case EEXIST: + if (!has(mode, WriteMode::MODIFY)) { + // CREATE-only mode. EEXIST = already exists = return null. + return nullptr; + } + goto failed; + default: + failed: + KJ_FAIL_SYSCALL("openat(fd, path, O_RDWR | ...)", error, path) { return nullptr; } + } + + kj::AutoCloseFd result(newFd); +#ifndef O_CLOEXEC + setCloexec(result); +#endif + + return kj::mv(result); + } + + bool tryCommitReplacement(StringPtr toPath, int fromDirFd, StringPtr fromPath, WriteMode mode, + int* errorReason = nullptr) const { + if (has(mode, WriteMode::CREATE) && has(mode, WriteMode::MODIFY)) { + // Always clobber. Try it. + KJ_SYSCALL_HANDLE_ERRORS(renameat(fromDirFd, fromPath.cStr(), fd.get(), toPath.cStr())) { + case EISDIR: + case ENOTDIR: + case ENOTEMPTY: + case EEXIST: + // Failed because target exists and due to the various weird quirks of rename(), it + // can't remove it for us. On Linux we can try an exchange instead. On others we have + // to move the target out of the way. + break; + default: + if (errorReason == nullptr) { + KJ_FAIL_SYSCALL("rename(fromPath, toPath)", error, fromPath, toPath) { return false; } + } else { + *errorReason = error; + return false; + } + } else { + return true; + } + } + +#if __linux__ && defined(RENAME_EXCHANGE) && defined(SYS_renameat2) + // Try to use Linux's renameat2() to atomically check preconditions and apply. + + if (has(mode, WriteMode::MODIFY)) { + // Use an exchange to implement modification. + // + // We reach this branch when performing a MODIFY-only, or when performing a CREATE | MODIFY + // in which we determined above that there's a node of a different type blocking the + // exchange. + + KJ_SYSCALL_HANDLE_ERRORS(syscall(SYS_renameat2, + fromDirFd, fromPath.cStr(), fd.get(), toPath.cStr(), RENAME_EXCHANGE)) { + case ENOSYS: // Syscall not supported by kernel. + case EINVAL: // Maybe we screwed up, or maybe the syscall is not supported by the + // filesystem. Unfortunately, there's no way to tell, so assume the latter. + // ZFS in particular apparently produces EINVAL. + break; // fall back to traditional means + case ENOENT: + // Presumably because the target path doesn't exist. + if (has(mode, WriteMode::CREATE)) { + KJ_FAIL_ASSERT("rename(tmp, path) claimed path exists but " + "renameat2(fromPath, toPath, EXCHANGE) said it doest; concurrent modification?", + fromPath, toPath) { return false; } + } else { + // Assume target doesn't exist. + return false; + } + default: + if (errorReason == nullptr) { + KJ_FAIL_SYSCALL("renameat2(fromPath, toPath, EXCHANGE)", error, fromPath, toPath) { + return false; + } + } else { + *errorReason = error; + return false; + } + } else { + // Successful swap! Delete swapped-out content. + rmrf(fromDirFd, fromPath); + return true; + } + } else if (has(mode, WriteMode::CREATE)) { + KJ_SYSCALL_HANDLE_ERRORS(syscall(SYS_renameat2, + fromDirFd, fromPath.cStr(), fd.get(), toPath.cStr(), RENAME_NOREPLACE)) { + case ENOSYS: // Syscall not supported by kernel. + case EINVAL: // Maybe we screwed up, or maybe the syscall is not supported by the + // filesystem. Unfortunately, there's no way to tell, so assume the latter. + // ZFS in particular apparently produces EINVAL. + break; // fall back to traditional means + case EEXIST: + return false; + default: + if (errorReason == nullptr) { + KJ_FAIL_SYSCALL("renameat2(fromPath, toPath, NOREPLACE)", error, fromPath, toPath) { + return false; + } + } else { + *errorReason = error; + return false; + } + } else { + return true; + } + } +#endif + + // We're unable to do what we wanted atomically. :( + + if (has(mode, WriteMode::CREATE) && has(mode, WriteMode::MODIFY)) { + // We failed to atomically delete the target previously. So now we need to do two calls in + // rapid succession to move the old file away then move the new one into place. + + // Find out what kind of file exists at the target path. + struct stat stats; + KJ_SYSCALL(fstatat(fd, toPath.cStr(), &stats, AT_SYMLINK_NOFOLLOW)) { return false; } + + // Create a temporary location to move the existing object to. Note that rename() allows a + // non-directory to replace a non-directory, and allows a directory to replace an empty + // directory. So we have to create the right type. + Path toPathParsed = Path::parse(toPath); + String away; + KJ_IF_MAYBE(awayPath, createNamedTemporary(toPathParsed, WriteMode::CREATE, + [&](StringPtr candidatePath) { + if (S_ISDIR(stats.st_mode)) { + return mkdirat(fd, candidatePath.cStr(), 0700); + } else { +#if __APPLE__ || __FreeBSD__ + // - No mknodat() on OSX, gotta open() a file, ugh. + // - On a modern FreeBSD, mknodat() is reserved strictly for device nodes, + // you cannot create a regular file using it (EINVAL). + int newFd = openat(fd, candidatePath.cStr(), + O_RDWR | O_CREAT | O_EXCL | MAYBE_O_CLOEXEC, 0700); + if (newFd >= 0) close(newFd); + return newFd; +#else + return mknodat(fd, candidatePath.cStr(), S_IFREG | 0600, dev_t()); +#endif + } + })) { + away = kj::mv(*awayPath); + } else { + // Already threw. + return false; + } + + // OK, now move the target object to replace the thing we just created. + KJ_SYSCALL(renameat(fd, toPath.cStr(), fd, away.cStr())) { + // Something went wrong. Remove the thing we just created. + unlinkat(fd, away.cStr(), S_ISDIR(stats.st_mode) ? AT_REMOVEDIR : 0); + return false; + } + + // Now move the source object to the target location. + KJ_SYSCALL_HANDLE_ERRORS(renameat(fromDirFd, fromPath.cStr(), fd, toPath.cStr())) { + default: + // Try to put things back where they were. If this fails, though, then we have little + // choice but to leave things broken. + KJ_SYSCALL_HANDLE_ERRORS(renameat(fd, away.cStr(), fd, toPath.cStr())) { + default: break; + } + + if (errorReason == nullptr) { + KJ_FAIL_SYSCALL("rename(fromPath, toPath)", error, fromPath, toPath) { + return false; + } + } else { + *errorReason = error; + return false; + } + } + + // OK, success. Delete the old content. + rmrf(fd, away); + return true; + } else { + // Only one of CREATE or MODIFY is specified, so we need to verify non-atomically that the + // corresponding precondition (must-not-exist or must-exist, respectively) is held. + if (has(mode, WriteMode::CREATE)) { + struct stat stats; + KJ_SYSCALL_HANDLE_ERRORS(fstatat(fd.get(), toPath.cStr(), &stats, AT_SYMLINK_NOFOLLOW)) { + case ENOENT: + case ENOTDIR: + break; // doesn't exist; continue + default: + KJ_FAIL_SYSCALL("fstatat(fd, toPath)", error, toPath) { return false; } + } else { + return false; // already exists; fail + } + } else if (has(mode, WriteMode::MODIFY)) { + struct stat stats; + KJ_SYSCALL_HANDLE_ERRORS(fstatat(fd.get(), toPath.cStr(), &stats, AT_SYMLINK_NOFOLLOW)) { + case ENOENT: + case ENOTDIR: + return false; // doesn't exist; fail + default: + KJ_FAIL_SYSCALL("fstatat(fd, toPath)", error, toPath) { return false; } + } else { + // already exists; continue + } + } else { + // Neither CREATE nor MODIFY. + return false; + } + + // Start over in create-and-modify mode. + return tryCommitReplacement(toPath, fromDirFd, fromPath, + WriteMode::CREATE | WriteMode::MODIFY, + errorReason); + } + } + + template + class ReplacerImpl final: public Directory::Replacer { + public: + ReplacerImpl(Own&& object, const DiskHandle& handle, + String&& tempPath, String&& path, WriteMode mode) + : Directory::Replacer(mode), + object(kj::mv(object)), handle(handle), + tempPath(kj::mv(tempPath)), path(kj::mv(path)) {} + + ~ReplacerImpl() noexcept(false) { + if (!committed) { + rmrf(handle.fd, tempPath); + } + } + + const T& get() override { + return *object; + } + + bool tryCommit() override { + KJ_ASSERT(!committed, "already committed") { return false; } + return committed = handle.tryCommitReplacement(path, handle.fd, tempPath, + Directory::Replacer::mode); + } + + private: + Own object; + const DiskHandle& handle; + String tempPath; + String path; + bool committed = false; // true if *successfully* committed (in which case tempPath is gone) + }; + + template + class BrokenReplacer final: public Directory::Replacer { + // For recovery path when exceptions are disabled. + + public: + BrokenReplacer(Own inner) + : Directory::Replacer(WriteMode::CREATE | WriteMode::MODIFY), + inner(kj::mv(inner)) {} + + const T& get() override { return *inner; } + bool tryCommit() override { return false; } + + private: + Own inner; + }; + + Maybe> tryOpenFile(PathPtr path, WriteMode mode) const { + return tryOpenFileInternal(path, mode, false).map(newDiskFile); + } + + Own> replaceFile(PathPtr path, WriteMode mode) const { + mode_t acl = 0666; + if (has(mode, WriteMode::EXECUTABLE)) { + acl = 0777; + } + if (has(mode, WriteMode::PRIVATE)) { + acl &= 0700; + } + + int newFd_; + KJ_IF_MAYBE(temp, createNamedTemporary(path, mode, + [&](StringPtr candidatePath) { + return newFd_ = openat(fd, candidatePath.cStr(), + O_RDWR | O_CREAT | O_EXCL | MAYBE_O_CLOEXEC, acl); + })) { + AutoCloseFd newFd(newFd_); +#ifndef O_CLOEXEC + setCloexec(newFd); +#endif + return heap>(newDiskFile(kj::mv(newFd)), *this, kj::mv(*temp), + path.toString(), mode); + } else { + // threw, but exceptions are disabled + return heap>(newInMemoryFile(nullClock())); + } + } + + Own createTemporary() const { + int newFd_; + +#if __linux__ && defined(O_TMPFILE) + // Use syscall() to work around glibc bug with O_TMPFILE: + // https://sourceware.org/bugzilla/show_bug.cgi?id=17523 + KJ_SYSCALL_HANDLE_ERRORS(newFd_ = syscall( + SYS_openat, fd.get(), ".", O_RDWR | O_TMPFILE, 0700)) { + case EOPNOTSUPP: + case EINVAL: + case EISDIR: + // Maybe not supported by this kernel / filesystem. Fall back to below. + break; + default: + KJ_FAIL_SYSCALL("open(O_TMPFILE)", error) { break; } + break; + } else { + AutoCloseFd newFd(newFd_); +#ifndef O_CLOEXEC + setCloexec(newFd); +#endif + return newDiskFile(kj::mv(newFd)); + } +#endif + + KJ_IF_MAYBE(temp, createNamedTemporary(Path("unnamed"), WriteMode::CREATE, + [&](StringPtr path) { + return newFd_ = openat(fd, path.cStr(), O_RDWR | O_CREAT | O_EXCL | MAYBE_O_CLOEXEC, 0600); + })) { + AutoCloseFd newFd(newFd_); +#ifndef O_CLOEXEC + setCloexec(newFd); +#endif + auto result = newDiskFile(kj::mv(newFd)); + KJ_SYSCALL(unlinkat(fd, temp->cStr(), 0)) { break; } + return kj::mv(result); + } else { + // threw, but exceptions are disabled + return newInMemoryFile(nullClock()); + } + } + + Maybe> tryAppendFile(PathPtr path, WriteMode mode) const { + return tryOpenFileInternal(path, mode, true).map(newDiskAppendableFile); + } + + Maybe> tryOpenSubdir(PathPtr path, WriteMode mode) const { + // Must create before open. + if (has(mode, WriteMode::CREATE)) { + if (!tryMkdir(path, mode, false)) return nullptr; + } + + return tryOpenSubdirInternal(path).map(newDiskDirectory); + } + + Own> replaceSubdir(PathPtr path, WriteMode mode) const { + mode_t acl = has(mode, WriteMode::PRIVATE) ? 0700 : 0777; + + KJ_IF_MAYBE(temp, createNamedTemporary(path, mode, + [&](StringPtr candidatePath) { + return mkdirat(fd, candidatePath.cStr(), acl); + })) { + int subdirFd_; + KJ_SYSCALL_HANDLE_ERRORS(subdirFd_ = openat( + fd, temp->cStr(), O_RDONLY | MAYBE_O_CLOEXEC | MAYBE_O_DIRECTORY)) { + default: + KJ_FAIL_SYSCALL("open(just-created-temporary)", error); + return heap>(newInMemoryDirectory(nullClock())); + } + + AutoCloseFd subdirFd(subdirFd_); +#ifndef O_CLOEXEC + setCloexec(subdirFd); +#endif + return heap>( + newDiskDirectory(kj::mv(subdirFd)), *this, kj::mv(*temp), path.toString(), mode); + } else { + // threw, but exceptions are disabled + return heap>(newInMemoryDirectory(nullClock())); + } + } + + bool trySymlink(PathPtr linkpath, StringPtr content, WriteMode mode) const { + return tryReplaceNode(linkpath, mode, [&](StringPtr candidatePath) { + return symlinkat(content.cStr(), fd, candidatePath.cStr()); + }); + } + + bool tryTransfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode, const Directory& self) const { + KJ_REQUIRE(toPath.size() > 0, "can't replace self") { return false; } + + if (mode == TransferMode::LINK) { + KJ_IF_MAYBE(fromFd, fromDirectory.getFd()) { + // Other is a disk directory, so we can hopefully do an efficient move/link. + return tryReplaceNode(toPath, toMode, [&](StringPtr candidatePath) { + return linkat(*fromFd, fromPath.toString().cStr(), fd, candidatePath.cStr(), 0); + }); + }; + } else if (mode == TransferMode::MOVE) { + KJ_IF_MAYBE(fromFd, fromDirectory.getFd()) { + KJ_ASSERT(mode == TransferMode::MOVE); + + int error = 0; + if (tryCommitReplacement(toPath.toString(), *fromFd, fromPath.toString(), toMode, + &error)) { + return true; + } else switch (error) { + case 0: + // Plain old WriteMode precondition failure. + return false; + case EXDEV: + // Can't move between devices. Fall back to default implementation, which does + // copy/delete. + break; + case ENOENT: + // Either the destination directory doesn't exist or the source path doesn't exist. + // Unfortunately we don't really know. If CREATE_PARENT was provided, try creating + // the parent directory. Otherwise, we don't actually need to distinguish between + // these two errors; just return false. + if (has(toMode, WriteMode::CREATE) && has(toMode, WriteMode::CREATE_PARENT) && + toPath.size() > 0 && tryMkdir(toPath.parent(), + WriteMode::CREATE | WriteMode::MODIFY | WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + return tryTransfer(toPath, toMode - WriteMode::CREATE_PARENT, + fromDirectory, fromPath, mode, self); + } + return false; + default: + KJ_FAIL_SYSCALL("rename(fromPath, toPath)", error, fromPath, toPath) { + return false; + } + } + } + } + + // OK, we can't do anything efficient using the OS. Fall back to default implementation. + return self.Directory::tryTransfer(toPath, toMode, fromDirectory, fromPath, mode); + } + + bool tryRemove(PathPtr path) const { + return rmrf(fd, path.toString()); + } + +protected: + AutoCloseFd fd; +}; + +#define FSNODE_METHODS(classname) \ + Maybe getFd() const override { return DiskHandle::getFd(); } \ + \ + Own cloneFsNode() const override { \ + return heap(DiskHandle::clone()); \ + } \ + \ + Metadata stat() const override { return DiskHandle::stat(); } \ + void sync() const override { DiskHandle::sync(); } \ + void datasync() const override { DiskHandle::datasync(); } + +class DiskReadableFile final: public ReadableFile, public DiskHandle { +public: + DiskReadableFile(AutoCloseFd&& fd): DiskHandle(kj::mv(fd)) {} + + FSNODE_METHODS(DiskReadableFile); + + size_t read(uint64_t offset, ArrayPtr buffer) const override { + return DiskHandle::read(offset, buffer); + } + Array mmap(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmap(offset, size); + } + Array mmapPrivate(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmapPrivate(offset, size); + } +}; + +class DiskAppendableFile final: public AppendableFile, public DiskHandle, public FdOutputStream { +public: + DiskAppendableFile(AutoCloseFd&& fd) + : DiskHandle(kj::mv(fd)), + FdOutputStream(DiskHandle::fd.get()) {} + + FSNODE_METHODS(DiskAppendableFile); + + void write(const void* buffer, size_t size) override { + FdOutputStream::write(buffer, size); + } + void write(ArrayPtr> pieces) override { + FdOutputStream::write(pieces); + } +}; + +class DiskFile final: public File, public DiskHandle { +public: + DiskFile(AutoCloseFd&& fd): DiskHandle(kj::mv(fd)) {} + + FSNODE_METHODS(DiskFile); + + size_t read(uint64_t offset, ArrayPtr buffer) const override { + return DiskHandle::read(offset, buffer); + } + Array mmap(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmap(offset, size); + } + Array mmapPrivate(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmapPrivate(offset, size); + } + + void write(uint64_t offset, ArrayPtr data) const override { + DiskHandle::write(offset, data); + } + void zero(uint64_t offset, uint64_t size) const override { + DiskHandle::zero(offset, size); + } + void truncate(uint64_t size) const override { + DiskHandle::truncate(size); + } + Own mmapWritable(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmapWritable(offset, size); + } + size_t copy(uint64_t offset, const ReadableFile& from, + uint64_t fromOffset, uint64_t size) const override { + KJ_IF_MAYBE(result, DiskHandle::copy(offset, from, fromOffset, size)) { + return *result; + } else { + return File::copy(offset, from, fromOffset, size); + } + } +}; + +class DiskReadableDirectory final: public ReadableDirectory, public DiskHandle { +public: + DiskReadableDirectory(AutoCloseFd&& fd): DiskHandle(kj::mv(fd)) {} + + FSNODE_METHODS(DiskReadableDirectory); + + Array listNames() const override { return DiskHandle::listNames(); } + Array listEntries() const override { return DiskHandle::listEntries(); } + bool exists(PathPtr path) const override { return DiskHandle::exists(path); } + Maybe tryLstat(PathPtr path) const override { + return DiskHandle::tryLstat(path); + } + Maybe> tryOpenFile(PathPtr path) const override { + return DiskHandle::tryOpenFile(path); + } + Maybe> tryOpenSubdir(PathPtr path) const override { + return DiskHandle::tryOpenSubdir(path); + } + Maybe tryReadlink(PathPtr path) const override { return DiskHandle::tryReadlink(path); } +}; + +class DiskDirectory final: public Directory, public DiskHandle { +public: + DiskDirectory(AutoCloseFd&& fd): DiskHandle(kj::mv(fd)) {} + + FSNODE_METHODS(DiskDirectory); + + Array listNames() const override { return DiskHandle::listNames(); } + Array listEntries() const override { return DiskHandle::listEntries(); } + bool exists(PathPtr path) const override { return DiskHandle::exists(path); } + Maybe tryLstat(PathPtr path) const override { + return DiskHandle::tryLstat(path); + } + Maybe> tryOpenFile(PathPtr path) const override { + return DiskHandle::tryOpenFile(path); + } + Maybe> tryOpenSubdir(PathPtr path) const override { + return DiskHandle::tryOpenSubdir(path); + } + Maybe tryReadlink(PathPtr path) const override { return DiskHandle::tryReadlink(path); } + + Maybe> tryOpenFile(PathPtr path, WriteMode mode) const override { + return DiskHandle::tryOpenFile(path, mode); + } + Own> replaceFile(PathPtr path, WriteMode mode) const override { + return DiskHandle::replaceFile(path, mode); + } + Own createTemporary() const override { + return DiskHandle::createTemporary(); + } + Maybe> tryAppendFile(PathPtr path, WriteMode mode) const override { + return DiskHandle::tryAppendFile(path, mode); + } + Maybe> tryOpenSubdir(PathPtr path, WriteMode mode) const override { + return DiskHandle::tryOpenSubdir(path, mode); + } + Own> replaceSubdir(PathPtr path, WriteMode mode) const override { + return DiskHandle::replaceSubdir(path, mode); + } + bool trySymlink(PathPtr linkpath, StringPtr content, WriteMode mode) const override { + return DiskHandle::trySymlink(linkpath, content, mode); + } + bool tryTransfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode) const override { + return DiskHandle::tryTransfer(toPath, toMode, fromDirectory, fromPath, mode, *this); + } + // tryTransferTo() not implemented because we have nothing special we can do. + bool tryRemove(PathPtr path) const override { + return DiskHandle::tryRemove(path); + } +}; + +class DiskFilesystem final: public Filesystem { +public: + DiskFilesystem() + : root(openDir("/")), + current(openDir(".")), + currentPath(computeCurrentPath()) { + // We sometimes like to use qemu-user to test arm64 binaries cross-compiled from an x64 host + // machine. But, because it intercepts and rewrites system calls from userspace rather than + // emulating a whole kernel, it has a lot of quirks. One quirk that hits kj::Filesystem pretty + // badly is that open("/") actually returns a file descriptor for "/usr/aarch64-linux-gnu". + // Attempts to openat() any files within there then don't work. We can detect this problem and + // correct for it here. + struct stat realRoot, fsRoot; + KJ_SYSCALL_HANDLE_ERRORS(stat("/dev/..", &realRoot)) { + default: + // stat("/dev/..") failed? Give up. + return; + } + KJ_SYSCALL(fstat(root.DiskHandle::getFd(), &fsRoot)); + if (realRoot.st_ino != fsRoot.st_ino) { + KJ_LOG(WARNING, "root dir file descriptor is broken, probably because of qemu; compensating"); + root.setFd(openDir("/dev/..")); + } + } + + const Directory& getRoot() const override { + return root; + } + + const Directory& getCurrent() const override { + return current; + } + + PathPtr getCurrentPath() const override { + return currentPath; + } + +private: + DiskDirectory root; + DiskDirectory current; + Path currentPath; + + static AutoCloseFd openDir(const char* dir) { + int newFd; + KJ_SYSCALL(newFd = open(dir, O_RDONLY | MAYBE_O_CLOEXEC | MAYBE_O_DIRECTORY)); + AutoCloseFd result(newFd); +#ifndef O_CLOEXEC + setCloexec(result); +#endif + return result; + } + + static Path computeCurrentPath() { + // If env var PWD is set and points to the current directory, use it. This captures the current + // path according to the user's shell, which may differ from the kernel's idea in the presence + // of symlinks. + const char* pwd = getenv("PWD"); + if (pwd != nullptr) { + Path result = nullptr; + struct stat pwdStat, dotStat; + KJ_IF_MAYBE(e, kj::runCatchingExceptions([&]() { + KJ_ASSERT(pwd[0] == '/') { return; } + result = Path::parse(pwd + 1); + KJ_SYSCALL(lstat(result.toString(true).cStr(), &pwdStat), result) { return; } + KJ_SYSCALL(lstat(".", &dotStat)) { return; } + })) { + // failed, give up on PWD + KJ_LOG(WARNING, "PWD environment variable seems invalid", pwd, *e); + } else { + if (pwdStat.st_ino == dotStat.st_ino && + pwdStat.st_dev == dotStat.st_dev) { + return kj::mv(result); + } else { + KJ_LOG(WARNING, "PWD environment variable doesn't match current directory", pwd); + } + } + } + + size_t size = 256; + retry: + KJ_STACK_ARRAY(char, buf, size, 256, 4096); + if (getcwd(buf.begin(), size) == nullptr) { + int error = errno; + if (error == ERANGE) { + size *= 2; + goto retry; + } else { + KJ_FAIL_SYSCALL("getcwd()", error); + } + } + + StringPtr path = buf.begin(); + + // On Linux, the path will start with "(unreachable)" if the working directory is not a subdir + // of the root directory, which is possible via chroot() or mount namespaces. + KJ_ASSERT(!path.startsWith("(unreachable)"), + "working directory is not reachable from root", path); + KJ_ASSERT(path.startsWith("/"), "current directory is not absolute", path); + + return Path::parse(path.slice(1)); + } +}; + +} // namespace + +Own newDiskReadableFile(kj::AutoCloseFd fd) { + return heap(kj::mv(fd)); +} +Own newDiskAppendableFile(kj::AutoCloseFd fd) { + return heap(kj::mv(fd)); +} +Own newDiskFile(kj::AutoCloseFd fd) { + return heap(kj::mv(fd)); +} +Own newDiskReadableDirectory(kj::AutoCloseFd fd) { + return heap(kj::mv(fd)); +} +Own newDiskDirectory(kj::AutoCloseFd fd) { + return heap(kj::mv(fd)); +} + +Own newDiskFilesystem() { + return heap(); +} + +} // namespace kj + +#endif // !_WIN32 diff --git a/c++/src/kj/filesystem-disk-win32.c++ b/c++/src/kj/filesystem-disk-win32.c++ new file mode 100644 index 0000000000..be761894f6 --- /dev/null +++ b/c++/src/kj/filesystem-disk-win32.c++ @@ -0,0 +1,1617 @@ +// Copyright (c) 2015 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if _WIN32 +// For Unix implementation, see filesystem-disk-unix.c++. + +// Request Vista-level APIs. +#include "win32-api-version.h" + +#include "filesystem.h" +#include "debug.h" +#include "encoding.h" +#include "vector.h" +#include +#include + +#include +#include +#include "windows-sanity.h" + +namespace kj { + +static Own newDiskReadableDirectory(AutoCloseHandle fd, Path&& path); +static Own newDiskDirectory(AutoCloseHandle fd, Path&& path); + +static AutoCloseHandle* getHandlePointerHack(File& file) { return nullptr; } +static AutoCloseHandle* getHandlePointerHack(Directory& dir); +static Path* getPathPointerHack(File& file) { return nullptr; } +static Path* getPathPointerHack(Directory& dir); + +namespace { + +struct REPARSE_DATA_BUFFER { + // From ntifs.h, which is part of the driver development kit so not necessarily available I + // guess. + ULONG ReparseTag; + USHORT ReparseDataLength; + USHORT Reserved; + union { + struct { + USHORT SubstituteNameOffset; + USHORT SubstituteNameLength; + USHORT PrintNameOffset; + USHORT PrintNameLength; + ULONG Flags; + WCHAR PathBuffer[1]; + } SymbolicLinkReparseBuffer; + struct { + USHORT SubstituteNameOffset; + USHORT SubstituteNameLength; + USHORT PrintNameOffset; + USHORT PrintNameLength; + WCHAR PathBuffer[1]; + } MountPointReparseBuffer; + struct { + UCHAR DataBuffer[1]; + } GenericReparseBuffer; + }; +}; + +#define HIDDEN_PREFIX ".kj-tmp." +// Prefix for temp files which should be hidden when listing a directory. +// +// If you change this, make sure to update the unit test. + +static constexpr int64_t WIN32_EPOCH_OFFSET = 116444736000000000ull; +// Number of 100ns intervals from Jan 1, 1601 to Jan 1, 1970. + +static Date toKjDate(FILETIME t) { + int64_t value = (static_cast(t.dwHighDateTime) << 32) | t.dwLowDateTime; + return (value - WIN32_EPOCH_OFFSET) * (100 * kj::NANOSECONDS) + UNIX_EPOCH; +} + +static FsNode::Type modeToType(DWORD attrs, DWORD reparseTag) { + if ((attrs & FILE_ATTRIBUTE_REPARSE_POINT) && + reparseTag == IO_REPARSE_TAG_SYMLINK) { + return FsNode::Type::SYMLINK; + } + if (attrs & FILE_ATTRIBUTE_DIRECTORY) return FsNode::Type::DIRECTORY; + return FsNode::Type::FILE; +} + +static FsNode::Metadata statToMetadata(const BY_HANDLE_FILE_INFORMATION& stats) { + uint64_t size = (implicitCast(stats.nFileSizeHigh) << 32) | stats.nFileSizeLow; + + // Assume file index is usually a small number, i.e. nFileIndexHigh is usually 0. So we try to + // put the serial number in the upper 32 bits and the index in the lower. + uint64_t hash = ((uint64_t(stats.dwVolumeSerialNumber) << 32) + ^ (uint64_t(stats.nFileIndexHigh) << 32)) + | (uint64_t(stats.nFileIndexLow)); + + return FsNode::Metadata { + modeToType(stats.dwFileAttributes, 0), + size, + // In theory, spaceUsed should be based on GetCompressedFileSize(), but requiring an extra + // syscall for something rarely used would be sad. + size, + toKjDate(stats.ftLastWriteTime), + stats.nNumberOfLinks, + hash + }; +} + +static FsNode::Metadata statToMetadata(const WIN32_FIND_DATAW& stats) { + uint64_t size = (implicitCast(stats.nFileSizeHigh) << 32) | stats.nFileSizeLow; + + return FsNode::Metadata { + modeToType(stats.dwFileAttributes, stats.dwReserved0), + size, + // In theory, spaceUsed should be based on GetCompressedFileSize(), but requiring an extra + // syscall for something rarely used would be sad. + size, + toKjDate(stats.ftLastWriteTime), + // We can't get the number of links without opening the file, apparently. Meh. + 1, + // We can't produce a reliable hashCode without opening the file. + 0 + }; +} + +static Array join16(ArrayPtr path, const wchar_t* file) { + // Assumes `path` ends with a NUL terminator (and `file` is of course NUL terminated as well). + + size_t len = wcslen(file) + 1; + auto result = kj::heapArray(path.size() + len); + memcpy(result.begin(), path.begin(), path.asBytes().size() - sizeof(wchar_t)); + result[path.size() - 1] = '\\'; + memcpy(result.begin() + path.size(), file, len * sizeof(wchar_t)); + return result; +} + +static String dbgStr(ArrayPtr wstr) { + if (wstr.size() > 0 && wstr[wstr.size() - 1] == L'\0') { + wstr = wstr.slice(0, wstr.size() - 1); + } + return decodeWideString(wstr); +} + +static void rmrfChildren(ArrayPtr path) { + auto glob = join16(path, L"*"); + + WIN32_FIND_DATAW data; + // TODO(security): If `path` is a reparse point (symlink), this will follow it and delete the + // contents. We check for reparse points before recursing, but there is still a TOCTOU race + // condition. + // + // Apparently, there is a whole different directory-listing API we could be using here: + // `GetFileInformationByHandleEx()`, with the `FileIdBothDirectoryInfo` flag. This lets us + // list the contents of a directory from its already-open handle -- it's probably how we should + // do directory listing in general! If we open a file with FILE_FLAG_OPEN_REPARSE_POINT, then + // the handle will represent the reparse point itself, and attempting to list it will produce + // no entries. I had no idea this API existed when I wrote much of this code; I wish I had + // because it seems much cleaner than the ancient FindFirstFile/FindNextFile API! + HANDLE handle = FindFirstFileW(glob.begin(), &data); + if (handle == INVALID_HANDLE_VALUE) { + auto error = GetLastError(); + if (error == ERROR_FILE_NOT_FOUND) return; + KJ_FAIL_WIN32("FindFirstFile", error, dbgStr(glob)) { return; } + } + KJ_DEFER(KJ_WIN32(FindClose(handle)) { break; }); + + do { + // Ignore "." and "..", ugh. + if (data.cFileName[0] == L'.') { + if (data.cFileName[1] == L'\0' || + (data.cFileName[1] == L'.' && data.cFileName[2] == L'\0')) { + continue; + } + } + + auto child = join16(path, data.cFileName); + // For rmrf purposes, we assume any "reparse points" are symlink-like, even if they aren't + // actually the "symbolic link" reparse type, because we don't want to recursively delete any + // shared content. + if ((data.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) && + !(data.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT)) { + rmrfChildren(child); + uint retryCount = 0; + retry: + KJ_WIN32_HANDLE_ERRORS(RemoveDirectoryW(child.begin())) { + case ERROR_DIR_NOT_EMPTY: + // On Windows, deleting a file actually only schedules it for deletion. Under heavy + // load it may take a bit for the deletion to go through. Or, if another process has + // the file open, it may not be deleted until that process closes it. + // + // We'll repeatedly retry for up to 100ms, then give up. This is awful but there's no + // way to tell for sure if the system is just being slow or if someone has the file + // open. + if (retryCount++ < 10) { + Sleep(10); + goto retry; + } + KJ_FALLTHROUGH; + default: + KJ_FAIL_WIN32("RemoveDirectory", error, dbgStr(child)) { break; } + } + } else { + KJ_WIN32(DeleteFileW(child.begin())); + } + } while (FindNextFileW(handle, &data)); + + auto error = GetLastError(); + if (error != ERROR_NO_MORE_FILES) { + KJ_FAIL_WIN32("FindNextFile", error, dbgStr(path)) { return; } + } +} + +static bool rmrf(ArrayPtr path) { + // Figure out whether this is a file or a directory. + // + // We use FindFirstFileW() because in the case of symlinks it will return info about the + // symlink rather than info about the target. + WIN32_FIND_DATAW data; + HANDLE handle = FindFirstFileW(path.begin(), &data); + if (handle == INVALID_HANDLE_VALUE) { + auto error = GetLastError(); + if (error == ERROR_FILE_NOT_FOUND) return false; + KJ_FAIL_WIN32("FindFirstFile", error, dbgStr(path)); + } + KJ_WIN32(FindClose(handle)); + + // For remove purposes, we assume any "reparse points" are symlink-like, even if they aren't + // actually the "symbolic link" reparse type, because we don't want to recursively delete any + // shared content. + if ((data.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) && + !(data.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT)) { + // directory + rmrfChildren(path); + KJ_WIN32(RemoveDirectoryW(path.begin()), dbgStr(path)); + } else { + KJ_WIN32(DeleteFileW(path.begin()), dbgStr(path)); + } + + return true; +} + +static Path getPathFromHandle(HANDLE handle) { + DWORD tryLen = MAX_PATH; + for (;;) { + auto temp = kj::heapArray(tryLen + 1); + DWORD len = GetFinalPathNameByHandleW(handle, temp.begin(), tryLen, 0); + if (len == 0) { + KJ_FAIL_WIN32("GetFinalPathNameByHandleW", GetLastError()); + } + if (len < temp.size()) { + return Path::parseWin32Api(temp.slice(0, len)); + } + // Try again with new length. + tryLen = len; + } +} + +struct MmapRange { + uint64_t offset; + uint64_t size; +}; + +static size_t getAllocationGranularity() { + SYSTEM_INFO info; + GetSystemInfo(&info); + return info.dwAllocationGranularity; +}; + +static MmapRange getMmapRange(uint64_t offset, uint64_t size) { + // Rounds the given offset down to the nearest page boundary, and adjusts the size up to match. + // (This is somewhat different from Unix: we do NOT round the size up to an even multiple of + // pages.) + static const uint64_t pageSize = getAllocationGranularity(); + uint64_t pageMask = pageSize - 1; + + uint64_t realOffset = offset & ~pageMask; + + uint64_t end = offset + size; + + return { realOffset, end - realOffset }; +} + +class MmapDisposer: public ArrayDisposer { +protected: + void disposeImpl(void* firstElement, size_t elementSize, size_t elementCount, + size_t capacity, void (*destroyElement)(void*)) const { + auto range = getMmapRange(reinterpret_cast(firstElement), + elementSize * elementCount); + void* mapping = reinterpret_cast(range.offset); + if (mapping != nullptr) { + KJ_ASSERT(UnmapViewOfFile(mapping)) { break; } + } + } +}; + +#if _MSC_VER && _MSC_VER < 1910 && !defined(__clang__) +// TODO(msvc): MSVC 2015 can't initialize a constexpr's vtable correctly. +const MmapDisposer mmapDisposer = MmapDisposer(); +#else +constexpr MmapDisposer mmapDisposer = MmapDisposer(); +#endif + +void* win32Mmap(HANDLE handle, MmapRange range, DWORD pageProtect, DWORD access) { + HANDLE mappingHandle; + KJ_WIN32(mappingHandle = CreateFileMappingW(handle, NULL, pageProtect, 0, 0, NULL)); + KJ_DEFER(KJ_WIN32(CloseHandle(mappingHandle)) { break; }); + + void* mapping = MapViewOfFile(mappingHandle, access, + static_cast(range.offset >> 32), static_cast(range.offset), range.size); + if (mapping == nullptr) { + KJ_FAIL_WIN32("MapViewOfFile", GetLastError()); + } + + // It's unclear from the documentation whether mappings will always start at a multiple of the + // allocation granularity, but we depend on that later, so check it... + KJ_ASSERT(getMmapRange(reinterpret_cast(mapping), 0).size == 0); + + return mapping; +} + +class DiskHandle { + // We need to implement each of ReadableFile, AppendableFile, File, ReadableDirectory, and + // Directory for disk handles. There is a lot of implementation overlap between these, especially + // stat(), sync(), etc. We can't have everything inherit from a common DiskFsNode that implements + // these because then we get diamond inheritance which means we need to make all our inheritance + // virtual which means downcasting requires RTTI which violates our goal of supporting compiling + // with no RTTI. So instead we have the DiskHandle class which implements all the methods without + // inheriting anything, and then we have DiskFile, DiskDirectory, etc. hold this and delegate to + // it. Ugly, but works. + +public: + DiskHandle(AutoCloseHandle&& handle, Maybe dirPath) + : handle(kj::mv(handle)), dirPath(kj::mv(dirPath)) {} + + AutoCloseHandle handle; + kj::Maybe dirPath; // needed for directories, empty for files + + Array nativePath(PathPtr path) const { + return KJ_ASSERT_NONNULL(dirPath).append(path).forWin32Api(true); + } + + // OsHandle ------------------------------------------------------------------ + + AutoCloseHandle clone() const { + HANDLE newHandle; + KJ_WIN32(DuplicateHandle(GetCurrentProcess(), handle, GetCurrentProcess(), &newHandle, + 0, FALSE, DUPLICATE_SAME_ACCESS)); + return AutoCloseHandle(newHandle); + } + + HANDLE getWin32Handle() const { + return handle.get(); + } + + // FsNode -------------------------------------------------------------------- + + FsNode::Metadata stat() const { + BY_HANDLE_FILE_INFORMATION stats; + KJ_WIN32(GetFileInformationByHandle(handle, &stats)); + auto metadata = statToMetadata(stats); + + // Get space usage, e.g. for sparse files. Apparently the correct way to do this is to query + // "compression". + FILE_COMPRESSION_INFO compInfo; + KJ_WIN32_HANDLE_ERRORS(GetFileInformationByHandleEx( + handle, FileCompressionInfo, &compInfo, sizeof(compInfo))) { + case ERROR_CALL_NOT_IMPLEMENTED: + // Probably WINE. + break; + default: + KJ_FAIL_WIN32("GetFileInformationByHandleEx(FileCompressionInfo)", error) { break; } + break; + } else { + metadata.spaceUsed = compInfo.CompressedFileSize.QuadPart; + } + + return metadata; + } + + void sync() const { KJ_WIN32(FlushFileBuffers(handle)); } + void datasync() const { KJ_WIN32(FlushFileBuffers(handle)); } + + // ReadableFile -------------------------------------------------------------- + + size_t read(uint64_t offset, ArrayPtr buffer) const { + // ReadFile() probably never returns short reads unless it hits EOF. Unfortunately, though, + // this is not documented, and it's unclear whether we can rely on it. + + size_t total = 0; + while (buffer.size() > 0) { + // Apparently, the way to fake pread() on Windows is to provide an OVERLAPPED structure even + // though we're not doing overlapped I/O. + OVERLAPPED overlapped; + memset(&overlapped, 0, sizeof(overlapped)); + overlapped.Offset = static_cast(offset); + overlapped.OffsetHigh = static_cast(offset >> 32); + + DWORD n; + KJ_WIN32_HANDLE_ERRORS(ReadFile(handle, buffer.begin(), buffer.size(), &n, &overlapped)) { + case ERROR_HANDLE_EOF: + // The documentation claims this shouldn't happen for synchronous reads, but it seems + // to happen for me, at least under WINE. + n = 0; + break; + default: + KJ_FAIL_WIN32("ReadFile", offset, buffer.size()) { return total; } + } + if (n == 0) break; + total += n; + offset += n; + buffer = buffer.slice(n, buffer.size()); + } + return total; + } + + Array mmap(uint64_t offset, uint64_t size) const { + if (size == 0) return nullptr; // Windows won't allow zero-length mappings + auto range = getMmapRange(offset, size); + const void* mapping = win32Mmap(handle, range, PAGE_READONLY, FILE_MAP_READ); + return Array(reinterpret_cast(mapping) + (offset - range.offset), + size, mmapDisposer); + } + + Array mmapPrivate(uint64_t offset, uint64_t size) const { + if (size == 0) return nullptr; // Windows won't allow zero-length mappings + auto range = getMmapRange(offset, size); + void* mapping = win32Mmap(handle, range, PAGE_READONLY, FILE_MAP_COPY); + return Array(reinterpret_cast(mapping) + (offset - range.offset), + size, mmapDisposer); + } + + // File ---------------------------------------------------------------------- + + void write(uint64_t offset, ArrayPtr data) const { + // WriteFile() probably never returns short writes unless there's no space left on disk. + // Unfortunately, though, this is not documented, and it's unclear whether we can rely on it. + + while (data.size() > 0) { + // Apparently, the way to fake pwrite() on Windows is to provide an OVERLAPPED structure even + // though we're not doing overlapped I/O. + OVERLAPPED overlapped; + memset(&overlapped, 0, sizeof(overlapped)); + overlapped.Offset = static_cast(offset); + overlapped.OffsetHigh = static_cast(offset >> 32); + + DWORD n; + KJ_WIN32(WriteFile(handle, data.begin(), data.size(), &n, &overlapped)); + KJ_ASSERT(n > 0, "WriteFile() returned zero?"); + offset += n; + data = data.slice(n, data.size()); + } + } + + void zero(uint64_t offset, uint64_t size) const { + FILE_ZERO_DATA_INFORMATION info; + memset(&info, 0, sizeof(info)); + info.FileOffset.QuadPart = offset; + info.BeyondFinalZero.QuadPart = offset + size; + + DWORD dummy; + KJ_WIN32_HANDLE_ERRORS(DeviceIoControl(handle, FSCTL_SET_ZERO_DATA, &info, + sizeof(info), NULL, 0, &dummy, NULL)) { + case ERROR_NOT_SUPPORTED: { + // Dang. Let's do it the hard way. + static const byte ZEROS[4096] = { 0 }; + + while (size > sizeof(ZEROS)) { + write(offset, ZEROS); + size -= sizeof(ZEROS); + offset += sizeof(ZEROS); + } + write(offset, kj::arrayPtr(ZEROS, size)); + break; + } + + default: + KJ_FAIL_WIN32("DeviceIoControl(FSCTL_SET_ZERO_DATA)", error); + break; + } + } + + void truncate(uint64_t size) const { + // SetEndOfFile() would require seeking the file. It looks like SetFileInformationByHandle() + // lets us avoid this! + FILE_END_OF_FILE_INFO info; + memset(&info, 0, sizeof(info)); + info.EndOfFile.QuadPart = size; + KJ_WIN32_HANDLE_ERRORS( + SetFileInformationByHandle(handle, FileEndOfFileInfo, &info, sizeof(info))) { + case ERROR_CALL_NOT_IMPLEMENTED: { + // Wine doesn't implement this. :( + + LONG currentHigh = 0; + LONG currentLow = SetFilePointer(handle, 0, ¤tHigh, FILE_CURRENT); + if (currentLow == INVALID_SET_FILE_POINTER) { + KJ_FAIL_WIN32("SetFilePointer", GetLastError()); + } + uint64_t current = (uint64_t(currentHigh) << 32) | uint64_t((ULONG)currentLow); + + LONG endLow = size & 0x00000000ffffffffull; + LONG endHigh = size >> 32; + if (SetFilePointer(handle, endLow, &endHigh, FILE_BEGIN) == INVALID_SET_FILE_POINTER) { + KJ_FAIL_WIN32("SetFilePointer", GetLastError()); + } + + KJ_WIN32(SetEndOfFile(handle)); + + if (current < size) { + if (SetFilePointer(handle, currentLow, ¤tHigh, FILE_BEGIN) == + INVALID_SET_FILE_POINTER) { + KJ_FAIL_WIN32("SetFilePointer", GetLastError()); + } + } + + break; + } + default: + KJ_FAIL_WIN32("SetFileInformationByHandle", error); + } + } + + class WritableFileMappingImpl final: public WritableFileMapping { + public: + WritableFileMappingImpl(Array bytes): bytes(kj::mv(bytes)) {} + + ArrayPtr get() const override { + // const_cast OK because WritableFileMapping does indeed provide a writable view despite + // being const itself. + return arrayPtr(const_cast(bytes.begin()), bytes.size()); + } + + void changed(ArrayPtr slice) const override { + KJ_REQUIRE(slice.begin() >= bytes.begin() && slice.end() <= bytes.end(), + "byte range is not part of this mapping"); + + // Nothing needed here -- NT tracks dirty pages. + } + + void sync(ArrayPtr slice) const override { + KJ_REQUIRE(slice.begin() >= bytes.begin() && slice.end() <= bytes.end(), + "byte range is not part of this mapping"); + + // Zero is treated specially by FlushViewOfFile(), so check for it. (This also handles the + // case where `bytes` is actually empty and not a real mapping.) + if (slice.size() > 0) { + KJ_WIN32(FlushViewOfFile(slice.begin(), slice.size())); + } + } + + private: + Array bytes; + }; + + Own mmapWritable(uint64_t offset, uint64_t size) const { + if (size == 0) { + // Windows won't allow zero-length mappings + return heap(nullptr); + } + auto range = getMmapRange(offset, size); + void* mapping = win32Mmap(handle, range, PAGE_READWRITE, FILE_MAP_ALL_ACCESS); + auto array = Array(reinterpret_cast(mapping) + (offset - range.offset), + size, mmapDisposer); + return heap(kj::mv(array)); + } + + // copy() is not optimized on Windows. + + // ReadableDirectory --------------------------------------------------------- + + template + auto list(bool needTypes, Func&& func) const + -> Array(), instance()))>> { + PathPtr path = KJ_ASSERT_NONNULL(dirPath); + auto glob = join16(path.forWin32Api(true), L"*"); + + // TODO(someday): Use GetFileInformationByHandleEx() with FileIdBothDirectoryInfo to enumerate + // directories instead. It's much cleaner. + WIN32_FIND_DATAW data; + HANDLE handle = FindFirstFileW(glob.begin(), &data); + if (handle == INVALID_HANDLE_VALUE) { + auto error = GetLastError(); + if (error == ERROR_FILE_NOT_FOUND) return nullptr; + KJ_FAIL_WIN32("FindFirstFile", error, dbgStr(glob)); + } + KJ_DEFER(KJ_WIN32(FindClose(handle)) { break; }); + + typedef Decay(), instance()))> Entry; + kj::Vector entries; + + do { + auto name = decodeUtf16( + arrayPtr(reinterpret_cast(data.cFileName), wcslen(data.cFileName))); + if (name != "." && name != ".." && !name.startsWith(HIDDEN_PREFIX)) { + entries.add(func(name, modeToType(data.dwFileAttributes, data.dwReserved0))); + } + } while (FindNextFileW(handle, &data)); + + auto error = GetLastError(); + if (error != ERROR_NO_MORE_FILES) { + KJ_FAIL_WIN32("FindNextFile", error, path); + } + + auto result = entries.releaseAsArray(); + std::sort(result.begin(), result.end()); + return result; + } + + Array listNames() const { + return list(false, [](StringPtr name, FsNode::Type type) { return heapString(name); }); + } + + Array listEntries() const { + return list(true, [](StringPtr name, FsNode::Type type) { + return ReadableDirectory::Entry { type, heapString(name), }; + }); + } + + bool exists(PathPtr path) const { + DWORD result = GetFileAttributesW(nativePath(path).begin()); + if (result == INVALID_FILE_ATTRIBUTES) { + auto error = GetLastError(); + switch (error) { + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + return false; + default: + KJ_FAIL_WIN32("GetFileAttributesEx(path)", error, path) { return false; } + } + } else { + return true; + } + } + + Maybe tryLstat(PathPtr path) const { + // We use FindFirstFileW() because in the case of symlinks it will return info about the + // symlink rather than info about the target. + WIN32_FIND_DATAW data; + HANDLE handle = FindFirstFileW(nativePath(path).begin(), &data); + if (handle == INVALID_HANDLE_VALUE) { + auto error = GetLastError(); + if (error == ERROR_FILE_NOT_FOUND) return nullptr; + KJ_FAIL_WIN32("FindFirstFile", error, path); + } else { + KJ_WIN32(FindClose(handle)); + return statToMetadata(data); + } + } + + Maybe> tryOpenFile(PathPtr path) const { + HANDLE newHandle; + KJ_WIN32_HANDLE_ERRORS(newHandle = CreateFileW( + nativePath(path).begin(), + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + NULL, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + NULL)) { + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + return nullptr; + default: + KJ_FAIL_WIN32("CreateFile(path, OPEN_EXISTING)", error, path) { return nullptr; } + } + + return newDiskReadableFile(kj::AutoCloseHandle(newHandle)); + } + + Maybe tryOpenSubdirInternal(PathPtr path) const { + HANDLE newHandle; + KJ_WIN32_HANDLE_ERRORS(newHandle = CreateFileW( + nativePath(path).begin(), + GENERIC_READ, + // When opening directories, we do NOT use FILE_SHARE_DELETE, because we need the directory + // path to remain valid. + // + // TODO(someday): Use NtCreateFile() and related "internal" APIs that allow for + // openat()-like behavior? + FILE_SHARE_READ | FILE_SHARE_WRITE, + NULL, + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, // apparently, this flag is required for directories + NULL)) { + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + return nullptr; + default: + KJ_FAIL_WIN32("CreateFile(directoryPath, OPEN_EXISTING)", error, path) { return nullptr; } + } + + kj::AutoCloseHandle ownHandle(newHandle); + + BY_HANDLE_FILE_INFORMATION info; + KJ_WIN32(GetFileInformationByHandle(ownHandle, &info)); + + KJ_REQUIRE(info.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY, "not a directory", path); + return kj::mv(ownHandle); + } + + Maybe> tryOpenSubdir(PathPtr path) const { + return tryOpenSubdirInternal(path).map([&](AutoCloseHandle&& handle) { + return newDiskReadableDirectory(kj::mv(handle), KJ_ASSERT_NONNULL(dirPath).append(path)); + }); + } + + Maybe tryReadlink(PathPtr path) const { + // Windows symlinks work differently from Unix. Generally they are set up by the system + // administrator and apps are expected to treat them transparently. Hence, on Windows, we act + // as if nothing is a symlink by always returning null here. + // TODO(someday): If we want to treat Windows symlinks more like Unix ones, start by reverting + // the comment that added this comment. + return nullptr; + } + + // Directory ----------------------------------------------------------------- + + static LPSECURITY_ATTRIBUTES makeSecAttr(WriteMode mode) { + if (has(mode, WriteMode::PRIVATE)) { + KJ_UNIMPLEMENTED("WriteMode::PRIVATE on Win32 is not implemented"); + } + + return nullptr; + } + + bool tryMkdir(PathPtr path, WriteMode mode, bool noThrow) const { + // Internal function to make a directory. + + auto filename = nativePath(path); + + KJ_WIN32_HANDLE_ERRORS(CreateDirectoryW(filename.begin(), makeSecAttr(mode))) { + case ERROR_ALREADY_EXISTS: + case ERROR_FILE_EXISTS: { + // Apparently this path exists. + if (!has(mode, WriteMode::MODIFY)) { + // Require exclusive create. + return false; + } + + // MODIFY is allowed, so we just need to check whether the existing entry is a directory. + DWORD attr = GetFileAttributesW(filename.begin()); + if (attr == INVALID_FILE_ATTRIBUTES) { + // CreateDirectory() says it already exists but we can't get attributes. Maybe it's a + // dangling link, or maybe we can't access it for some reason. Assume failure. + // + // TODO(someday): Maybe we should be creating the directory at the target of the + // link? + goto failed; + } + return attr & FILE_ATTRIBUTE_DIRECTORY; + } + case ERROR_PATH_NOT_FOUND: + if (has(mode, WriteMode::CREATE_PARENT) && path.size() > 0 && + tryMkdir(path.parent(), WriteMode::CREATE | WriteMode::MODIFY | + WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + return tryMkdir(path, mode - WriteMode::CREATE_PARENT, noThrow); + } else { + goto failed; + } + default: + failed: + if (noThrow) { + // Caller requested no throwing. + return false; + } else { + KJ_FAIL_WIN32("CreateDirectory", error, path); + } + } + + return true; + } + + kj::Maybe> createNamedTemporary( + PathPtr finalName, WriteMode mode, Path& kjTempPath, + Function tryCreate) const { + // Create a temporary file which will eventually replace `finalName`. + // + // Calls `tryCreate` to actually create the temporary, passing in the desired path. tryCreate() + // is expected to behave like a win32 call, returning a BOOL and setting `GetLastError()` on + // error. tryCreate() MUST fail with ERROR_{FILE,ALREADY}_EXISTS if the path exists -- this is + // not checked in advance, since it needs to be checked atomically. In the case of + // ERROR_*_EXISTS, tryCreate() will be called again with a new path. + // + // Returns the temporary path that succeeded. Only returns nullptr if there was an exception + // but we're compiled with -fno-exceptions. + // + // The optional parameter `kjTempPath` is filled in with the KJ Path of the temporary. + + if (finalName.size() == 0) { + KJ_FAIL_REQUIRE("can't replace self") { break; } + return nullptr; + } + + static uint counter = 0; + static const DWORD pid = GetCurrentProcessId(); + auto tempName = kj::str(HIDDEN_PREFIX, pid, '.', counter++, '.', + finalName.basename()[0], ".partial"); + kjTempPath = finalName.parent().append(tempName); + auto path = nativePath(kjTempPath); + + KJ_WIN32_HANDLE_ERRORS(tryCreate(path.begin())) { + case ERROR_ALREADY_EXISTS: + case ERROR_FILE_EXISTS: + // Try again with a new counter value. + return createNamedTemporary(finalName, mode, kj::mv(tryCreate)); + case ERROR_PATH_NOT_FOUND: + if (has(mode, WriteMode::CREATE_PARENT) && finalName.size() > 1 && + tryMkdir(finalName.parent(), WriteMode::CREATE | WriteMode::MODIFY | + WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + mode = mode - WriteMode::CREATE_PARENT; + return createNamedTemporary(finalName, mode, kj::mv(tryCreate)); + } + KJ_FALLTHROUGH; + default: + KJ_FAIL_WIN32("create(path)", error, path) { break; } + return nullptr; + } + + return kj::mv(path); + } + + kj::Maybe> createNamedTemporary( + PathPtr finalName, WriteMode mode, Function tryCreate) const { + Path dummy = nullptr; + return createNamedTemporary(finalName, mode, dummy, kj::mv(tryCreate)); + } + + bool tryReplaceNode(PathPtr path, WriteMode mode, + Function tryCreate) const { + // Replaces the given path with an object created by calling tryCreate(). + // + // tryCreate() must behave like a win32 call which creates the node at the path passed to it, + // returning FALSE error. If the path passed to tryCreate already exists, it MUST fail with + // ERROR_{FILE,ALREADY}_EXISTS. + // + // When `mode` includes MODIFY, replaceNode() reacts to ERROR_*_EXISTS by creating the + // node in a temporary location and then rename()ing it into place. + + if (path.size() == 0) { + KJ_FAIL_REQUIRE("can't replace self") { return false; } + } + + auto filename = nativePath(path); + + if (has(mode, WriteMode::CREATE)) { + // First try just cerating the node in-place. + KJ_WIN32_HANDLE_ERRORS(tryCreate(filename.begin())) { + case ERROR_ALREADY_EXISTS: + case ERROR_FILE_EXISTS: + // Target exists. + if (has(mode, WriteMode::MODIFY)) { + // Fall back to MODIFY path, below. + break; + } else { + return false; + } + case ERROR_PATH_NOT_FOUND: + if (has(mode, WriteMode::CREATE_PARENT) && path.size() > 0 && + tryMkdir(path.parent(), WriteMode::CREATE | WriteMode::MODIFY | + WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + return tryReplaceNode(path, mode - WriteMode::CREATE_PARENT, kj::mv(tryCreate)); + } + KJ_FALLTHROUGH; + default: + KJ_FAIL_WIN32("create(path)", error, path) { return false; } + } else { + // Success. + return true; + } + } + + // Either we don't have CREATE mode or the target already exists. We need to perform a + // replacement instead. + + KJ_IF_MAYBE(tempPath, createNamedTemporary(path, mode, kj::mv(tryCreate))) { + if (tryCommitReplacement(path, *tempPath, mode)) { + return true; + } else { + KJ_WIN32_HANDLE_ERRORS(DeleteFileW(tempPath->begin())) { + case ERROR_FILE_NOT_FOUND: + // meh + break; + default: + KJ_FAIL_WIN32("DeleteFile(tempPath)", error, dbgStr(*tempPath)); + } + return false; + } + } else { + // threw, but exceptions are disabled + return false; + } + } + + Maybe tryOpenFileInternal(PathPtr path, WriteMode mode, bool append) const { + DWORD disposition; + if (has(mode, WriteMode::MODIFY)) { + if (has(mode, WriteMode::CREATE)) { + disposition = OPEN_ALWAYS; + } else { + disposition = OPEN_EXISTING; + } + } else { + if (has(mode, WriteMode::CREATE)) { + disposition = CREATE_NEW; + } else { + // Neither CREATE nor MODIFY -- impossible to satisfy preconditions. + return nullptr; + } + } + + DWORD access = GENERIC_READ | GENERIC_WRITE; + if (append) { + // FILE_GENERIC_WRITE includes both FILE_APPEND_DATA and FILE_WRITE_DATA, but we only want + // the former. There are also a zillion other bits that we need, annoyingly. + access = (FILE_READ_ATTRIBUTES | FILE_GENERIC_WRITE) & ~FILE_WRITE_DATA; + } + + auto filename = path.toString(); + + HANDLE newHandle; + KJ_WIN32_HANDLE_ERRORS(newHandle = CreateFileW( + nativePath(path).begin(), + access, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + makeSecAttr(mode), + disposition, + FILE_ATTRIBUTE_NORMAL, + NULL)) { + case ERROR_PATH_NOT_FOUND: + if (has(mode, WriteMode::CREATE)) { + // A parent directory didn't exist. Maybe create it. + if (has(mode, WriteMode::CREATE_PARENT) && path.size() > 0 && + tryMkdir(path.parent(), WriteMode::CREATE | WriteMode::MODIFY | + WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + return tryOpenFileInternal(path, mode - WriteMode::CREATE_PARENT, append); + } + + KJ_FAIL_REQUIRE("parent is not a directory", path) { return nullptr; } + } else { + // MODIFY-only mode. ERROR_PATH_NOT_FOUND = parent path doesn't exist = return null. + return nullptr; + } + case ERROR_FILE_NOT_FOUND: + if (!has(mode, WriteMode::CREATE)) { + // MODIFY-only mode. ERROR_FILE_NOT_FOUND = doesn't exist = return null. + return nullptr; + } + goto failed; + case ERROR_ALREADY_EXISTS: + case ERROR_FILE_EXISTS: + if (!has(mode, WriteMode::MODIFY)) { + // CREATE-only mode. ERROR_ALREADY_EXISTS = already exists = return null. + return nullptr; + } + goto failed; + default: + failed: + KJ_FAIL_WIN32("CreateFile", error, path) { return nullptr; } + } + + return kj::AutoCloseHandle(newHandle); + } + + bool tryCommitReplacement( + PathPtr toPath, ArrayPtr fromPath, + WriteMode mode, kj::Maybe pathForCreatingParents = nullptr) const { + // Try to use MoveFileEx() to replace `toPath` with `fromPath`. + + auto wToPath = nativePath(toPath); + + DWORD flags = has(mode, WriteMode::MODIFY) ? MOVEFILE_REPLACE_EXISTING : 0; + + if (!has(mode, WriteMode::CREATE)) { + // Non-atomically verify that target exists. There's no way to make this atomic. + DWORD result = GetFileAttributesW(wToPath.begin()); + if (result == INVALID_FILE_ATTRIBUTES) { + auto error = GetLastError(); + switch (error) { + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + return false; + default: + KJ_FAIL_WIN32("GetFileAttributesEx(toPath)", error, toPath) { return false; } + } + } + } + + KJ_WIN32_HANDLE_ERRORS(MoveFileExW(fromPath.begin(), wToPath.begin(), flags)) { + case ERROR_ALREADY_EXISTS: + case ERROR_FILE_EXISTS: + // We must not be in MODIFY mode. + return false; + case ERROR_PATH_NOT_FOUND: + KJ_IF_MAYBE(p, pathForCreatingParents) { + if (has(mode, WriteMode::CREATE_PARENT) && + p->size() > 0 && tryMkdir(p->parent(), + WriteMode::CREATE | WriteMode::MODIFY | WriteMode::CREATE_PARENT, true)) { + // Retry, but make sure we don't try to create the parent again. + return tryCommitReplacement(toPath, fromPath, mode - WriteMode::CREATE_PARENT); + } + } + goto default_; + + case ERROR_ACCESS_DENIED: { + // This often means that the target already exists and cannot be replaced, e.g. because + // it is a directory. Move it out of the way first, then move our replacement in, then + // delete the old thing. + + if (has(mode, WriteMode::MODIFY)) { + KJ_IF_MAYBE(tempName, + createNamedTemporary(toPath, WriteMode::CREATE, [&](const wchar_t* tempName2) { + return MoveFileW(wToPath.begin(), tempName2); + })) { + KJ_WIN32_HANDLE_ERRORS(MoveFileW(fromPath.begin(), wToPath.begin())) { + default: + // Try to move back. + MoveFileW(tempName->begin(), wToPath.begin()); + KJ_FAIL_WIN32("MoveFile", error, dbgStr(fromPath), dbgStr(wToPath)) { + return false; + } + } + + // Succeeded, delete temporary. + rmrf(*tempName); + return true; + } else { + // createNamedTemporary() threw exception but exceptions are disabled. + return false; + } + } else { + // Not MODIFY, so no overwrite allowed. If the file really does exist, we need to return + // false. + if (GetFileAttributesW(wToPath.begin()) != INVALID_FILE_ATTRIBUTES) { + return false; + } + } + + goto default_; + } + + default: + default_: + KJ_FAIL_WIN32("MoveFileEx", error, dbgStr(wToPath), dbgStr(fromPath)) { return false; } + } + + return true; + } + + template + class ReplacerImpl final: public Directory::Replacer { + public: + ReplacerImpl(Own&& object, const DiskHandle& parentDirectory, + Array&& tempPath, Path&& path, WriteMode mode) + : Directory::Replacer(mode), + object(kj::mv(object)), parentDirectory(parentDirectory), + tempPath(kj::mv(tempPath)), path(kj::mv(path)) {} + + ~ReplacerImpl() noexcept(false) { + if (!committed) { + object = Own(); // Force close of handle before trying to delete. + + if (kj::isSameType()) { + KJ_WIN32(DeleteFileW(tempPath.begin())) { break; } + } else { + rmrfChildren(tempPath); + KJ_WIN32(RemoveDirectoryW(tempPath.begin())) { break; } + } + } + } + + const T& get() override { + return *object; + } + + bool tryCommit() override { + KJ_ASSERT(!committed, "already committed") { return false; } + + // For directories, we intentionally don't use FILE_SHARE_DELETE on our handle because if the + // directory name changes our paths would be wrong. But, this means we can't rename the + // directory here to commit it. So, we need to close the handle and then re-open it + // afterwards. Ick. + AutoCloseHandle* objectHandle = getHandlePointerHack(*object); + if (kj::isSameType()) { + *objectHandle = nullptr; + } + KJ_DEFER({ + if (kj::isSameType()) { + HANDLE newHandle = nullptr; + KJ_WIN32(newHandle = CreateFileW( + committed ? parentDirectory.nativePath(path).begin() : tempPath.begin(), + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE, + NULL, + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, // apparently, this flag is required for directories + NULL)) { return; } + *objectHandle = AutoCloseHandle(newHandle); + *getPathPointerHack(*object) = KJ_ASSERT_NONNULL(parentDirectory.dirPath).append(path); + } + }); + + return committed = parentDirectory.tryCommitReplacement( + path, tempPath, Directory::Replacer::mode); + } + + private: + Own object; + const DiskHandle& parentDirectory; + Array tempPath; + Path path; + bool committed = false; // true if *successfully* committed (in which case tempPath is gone) + }; + + template + class BrokenReplacer final: public Directory::Replacer { + // For recovery path when exceptions are disabled. + + public: + BrokenReplacer(Own inner) + : Directory::Replacer(WriteMode::CREATE | WriteMode::MODIFY), + inner(kj::mv(inner)) {} + + const T& get() override { return *inner; } + bool tryCommit() override { return false; } + + private: + Own inner; + }; + + Maybe> tryOpenFile(PathPtr path, WriteMode mode) const { + return tryOpenFileInternal(path, mode, false).map(newDiskFile); + } + + Own> replaceFile(PathPtr path, WriteMode mode) const { + HANDLE newHandle_; + KJ_IF_MAYBE(temp, createNamedTemporary(path, mode, + [&](const wchar_t* candidatePath) { + newHandle_ = CreateFileW( + candidatePath, + GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + makeSecAttr(mode), + CREATE_NEW, + FILE_ATTRIBUTE_NORMAL, + NULL); + return newHandle_ != INVALID_HANDLE_VALUE; + })) { + AutoCloseHandle newHandle(newHandle_); + return heap>(newDiskFile(kj::mv(newHandle)), *this, kj::mv(*temp), + path.clone(), mode); + } else { + // threw, but exceptions are disabled + return heap>(newInMemoryFile(nullClock())); + } + } + + Own createTemporary() const { + HANDLE newHandle_; + KJ_IF_MAYBE(temp, createNamedTemporary(Path("unnamed"), WriteMode::CREATE, + [&](const wchar_t* candidatePath) { + newHandle_ = CreateFileW( + candidatePath, + GENERIC_READ | GENERIC_WRITE, + 0, + NULL, // TODO(someday): makeSecAttr(WriteMode::PRIVATE), when it's implemented + CREATE_NEW, + FILE_ATTRIBUTE_TEMPORARY | FILE_FLAG_DELETE_ON_CLOSE, + NULL); + return newHandle_ != INVALID_HANDLE_VALUE; + })) { + AutoCloseHandle newHandle(newHandle_); + return newDiskFile(kj::mv(newHandle)); + } else { + // threw, but exceptions are disabled + return newInMemoryFile(nullClock()); + } + } + + Maybe> tryAppendFile(PathPtr path, WriteMode mode) const { + return tryOpenFileInternal(path, mode, true).map(newDiskAppendableFile); + } + + Maybe> tryOpenSubdir(PathPtr path, WriteMode mode) const { + // Must create before open. + if (has(mode, WriteMode::CREATE)) { + if (!tryMkdir(path, mode, false)) return nullptr; + } + + return tryOpenSubdirInternal(path).map([&](AutoCloseHandle&& handle) { + return newDiskDirectory(kj::mv(handle), KJ_ASSERT_NONNULL(dirPath).append(path)); + }); + } + + Own> replaceSubdir(PathPtr path, WriteMode mode) const { + Path kjTempPath = nullptr; + KJ_IF_MAYBE(temp, createNamedTemporary(path, mode, kjTempPath, + [&](const wchar_t* candidatePath) { + return CreateDirectoryW(candidatePath, makeSecAttr(mode)); + })) { + HANDLE subdirHandle_; + KJ_WIN32_HANDLE_ERRORS(subdirHandle_ = CreateFileW( + temp->begin(), + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE, + NULL, + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, // apparently, this flag is required for directories + NULL)) { + default: + KJ_FAIL_WIN32("CreateFile(just-created-temporary, OPEN_EXISTING)", error, path) { + goto fail; + } + } + + AutoCloseHandle subdirHandle(subdirHandle_); + return heap>( + newDiskDirectory(kj::mv(subdirHandle), + KJ_ASSERT_NONNULL(dirPath).append(kj::mv(kjTempPath))), + *this, kj::mv(*temp), path.clone(), mode); + } else { + // threw, but exceptions are disabled + fail: + return heap>(newInMemoryDirectory(nullClock())); + } + } + + bool trySymlink(PathPtr linkpath, StringPtr content, WriteMode mode) const { + // We can't really create symlinks on Windows. Reasons: + // - We'd need to know whether the target is a file or a directory to pass the correct flags. + // That means we'd need to evaluate the link content and track down the target. What if the + // target doesn't exist? It's unclear if this is even allowed on Windows. + // - Apparently, creating symlinks is a privileged operation on Windows prior to Windows 10. + // The flag SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE is very new. + KJ_UNIMPLEMENTED( + "Creating symbolic links is not supported on Windows due to semantic differences."); + } + + bool tryTransfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode, const Directory& self) const { + KJ_REQUIRE(toPath.size() > 0, "can't replace self") { return false; } + + // Try to get the "from" path. + Array rawFromPath; +#if !KJ_NO_RTTI + // Oops, dynamicDowncastIfAvailable() doesn't work since this isn't a downcast, it's a + // side-cast... + if (auto dh = dynamic_cast(&fromDirectory)) { + rawFromPath = dh->nativePath(fromPath); + } else +#endif + KJ_IF_MAYBE(h, fromDirectory.getWin32Handle()) { + // Can't downcast to DiskHandle, but getWin32Handle() returns a handle... maybe RTTI is + // disabled? Or maybe this is some kind of wrapper? + rawFromPath = getPathFromHandle(*h).append(fromPath).forWin32Api(true); + } else { + // Not a disk directory, so fall back to default implementation. + return self.Directory::tryTransfer(toPath, toMode, fromDirectory, fromPath, mode); + } + + if (mode == TransferMode::LINK) { + return tryReplaceNode(toPath, toMode, [&](const wchar_t* candidatePath) { + return CreateHardLinkW(candidatePath, rawFromPath.begin(), NULL); + }); + } else if (mode == TransferMode::MOVE) { + return tryCommitReplacement(toPath, rawFromPath, toMode, toPath); + } else if (mode == TransferMode::COPY) { + // We can accellerate copies on Windows. + + if (!has(toMode, WriteMode::CREATE)) { + // Non-atomically verify that target exists. There's no way to make this atomic. + if (!exists(toPath)) return false; + } + + bool failIfExists = !has(toMode, WriteMode::MODIFY); + KJ_WIN32_HANDLE_ERRORS( + CopyFileW(rawFromPath.begin(), nativePath(toPath).begin(), failIfExists)) { + case ERROR_ALREADY_EXISTS: + case ERROR_FILE_EXISTS: + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + return false; + case ERROR_ACCESS_DENIED: + // This usually means that fromPath was a directory or toPath was a directory. Fall back + // to default implementation. + break; + default: + KJ_FAIL_WIN32("CopyFile", error, fromPath, toPath) { return false; } + } else { + // Copy succeeded. + return true; + } + } + + // OK, we can't do anything efficient using the OS. Fall back to default implementation. + return self.Directory::tryTransfer(toPath, toMode, fromDirectory, fromPath, mode); + } + + bool tryRemove(PathPtr path) const { + return rmrf(nativePath(path)); + } +}; + +#define FSNODE_METHODS \ + Maybe getWin32Handle() const override { return DiskHandle::getWin32Handle(); } \ + \ + Metadata stat() const override { return DiskHandle::stat(); } \ + void sync() const override { DiskHandle::sync(); } \ + void datasync() const override { DiskHandle::datasync(); } + +class DiskReadableFile final: public ReadableFile, public DiskHandle { +public: + DiskReadableFile(AutoCloseHandle&& handle): DiskHandle(kj::mv(handle), nullptr) {} + + Own cloneFsNode() const override { + return heap(DiskHandle::clone()); + } + + FSNODE_METHODS + + size_t read(uint64_t offset, ArrayPtr buffer) const override { + return DiskHandle::read(offset, buffer); + } + Array mmap(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmap(offset, size); + } + Array mmapPrivate(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmapPrivate(offset, size); + } +}; + +class DiskAppendableFile final: public AppendableFile, public DiskHandle { +public: + DiskAppendableFile(AutoCloseHandle&& handle) + : DiskHandle(kj::mv(handle), nullptr), + stream(DiskHandle::handle.get()) {} + + Own cloneFsNode() const override { + return heap(DiskHandle::clone()); + } + + FSNODE_METHODS + + void write(const void* buffer, size_t size) override { stream.write(buffer, size); } + void write(ArrayPtr> pieces) override { + implicitCast(stream).write(pieces); + } + +private: + HandleOutputStream stream; +}; + +class DiskFile final: public File, public DiskHandle { +public: + DiskFile(AutoCloseHandle&& handle): DiskHandle(kj::mv(handle), nullptr) {} + + Own cloneFsNode() const override { + return heap(DiskHandle::clone()); + } + + FSNODE_METHODS + + size_t read(uint64_t offset, ArrayPtr buffer) const override { + return DiskHandle::read(offset, buffer); + } + Array mmap(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmap(offset, size); + } + Array mmapPrivate(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmapPrivate(offset, size); + } + + void write(uint64_t offset, ArrayPtr data) const override { + DiskHandle::write(offset, data); + } + void zero(uint64_t offset, uint64_t size) const override { + DiskHandle::zero(offset, size); + } + void truncate(uint64_t size) const override { + DiskHandle::truncate(size); + } + Own mmapWritable(uint64_t offset, uint64_t size) const override { + return DiskHandle::mmapWritable(offset, size); + } + // copy() is not optimized on Windows. +}; + +class DiskReadableDirectory final: public ReadableDirectory, public DiskHandle { +public: + DiskReadableDirectory(AutoCloseHandle&& handle, Path&& path) + : DiskHandle(kj::mv(handle), kj::mv(path)) {} + + Own cloneFsNode() const override { + return heap(DiskHandle::clone(), KJ_ASSERT_NONNULL(dirPath).clone()); + } + + FSNODE_METHODS + + Array listNames() const override { return DiskHandle::listNames(); } + Array listEntries() const override { return DiskHandle::listEntries(); } + bool exists(PathPtr path) const override { return DiskHandle::exists(path); } + Maybe tryLstat(PathPtr path) const override { + return DiskHandle::tryLstat(path); + } + Maybe> tryOpenFile(PathPtr path) const override { + return DiskHandle::tryOpenFile(path); + } + Maybe> tryOpenSubdir(PathPtr path) const override { + return DiskHandle::tryOpenSubdir(path); + } + Maybe tryReadlink(PathPtr path) const override { return DiskHandle::tryReadlink(path); } +}; + +class DiskDirectoryBase: public Directory, public DiskHandle { +public: + DiskDirectoryBase(AutoCloseHandle&& handle, Path&& path) + : DiskHandle(kj::mv(handle), kj::mv(path)) {} + + bool exists(PathPtr path) const override { return DiskHandle::exists(path); } + Maybe tryLstat(PathPtr path) const override { return DiskHandle::tryLstat(path); } + Maybe> tryOpenFile(PathPtr path) const override { + return DiskHandle::tryOpenFile(path); + } + Maybe> tryOpenSubdir(PathPtr path) const override { + return DiskHandle::tryOpenSubdir(path); + } + Maybe tryReadlink(PathPtr path) const override { return DiskHandle::tryReadlink(path); } + + Maybe> tryOpenFile(PathPtr path, WriteMode mode) const override { + return DiskHandle::tryOpenFile(path, mode); + } + Own> replaceFile(PathPtr path, WriteMode mode) const override { + return DiskHandle::replaceFile(path, mode); + } + Maybe> tryAppendFile(PathPtr path, WriteMode mode) const override { + return DiskHandle::tryAppendFile(path, mode); + } + Maybe> tryOpenSubdir(PathPtr path, WriteMode mode) const override { + return DiskHandle::tryOpenSubdir(path, mode); + } + Own> replaceSubdir(PathPtr path, WriteMode mode) const override { + return DiskHandle::replaceSubdir(path, mode); + } + bool trySymlink(PathPtr linkpath, StringPtr content, WriteMode mode) const override { + return DiskHandle::trySymlink(linkpath, content, mode); + } + bool tryTransfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode) const override { + return DiskHandle::tryTransfer(toPath, toMode, fromDirectory, fromPath, mode, *this); + } + // tryTransferTo() not implemented because we have nothing special we can do. + bool tryRemove(PathPtr path) const override { + return DiskHandle::tryRemove(path); + } +}; + +class DiskDirectory final: public DiskDirectoryBase { +public: + DiskDirectory(AutoCloseHandle&& handle, Path&& path) + : DiskDirectoryBase(kj::mv(handle), kj::mv(path)) {} + + Own cloneFsNode() const override { + return heap(DiskHandle::clone(), KJ_ASSERT_NONNULL(dirPath).clone()); + } + + FSNODE_METHODS + + Array listNames() const override { return DiskHandle::listNames(); } + Array listEntries() const override { return DiskHandle::listEntries(); } + Own createTemporary() const override { + return DiskHandle::createTemporary(); + } +}; + +class RootDiskDirectory final: public DiskDirectoryBase { + // On Windows, the root directory is special. + // + // HACK: We only override a few functions of DiskDirectory, and we rely on the fact that + // Path::forWin32Api(true) throws an exception complaining about missing drive letter if the + // path is totally empty. + +public: + RootDiskDirectory(): DiskDirectoryBase(nullptr, Path(nullptr)) {} + + Own cloneFsNode() const override { + return heap(); + } + + Metadata stat() const override { + return { Type::DIRECTORY, 0, 0, UNIX_EPOCH, 1, 0 }; + } + void sync() const override {} + void datasync() const override {} + + Array listNames() const override { + return KJ_MAP(e, listEntries()) { return kj::mv(e.name); }; + } + Array listEntries() const override { + DWORD drives = GetLogicalDrives(); + if (drives == 0) { + KJ_FAIL_WIN32("GetLogicalDrives()", GetLastError()) { return nullptr; } + } + + Vector results; + for (uint i = 0; i < 26; i++) { + if (drives & (1 << i)) { + char name[2] = { static_cast('A' + i), ':' }; + results.add(Entry { FsNode::Type::DIRECTORY, kj::heapString(name, 2) }); + } + } + + return results.releaseAsArray(); + } + + Own createTemporary() const override { + KJ_FAIL_REQUIRE("can't create temporaries in Windows pseudo-root directory (the drive list)"); + } +}; + +class DiskFilesystem final: public Filesystem { +public: + DiskFilesystem() + : DiskFilesystem(computeCurrentPath()) {} + DiskFilesystem(Path currentPath) + : current(KJ_ASSERT_NONNULL(root.tryOpenSubdirInternal(currentPath), + "path returned by GetCurrentDirectory() doesn't exist?"), + kj::mv(currentPath)) {} + + const Directory& getRoot() const override { + return root; + } + + const Directory& getCurrent() const override { + return current; + } + + PathPtr getCurrentPath() const override { + return KJ_ASSERT_NONNULL(current.dirPath); + } + +private: + RootDiskDirectory root; + DiskDirectory current; + + static Path computeCurrentPath() { + DWORD tryLen = MAX_PATH; + for (;;) { + auto temp = kj::heapArray(tryLen + 1); + DWORD len = GetCurrentDirectoryW(temp.size(), temp.begin()); + if (len == 0) { + KJ_FAIL_WIN32("GetCurrentDirectory", GetLastError()) { break; } + return Path("."); + } + if (len < temp.size()) { + return Path::parseWin32Api(temp.slice(0, len)); + } + // Try again with new length. + tryLen = len; + } + } +}; + +} // namespace + +Own newDiskReadableFile(AutoCloseHandle fd) { + return heap(kj::mv(fd)); +} +Own newDiskAppendableFile(AutoCloseHandle fd) { + return heap(kj::mv(fd)); +} +Own newDiskFile(AutoCloseHandle fd) { + return heap(kj::mv(fd)); +} +Own newDiskReadableDirectory(AutoCloseHandle fd) { + return heap(kj::mv(fd), getPathFromHandle(fd)); +} +static Own newDiskReadableDirectory(AutoCloseHandle fd, Path&& path) { + return heap(kj::mv(fd), kj::mv(path)); +} +Own newDiskDirectory(AutoCloseHandle fd) { + return heap(kj::mv(fd), getPathFromHandle(fd)); +} +static Own newDiskDirectory(AutoCloseHandle fd, Path&& path) { + return heap(kj::mv(fd), kj::mv(path)); +} + +Own newDiskFilesystem() { + return heap(); +} + +static AutoCloseHandle* getHandlePointerHack(Directory& dir) { + return &static_cast(dir).handle; +} +static Path* getPathPointerHack(Directory& dir) { + return &KJ_ASSERT_NONNULL(static_cast(dir).dirPath); +} + +} // namespace kj + +#endif // _WIN32 diff --git a/c++/src/kj/filesystem-test.c++ b/c++/src/kj/filesystem-test.c++ new file mode 100644 index 0000000000..f3eae2fe79 --- /dev/null +++ b/c++/src/kj/filesystem-test.c++ @@ -0,0 +1,761 @@ +// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "filesystem.h" +#include "test.h" +#include + +namespace kj { +namespace { + +KJ_TEST("Path") { + KJ_EXPECT(Path(nullptr).toString() == "."); + KJ_EXPECT(Path(nullptr).toString(true) == "/"); + KJ_EXPECT(Path("foo").toString() == "foo"); + KJ_EXPECT(Path("foo").toString(true) == "/foo"); + + KJ_EXPECT(Path({"foo", "bar"}).toString() == "foo/bar"); + KJ_EXPECT(Path({"foo", "bar"}).toString(true) == "/foo/bar"); + + KJ_EXPECT(Path::parse("foo/bar").toString() == "foo/bar"); + KJ_EXPECT(Path::parse("foo//bar").toString() == "foo/bar"); + KJ_EXPECT(Path::parse("foo/./bar").toString() == "foo/bar"); + KJ_EXPECT(Path::parse("foo/../bar").toString() == "bar"); + KJ_EXPECT(Path::parse("foo/bar/..").toString() == "foo"); + KJ_EXPECT(Path::parse("foo/bar/../..").toString() == "."); + + KJ_EXPECT(Path({"foo", "bar"}).eval("baz").toString() == "foo/bar/baz"); + KJ_EXPECT(Path({"foo", "bar"}).eval("./baz").toString() == "foo/bar/baz"); + KJ_EXPECT(Path({"foo", "bar"}).eval("baz/qux").toString() == "foo/bar/baz/qux"); + KJ_EXPECT(Path({"foo", "bar"}).eval("baz//qux").toString() == "foo/bar/baz/qux"); + KJ_EXPECT(Path({"foo", "bar"}).eval("baz/./qux").toString() == "foo/bar/baz/qux"); + KJ_EXPECT(Path({"foo", "bar"}).eval("baz/../qux").toString() == "foo/bar/qux"); + KJ_EXPECT(Path({"foo", "bar"}).eval("baz/qux/..").toString() == "foo/bar/baz"); + KJ_EXPECT(Path({"foo", "bar"}).eval("../baz").toString() == "foo/baz"); + KJ_EXPECT(Path({"foo", "bar"}).eval("baz/../../qux/").toString() == "foo/qux"); + KJ_EXPECT(Path({"foo", "bar"}).eval("/baz/qux").toString() == "baz/qux"); + KJ_EXPECT(Path({"foo", "bar"}).eval("//baz/qux").toString() == "baz/qux"); + KJ_EXPECT(Path({"foo", "bar"}).eval("/baz/../qux").toString() == "qux"); + + KJ_EXPECT(Path({"foo", "bar"}).basename()[0] == "bar"); + KJ_EXPECT(Path({"foo", "bar", "baz"}).parent().toString() == "foo/bar"); + + KJ_EXPECT(Path({"foo", "bar"}).append("baz").toString() == "foo/bar/baz"); + KJ_EXPECT(Path({"foo", "bar"}).append(Path({"baz", "qux"})).toString() == "foo/bar/baz/qux"); + + { + // Test methods which are overloaded for && on a non-rvalue path. + Path path({"foo", "bar"}); + KJ_EXPECT(path.eval("baz").toString() == "foo/bar/baz"); + KJ_EXPECT(path.eval("./baz").toString() == "foo/bar/baz"); + KJ_EXPECT(path.eval("baz/qux").toString() == "foo/bar/baz/qux"); + KJ_EXPECT(path.eval("baz//qux").toString() == "foo/bar/baz/qux"); + KJ_EXPECT(path.eval("baz/./qux").toString() == "foo/bar/baz/qux"); + KJ_EXPECT(path.eval("baz/../qux").toString() == "foo/bar/qux"); + KJ_EXPECT(path.eval("baz/qux/..").toString() == "foo/bar/baz"); + KJ_EXPECT(path.eval("../baz").toString() == "foo/baz"); + KJ_EXPECT(path.eval("baz/../../qux/").toString() == "foo/qux"); + KJ_EXPECT(path.eval("/baz/qux").toString() == "baz/qux"); + KJ_EXPECT(path.eval("/baz/../qux").toString() == "qux"); + + KJ_EXPECT(path.basename()[0] == "bar"); + KJ_EXPECT(path.parent().toString() == "foo"); + + KJ_EXPECT(path.append("baz").toString() == "foo/bar/baz"); + KJ_EXPECT(path.append(Path({"baz", "qux"})).toString() == "foo/bar/baz/qux"); + } + + KJ_EXPECT(kj::str(Path({"foo", "bar"})) == "foo/bar"); +} + +KJ_TEST("Path comparisons") { + KJ_EXPECT(Path({"foo", "bar"}) == Path({"foo", "bar"})); + KJ_EXPECT(!(Path({"foo", "bar"}) != Path({"foo", "bar"}))); + KJ_EXPECT(Path({"foo", "bar"}) != Path({"foo", "baz"})); + KJ_EXPECT(!(Path({"foo", "bar"}) == Path({"foo", "baz"}))); + + KJ_EXPECT(Path({"foo", "bar"}) != Path({"fob", "bar"})); + KJ_EXPECT(Path({"foo", "bar"}) != Path({"foo", "bar", "baz"})); + KJ_EXPECT(Path({"foo", "bar", "baz"}) != Path({"foo", "bar"})); + + KJ_EXPECT(Path({"foo", "bar"}) <= Path({"foo", "bar"})); + KJ_EXPECT(Path({"foo", "bar"}) >= Path({"foo", "bar"})); + KJ_EXPECT(!(Path({"foo", "bar"}) < Path({"foo", "bar"}))); + KJ_EXPECT(!(Path({"foo", "bar"}) > Path({"foo", "bar"}))); + + KJ_EXPECT(Path({"foo", "bar"}) < Path({"foo", "bar", "baz"})); + KJ_EXPECT(!(Path({"foo", "bar"}) > Path({"foo", "bar", "baz"}))); + KJ_EXPECT(Path({"foo", "bar", "baz"}) > Path({"foo", "bar"})); + KJ_EXPECT(!(Path({"foo", "bar", "baz"}) < Path({"foo", "bar"}))); + + KJ_EXPECT(Path({"foo", "bar"}) < Path({"foo", "baz"})); + KJ_EXPECT(Path({"foo", "bar"}) > Path({"foo", "baa"})); + KJ_EXPECT(Path({"foo", "bar"}) > Path({"foo"})); + + KJ_EXPECT(Path({"foo", "bar"}).startsWith(Path({}))); + KJ_EXPECT(Path({"foo", "bar"}).startsWith(Path({"foo"}))); + KJ_EXPECT(Path({"foo", "bar"}).startsWith(Path({"foo", "bar"}))); + KJ_EXPECT(!Path({"foo", "bar"}).startsWith(Path({"foo", "bar", "baz"}))); + KJ_EXPECT(!Path({"foo", "bar"}).startsWith(Path({"foo", "baz"}))); + KJ_EXPECT(!Path({"foo", "bar"}).startsWith(Path({"baz", "foo", "bar"}))); + KJ_EXPECT(!Path({"foo", "bar"}).startsWith(Path({"baz"}))); + + KJ_EXPECT(Path({"foo", "bar"}).endsWith(Path({}))); + KJ_EXPECT(Path({"foo", "bar"}).endsWith(Path({"bar"}))); + KJ_EXPECT(Path({"foo", "bar"}).endsWith(Path({"foo", "bar"}))); + KJ_EXPECT(!Path({"foo", "bar"}).endsWith(Path({"baz", "foo", "bar"}))); + KJ_EXPECT(!Path({"foo", "bar"}).endsWith(Path({"fob", "bar"}))); + KJ_EXPECT(!Path({"foo", "bar"}).endsWith(Path({"foo", "bar", "baz"}))); + KJ_EXPECT(!Path({"foo", "bar"}).endsWith(Path({"baz"}))); +} + +KJ_TEST("Path exceptions") { + KJ_EXPECT_THROW_MESSAGE("invalid path component", Path("")); + KJ_EXPECT_THROW_MESSAGE("invalid path component", Path(".")); + KJ_EXPECT_THROW_MESSAGE("invalid path component", Path("..")); + KJ_EXPECT_THROW_MESSAGE("NUL character", Path(StringPtr("foo\0bar", 7))); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", Path::parse("..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", Path::parse("../foo")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", Path::parse("foo/../..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("expected a relative path", Path::parse("/foo")); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("NUL character", Path::parse(kj::StringPtr("foo\0bar", 7))); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).eval("../../..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).eval("../baz/../../..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).eval("baz/../../../..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).eval("/..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).eval("/baz/../..")); + + KJ_EXPECT_THROW_MESSAGE("root path has no basename", Path(nullptr).basename()); + KJ_EXPECT_THROW_MESSAGE("root path has no parent", Path(nullptr).parent()); +} + +constexpr kj::ArrayPtr operator "" _a(const wchar_t* str, size_t n) { + return { str, n }; +} + +KJ_TEST("Win32 Path") { + KJ_EXPECT(Path({"foo", "bar"}).toWin32String() == "foo\\bar"); + KJ_EXPECT(Path({"foo", "bar"}).toWin32String(true) == "\\\\foo\\bar"); + KJ_EXPECT(Path({"c:", "foo", "bar"}).toWin32String(true) == "c:\\foo\\bar"); + KJ_EXPECT(Path({"A:", "foo", "bar"}).toWin32String(true) == "A:\\foo\\bar"); + + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz").toWin32String() == "foo\\bar\\baz"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("./baz").toWin32String() == "foo\\bar\\baz"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz/qux").toWin32String() == "foo\\bar\\baz\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz//qux").toWin32String() == "foo\\bar\\baz\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz/./qux").toWin32String() == "foo\\bar\\baz\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz/../qux").toWin32String() == "foo\\bar\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz/qux/..").toWin32String() == "foo\\bar\\baz"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("../baz").toWin32String() == "foo\\baz"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz/../../qux/").toWin32String() == "foo\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32(".\\baz").toWin32String() == "foo\\bar\\baz"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz\\qux").toWin32String() == "foo\\bar\\baz\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz\\\\qux").toWin32String() == "foo\\bar\\baz\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz\\.\\qux").toWin32String() == "foo\\bar\\baz\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz\\..\\qux").toWin32String() == "foo\\bar\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz\\qux\\..").toWin32String() == "foo\\bar\\baz"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("..\\baz").toWin32String() == "foo\\baz"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz\\..\\..\\qux\\").toWin32String() == "foo\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("baz\\../..\\qux/").toWin32String() == "foo\\qux"); + + KJ_EXPECT(Path({"c:", "foo", "bar"}).evalWin32("/baz/qux") + .toWin32String(true) == "c:\\baz\\qux"); + KJ_EXPECT(Path({"c:", "foo", "bar"}).evalWin32("\\baz\\qux") + .toWin32String(true) == "c:\\baz\\qux"); + KJ_EXPECT(Path({"c:", "foo", "bar"}).evalWin32("d:\\baz\\qux") + .toWin32String(true) == "d:\\baz\\qux"); + KJ_EXPECT(Path({"c:", "foo", "bar"}).evalWin32("d:\\baz\\..\\qux") + .toWin32String(true) == "d:\\qux"); + KJ_EXPECT(Path({"c:", "foo", "bar"}).evalWin32("\\\\baz\\qux") + .toWin32String(true) == "\\\\baz\\qux"); + KJ_EXPECT(Path({"foo", "bar"}).evalWin32("d:\\baz\\..\\qux") + .toWin32String(true) == "d:\\qux"); + KJ_EXPECT(Path({"foo", "bar", "baz"}).evalWin32("\\qux") + .toWin32String(true) == "\\\\foo\\bar\\qux"); + + KJ_EXPECT(Path({"foo", "bar"}).forWin32Api(false) == L"foo\\bar"); + KJ_EXPECT(Path({"foo", "bar"}).forWin32Api(true) == L"\\\\?\\UNC\\foo\\bar"); + KJ_EXPECT(Path({"c:", "foo", "bar"}).forWin32Api(true) == L"\\\\?\\c:\\foo\\bar"); + KJ_EXPECT(Path({"A:", "foo", "bar"}).forWin32Api(true) == L"\\\\?\\A:\\foo\\bar"); + + KJ_EXPECT(Path::parseWin32Api(L"\\\\?\\c:\\foo\\bar"_a).toString() == "c:/foo/bar"); + KJ_EXPECT(Path::parseWin32Api(L"\\\\?\\UNC\\foo\\bar"_a).toString() == "foo/bar"); + KJ_EXPECT(Path::parseWin32Api(L"c:\\foo\\bar"_a).toString() == "c:/foo/bar"); + KJ_EXPECT(Path::parseWin32Api(L"\\\\foo\\bar"_a).toString() == "foo/bar"); +} + +KJ_TEST("Win32 Path exceptions") { + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("colons are prohibited", + Path({"c:", "foo", "bar"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("colons are prohibited", + Path({"c:", "foo:bar"}).toWin32String(true)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"con"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"CON", "bar"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"foo", "cOn"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"prn"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"aux"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"NUL"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"nul.txt"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"com3"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"lpt9"}).toWin32String()); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("DOS reserved name", Path({"com1.hello"}).toWin32String()); + + KJ_EXPECT_THROW_MESSAGE("drive letter or netbios", Path({"?", "foo"}).toWin32String(true)); + + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).evalWin32("../../..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).evalWin32("../baz/../../..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).evalWin32("baz/../../../..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"foo", "bar"}).evalWin32("c:\\..\\..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("break out of starting", + Path({"c:", "foo", "bar"}).evalWin32("/baz/../../..")); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("must specify drive letter", + Path({"foo"}).evalWin32("\\baz\\qux")); +} + +KJ_TEST("WriteMode operators") { + WriteMode createOrModify = WriteMode::CREATE | WriteMode::MODIFY; + + KJ_EXPECT(has(createOrModify, WriteMode::MODIFY)); + KJ_EXPECT(has(createOrModify, WriteMode::CREATE)); + KJ_EXPECT(!has(createOrModify, WriteMode::CREATE_PARENT)); + KJ_EXPECT(has(createOrModify, createOrModify)); + KJ_EXPECT(!has(createOrModify, createOrModify | WriteMode::CREATE_PARENT)); + KJ_EXPECT(!has(createOrModify, WriteMode::CREATE | WriteMode::CREATE_PARENT)); + KJ_EXPECT(!has(WriteMode::CREATE, createOrModify)); + + KJ_EXPECT(createOrModify != WriteMode::MODIFY); + KJ_EXPECT(createOrModify != WriteMode::CREATE); + + KJ_EXPECT(createOrModify - WriteMode::CREATE == WriteMode::MODIFY); + KJ_EXPECT(WriteMode::CREATE + WriteMode::MODIFY == createOrModify); + + // Adding existing bit / subtracting non-existing bit are no-ops. + KJ_EXPECT(createOrModify + WriteMode::MODIFY == createOrModify); + KJ_EXPECT(createOrModify - WriteMode::CREATE_PARENT == createOrModify); +} + +// ====================================================================================== + +class TestClock final: public Clock { +public: + void tick() { + time += 1 * SECONDS; + } + + Date now() const override { return time; } + + void expectChanged(const FsNode& file) { + KJ_EXPECT(file.stat().lastModified == time); + time += 1 * SECONDS; + } + void expectUnchanged(const FsNode& file) { + KJ_EXPECT(file.stat().lastModified != time); + } + +private: + Date time = UNIX_EPOCH + 1 * SECONDS; +}; + +KJ_TEST("InMemoryFile") { + TestClock clock; + + auto file = newInMemoryFile(clock); + clock.expectChanged(*file); + + KJ_EXPECT(file->readAllText() == ""); + clock.expectUnchanged(*file); + + file->writeAll("foo"); + clock.expectChanged(*file); + KJ_EXPECT(file->readAllText() == "foo"); + + file->write(3, StringPtr("bar").asBytes()); + clock.expectChanged(*file); + KJ_EXPECT(file->readAllText() == "foobar"); + + file->write(3, StringPtr("baz").asBytes()); + clock.expectChanged(*file); + KJ_EXPECT(file->readAllText() == "foobaz"); + + file->write(9, StringPtr("qux").asBytes()); + clock.expectChanged(*file); + KJ_EXPECT(file->readAllText() == kj::StringPtr("foobaz\0\0\0qux", 12)); + + file->truncate(6); + clock.expectChanged(*file); + KJ_EXPECT(file->readAllText() == "foobaz"); + + file->truncate(18); + clock.expectChanged(*file); + KJ_EXPECT(file->readAllText() == kj::StringPtr("foobaz\0\0\0\0\0\0\0\0\0\0\0\0", 18)); + + { + auto mapping = file->mmap(0, 18); + auto privateMapping = file->mmapPrivate(0, 18); + auto writableMapping = file->mmapWritable(0, 18); + clock.expectUnchanged(*file); + + KJ_EXPECT(mapping.size() == 18); + KJ_EXPECT(privateMapping.size() == 18); + KJ_EXPECT(writableMapping->get().size() == 18); + clock.expectUnchanged(*file); + + KJ_EXPECT(writableMapping->get().begin() == mapping.begin()); + KJ_EXPECT(privateMapping.begin() != mapping.begin()); + + KJ_EXPECT(kj::str(mapping.slice(0, 6).asChars()) == "foobaz"); + KJ_EXPECT(kj::str(privateMapping.slice(0, 6).asChars()) == "foobaz"); + clock.expectUnchanged(*file); + + file->write(0, StringPtr("qux").asBytes()); + clock.expectChanged(*file); + KJ_EXPECT(kj::str(mapping.slice(0, 6).asChars()) == "quxbaz"); + KJ_EXPECT(kj::str(privateMapping.slice(0, 6).asChars()) == "foobaz"); + + file->write(12, StringPtr("corge").asBytes()); + KJ_EXPECT(kj::str(mapping.slice(12, 17).asChars()) == "corge"); + + // Can shrink. + file->truncate(6); + KJ_EXPECT(kj::str(mapping.slice(12, 17).asChars()) == kj::StringPtr("\0\0\0\0\0", 5)); + + // Can regrow. + file->truncate(18); + KJ_EXPECT(kj::str(mapping.slice(12, 17).asChars()) == kj::StringPtr("\0\0\0\0\0", 5)); + + // Can't grow past previoous capacity. + KJ_EXPECT_THROW_MESSAGE("cannot resize the file backing store", file->truncate(100)); + + clock.expectChanged(*file); + writableMapping->changed(writableMapping->get().slice(0, 3)); + clock.expectChanged(*file); + writableMapping->sync(writableMapping->get().slice(0, 3)); + clock.expectChanged(*file); + } + + // But now we can since the mapping is gone. + file->truncate(100); + + file->truncate(6); + clock.expectChanged(*file); + + KJ_EXPECT(file->readAllText() == "quxbaz"); + file->zero(3, 3); + clock.expectChanged(*file); + KJ_EXPECT(file->readAllText() == StringPtr("qux\0\0\0", 6)); +} + +KJ_TEST("InMemoryFile::copy()") { + TestClock clock; + + auto source = newInMemoryFile(clock); + source->writeAll("foobarbaz"); + + auto dest = newInMemoryFile(clock); + dest->writeAll("quxcorge"); + clock.expectChanged(*dest); + + KJ_EXPECT(dest->copy(3, *source, 6, kj::maxValue) == 3); + clock.expectChanged(*dest); + KJ_EXPECT(dest->readAllText() == "quxbazge"); + + KJ_EXPECT(dest->copy(0, *source, 3, 4) == 4); + clock.expectChanged(*dest); + KJ_EXPECT(dest->readAllText() == "barbazge"); + + KJ_EXPECT(dest->copy(0, *source, 128, kj::maxValue) == 0); + clock.expectUnchanged(*dest); + + KJ_EXPECT(dest->copy(4, *source, 3, 0) == 0); + clock.expectUnchanged(*dest); + + String bigString = strArray(repeat("foobar", 10000), ""); + source->truncate(bigString.size() + 1000); + source->write(123, bigString.asBytes()); + + dest->copy(321, *source, 123, bigString.size()); + KJ_EXPECT(dest->readAllText().slice(321) == bigString); +} + +KJ_TEST("File::copy()") { + TestClock clock; + + auto source = newInMemoryFile(clock); + source->writeAll("foobarbaz"); + + auto dest = newInMemoryFile(clock); + dest->writeAll("quxcorge"); + clock.expectChanged(*dest); + + KJ_EXPECT(dest->File::copy(3, *source, 6, kj::maxValue) == 3); + clock.expectChanged(*dest); + KJ_EXPECT(dest->readAllText() == "quxbazge"); + + KJ_EXPECT(dest->File::copy(0, *source, 3, 4) == 4); + clock.expectChanged(*dest); + KJ_EXPECT(dest->readAllText() == "barbazge"); + + KJ_EXPECT(dest->File::copy(0, *source, 128, kj::maxValue) == 0); + clock.expectUnchanged(*dest); + + KJ_EXPECT(dest->File::copy(4, *source, 3, 0) == 0); + clock.expectUnchanged(*dest); + + String bigString = strArray(repeat("foobar", 10000), ""); + source->truncate(bigString.size() + 1000); + source->write(123, bigString.asBytes()); + + dest->File::copy(321, *source, 123, bigString.size()); + KJ_EXPECT(dest->readAllText().slice(321) == bigString); +} + +KJ_TEST("InMemoryDirectory") { + TestClock clock; + + auto dir = newInMemoryDirectory(clock); + clock.expectChanged(*dir); + + KJ_EXPECT(dir->listNames() == nullptr); + KJ_EXPECT(dir->listEntries() == nullptr); + KJ_EXPECT(!dir->exists(Path("foo"))); + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::MODIFY) == nullptr); + clock.expectUnchanged(*dir); + + { + auto file = dir->openFile(Path("foo"), WriteMode::CREATE); + clock.expectChanged(*dir); + file->writeAll("foobar"); + clock.expectUnchanged(*dir); + } + clock.expectUnchanged(*dir); + + KJ_EXPECT(dir->exists(Path("foo"))); + clock.expectUnchanged(*dir); + + { + auto stats = dir->lstat(Path("foo")); + clock.expectUnchanged(*dir); + KJ_EXPECT(stats.type == FsNode::Type::FILE); + KJ_EXPECT(stats.size == 6); + } + + { + auto list = dir->listNames(); + clock.expectUnchanged(*dir); + KJ_ASSERT(list.size() == 1); + KJ_EXPECT(list[0] == "foo"); + } + + { + auto list = dir->listEntries(); + clock.expectUnchanged(*dir); + KJ_ASSERT(list.size() == 1); + KJ_EXPECT(list[0].name == "foo"); + KJ_EXPECT(list[0].type == FsNode::Type::FILE); + } + + KJ_EXPECT(dir->openFile(Path("foo"))->readAllText() == "foobar"); + clock.expectUnchanged(*dir); + + KJ_EXPECT(dir->tryOpenFile(Path({"foo", "bar"}), WriteMode::MODIFY) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path({"bar", "baz"}), WriteMode::MODIFY) == nullptr); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("parent is not a directory", + dir->tryOpenFile(Path({"bar", "baz"}), WriteMode::CREATE)); + clock.expectUnchanged(*dir); + + { + auto file = dir->openFile(Path({"bar", "baz"}), WriteMode::CREATE | WriteMode::CREATE_PARENT); + clock.expectChanged(*dir); + file->writeAll("bazqux"); + clock.expectUnchanged(*dir); + } + clock.expectUnchanged(*dir); + + KJ_EXPECT(dir->openFile(Path({"bar", "baz"}))->readAllText() == "bazqux"); + clock.expectUnchanged(*dir); + + { + auto stats = dir->lstat(Path("bar")); + clock.expectUnchanged(*dir); + KJ_EXPECT(stats.type == FsNode::Type::DIRECTORY); + } + + { + auto list = dir->listNames(); + clock.expectUnchanged(*dir); + KJ_ASSERT(list.size() == 2); + KJ_EXPECT(list[0] == "bar"); + KJ_EXPECT(list[1] == "foo"); + } + + { + auto list = dir->listEntries(); + clock.expectUnchanged(*dir); + KJ_ASSERT(list.size() == 2); + KJ_EXPECT(list[0].name == "bar"); + KJ_EXPECT(list[0].type == FsNode::Type::DIRECTORY); + KJ_EXPECT(list[1].name == "foo"); + KJ_EXPECT(list[1].type == FsNode::Type::FILE); + } + + { + auto subdir = dir->openSubdir(Path("bar")); + clock.expectUnchanged(*dir); + clock.expectUnchanged(*subdir); + + KJ_EXPECT(subdir->openFile(Path("baz"))->readAllText() == "bazqux"); + clock.expectUnchanged(*subdir); + } + + auto subdir = dir->openSubdir(Path("corge"), WriteMode::CREATE); + clock.expectChanged(*dir); + + subdir->openFile(Path("grault"), WriteMode::CREATE)->writeAll("garply"); + clock.expectUnchanged(*dir); + clock.expectChanged(*subdir); + + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "garply"); + + dir->openFile(Path({"corge", "grault"}), WriteMode::CREATE | WriteMode::MODIFY) + ->write(0, StringPtr("rag").asBytes()); + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "ragply"); + clock.expectUnchanged(*dir); + + { + auto replacer = + dir->replaceFile(Path({"corge", "grault"}), WriteMode::CREATE | WriteMode::MODIFY); + clock.expectUnchanged(*subdir); + replacer->get().writeAll("rag"); + clock.expectUnchanged(*subdir); + // Don't commit. + } + clock.expectUnchanged(*subdir); + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "ragply"); + + { + auto replacer = + dir->replaceFile(Path({"corge", "grault"}), WriteMode::CREATE | WriteMode::MODIFY); + clock.expectUnchanged(*subdir); + replacer->get().writeAll("rag"); + clock.expectUnchanged(*subdir); + replacer->commit(); + clock.expectChanged(*subdir); + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "rag"); + } + + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "rag"); + + { + auto appender = dir->appendFile(Path({"corge", "grault"}), WriteMode::MODIFY); + appender->write("waldo", 5); + appender->write("fred", 4); + } + + KJ_EXPECT(dir->openFile(Path({"corge", "grault"}))->readAllText() == "ragwaldofred"); + + KJ_EXPECT(dir->exists(Path("foo"))); + clock.expectUnchanged(*dir); + dir->remove(Path("foo")); + clock.expectChanged(*dir); + KJ_EXPECT(!dir->exists(Path("foo"))); + KJ_EXPECT(!dir->tryRemove(Path("foo"))); + clock.expectUnchanged(*dir); + + KJ_EXPECT(dir->exists(Path({"bar", "baz"}))); + clock.expectUnchanged(*dir); + dir->remove(Path({"bar", "baz"})); + clock.expectUnchanged(*dir); + KJ_EXPECT(!dir->exists(Path({"bar", "baz"}))); + KJ_EXPECT(dir->exists(Path("bar"))); + KJ_EXPECT(!dir->tryRemove(Path({"bar", "baz"}))); + clock.expectUnchanged(*dir); + + KJ_EXPECT(dir->exists(Path("corge"))); + KJ_EXPECT(dir->exists(Path({"corge", "grault"}))); + clock.expectUnchanged(*dir); + dir->remove(Path("corge")); + clock.expectChanged(*dir); + KJ_EXPECT(!dir->exists(Path("corge"))); + KJ_EXPECT(!dir->exists(Path({"corge", "grault"}))); + KJ_EXPECT(!dir->tryRemove(Path("corge"))); + clock.expectUnchanged(*dir); +} + +KJ_TEST("InMemoryDirectory symlinks") { + TestClock clock; + + auto dir = newInMemoryDirectory(clock); + clock.expectChanged(*dir); + + dir->symlink(Path("foo"), "bar/qux/../baz", WriteMode::CREATE); + clock.expectChanged(*dir); + + KJ_EXPECT(!dir->trySymlink(Path("foo"), "bar/qux/../baz", WriteMode::CREATE)); + clock.expectUnchanged(*dir); + + { + auto stats = dir->lstat(Path("foo")); + clock.expectUnchanged(*dir); + KJ_EXPECT(stats.type == FsNode::Type::SYMLINK); + } + + KJ_EXPECT(dir->readlink(Path("foo")) == "bar/qux/../baz"); + + // Broken link into non-existing directory cannot be opened in any mode. + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::CREATE) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::MODIFY) == nullptr); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("parent is not a directory", + dir->tryOpenFile(Path("foo"), WriteMode::CREATE | WriteMode::MODIFY)); + KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("parent is not a directory", + dir->tryOpenFile(Path("foo"), + WriteMode::CREATE | WriteMode::MODIFY | WriteMode::CREATE_PARENT)); + + // Create the directory. + auto subdir = dir->openSubdir(Path("bar"), WriteMode::CREATE); + clock.expectChanged(*dir); + + // Link still points to non-existing file so cannot be open in most modes. + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::CREATE) == nullptr); + KJ_EXPECT(dir->tryOpenFile(Path("foo"), WriteMode::MODIFY) == nullptr); + clock.expectUnchanged(*dir); + + // But... CREATE | MODIFY works. + dir->openFile(Path("foo"), WriteMode::CREATE | WriteMode::MODIFY) + ->writeAll("foobar"); + clock.expectUnchanged(*dir); // Change is only to subdir! + + KJ_EXPECT(dir->openFile(Path({"bar", "baz"}))->readAllText() == "foobar"); + KJ_EXPECT(dir->openFile(Path("foo"))->readAllText() == "foobar"); + KJ_EXPECT(dir->openFile(Path("foo"), WriteMode::MODIFY)->readAllText() == "foobar"); + + // operations that modify the symlink + dir->symlink(Path("foo"), "corge", WriteMode::MODIFY); + KJ_EXPECT(dir->openFile(Path({"bar", "baz"}))->readAllText() == "foobar"); + KJ_EXPECT(dir->readlink(Path("foo")) == "corge"); + KJ_EXPECT(!dir->exists(Path("foo"))); + KJ_EXPECT(dir->lstat(Path("foo")).type == FsNode::Type::SYMLINK); + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); + + dir->remove(Path("foo")); + KJ_EXPECT(!dir->exists(Path("foo"))); + KJ_EXPECT(dir->tryOpenFile(Path("foo")) == nullptr); +} + +KJ_TEST("InMemoryDirectory link") { + TestClock clock; + + auto src = newInMemoryDirectory(clock); + auto dst = newInMemoryDirectory(clock); + + src->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + src->openFile(Path({"foo", "baz", "qux"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("bazqux"); + clock.expectChanged(*src); + clock.expectUnchanged(*dst); + + dst->transfer(Path("link"), WriteMode::CREATE, *src, Path("foo"), TransferMode::LINK); + clock.expectUnchanged(*src); + clock.expectChanged(*dst); + + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); + + KJ_EXPECT(dst->exists(Path({"link", "bar"}))); + src->remove(Path({"foo", "bar"})); + KJ_EXPECT(!dst->exists(Path({"link", "bar"}))); +} + +KJ_TEST("InMemoryDirectory copy") { + TestClock clock; + + auto src = newInMemoryDirectory(clock); + auto dst = newInMemoryDirectory(clock); + + src->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + src->openFile(Path({"foo", "baz", "qux"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("bazqux"); + clock.expectChanged(*src); + clock.expectUnchanged(*dst); + + dst->transfer(Path("link"), WriteMode::CREATE, *src, Path("foo"), TransferMode::COPY); + clock.expectUnchanged(*src); + clock.expectChanged(*dst); + + KJ_EXPECT(src->openFile(Path({"foo", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(src->openFile(Path({"foo", "baz", "qux"}))->readAllText() == "bazqux"); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); + + KJ_EXPECT(dst->exists(Path({"link", "bar"}))); + src->remove(Path({"foo", "bar"})); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); +} + +KJ_TEST("InMemoryDirectory move") { + TestClock clock; + + auto src = newInMemoryDirectory(clock); + auto dst = newInMemoryDirectory(clock); + + src->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("foobar"); + src->openFile(Path({"foo", "baz", "qux"}), WriteMode::CREATE | WriteMode::CREATE_PARENT) + ->writeAll("bazqux"); + clock.expectChanged(*src); + clock.expectUnchanged(*dst); + + dst->transfer(Path("link"), WriteMode::CREATE, *src, Path("foo"), TransferMode::MOVE); + clock.expectChanged(*src); + + KJ_EXPECT(!src->exists(Path({"foo"}))); + KJ_EXPECT(dst->openFile(Path({"link", "bar"}))->readAllText() == "foobar"); + KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); +} + +KJ_TEST("InMemoryDirectory createTemporary") { + TestClock clock; + + auto dir = newInMemoryDirectory(clock); + auto file = dir->createTemporary(); + file->writeAll("foobar"); + KJ_EXPECT(file->readAllText() == "foobar"); + KJ_EXPECT(dir->listNames() == nullptr); +} + +} // namespace +} // namespace kj diff --git a/c++/src/kj/filesystem.c++ b/c++/src/kj/filesystem.c++ new file mode 100644 index 0000000000..1dff22ba21 --- /dev/null +++ b/c++/src/kj/filesystem.c++ @@ -0,0 +1,1743 @@ +// Copyright (c) 2015 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "filesystem.h" +#include "vector.h" +#include "debug.h" +#include "one-of.h" +#include "encoding.h" +#include "refcount.h" +#include "mutex.h" +#include + +namespace kj { + +Path::Path(StringPtr name): Path(heapString(name)) {} +Path::Path(String&& name): parts(heapArray(1)) { + parts[0] = kj::mv(name); + validatePart(parts[0]); +} + +Path::Path(ArrayPtr parts) + : Path(KJ_MAP(p, parts) { return heapString(p); }) {} +Path::Path(Array partsParam) + : Path(kj::mv(partsParam), ALREADY_CHECKED) { + for (auto& p: parts) { + validatePart(p); + } +} + +Path PathPtr::clone() { + return Path(KJ_MAP(p, parts) { return heapString(p); }, Path::ALREADY_CHECKED); +} + +Path Path::parse(StringPtr path) { + KJ_REQUIRE(!path.startsWith("/"), "expected a relative path, got absolute", path) { + // When exceptions are disabled, go on -- the leading '/' will end up ignored. + break; + } + return evalImpl(Vector(countParts(path)), path); +} + +Path Path::parseWin32Api(ArrayPtr text) { + auto utf8 = decodeWideString(text); + return evalWin32Impl(Vector(countPartsWin32(utf8)), utf8, true); +} + +Path PathPtr::append(Path&& suffix) const { + auto newParts = kj::heapArrayBuilder(parts.size() + suffix.parts.size()); + for (auto& p: parts) newParts.add(heapString(p)); + for (auto& p: suffix.parts) newParts.add(kj::mv(p)); + return Path(newParts.finish(), Path::ALREADY_CHECKED); +} +Path Path::append(Path&& suffix) && { + auto newParts = kj::heapArrayBuilder(parts.size() + suffix.parts.size()); + for (auto& p: parts) newParts.add(kj::mv(p)); + for (auto& p: suffix.parts) newParts.add(kj::mv(p)); + return Path(newParts.finish(), ALREADY_CHECKED); +} +Path PathPtr::append(PathPtr suffix) const { + auto newParts = kj::heapArrayBuilder(parts.size() + suffix.parts.size()); + for (auto& p: parts) newParts.add(heapString(p)); + for (auto& p: suffix.parts) newParts.add(heapString(p)); + return Path(newParts.finish(), Path::ALREADY_CHECKED); +} +Path Path::append(PathPtr suffix) && { + auto newParts = kj::heapArrayBuilder(parts.size() + suffix.parts.size()); + for (auto& p: parts) newParts.add(kj::mv(p)); + for (auto& p: suffix.parts) newParts.add(heapString(p)); + return Path(newParts.finish(), ALREADY_CHECKED); +} + +Path PathPtr::eval(StringPtr pathText) const { + if (pathText.startsWith("/")) { + // Optimization: avoid copying parts that will just be dropped. + return Path::evalImpl(Vector(Path::countParts(pathText)), pathText); + } else { + Vector newParts(parts.size() + Path::countParts(pathText)); + for (auto& p: parts) newParts.add(heapString(p)); + return Path::evalImpl(kj::mv(newParts), pathText); + } +} +Path Path::eval(StringPtr pathText) && { + if (pathText.startsWith("/")) { + // Optimization: avoid copying parts that will just be dropped. + return evalImpl(Vector(countParts(pathText)), pathText); + } else { + Vector newParts(parts.size() + countParts(pathText)); + for (auto& p: parts) newParts.add(kj::mv(p)); + return evalImpl(kj::mv(newParts), pathText); + } +} + +PathPtr PathPtr::basename() const { + KJ_REQUIRE(parts.size() > 0, "root path has no basename"); + return PathPtr(parts.slice(parts.size() - 1, parts.size())); +} +Path Path::basename() && { + KJ_REQUIRE(parts.size() > 0, "root path has no basename"); + auto newParts = kj::heapArrayBuilder(1); + newParts.add(kj::mv(parts[parts.size() - 1])); + return Path(newParts.finish(), ALREADY_CHECKED); +} + +PathPtr PathPtr::parent() const { + KJ_REQUIRE(parts.size() > 0, "root path has no parent"); + return PathPtr(parts.slice(0, parts.size() - 1)); +} +Path Path::parent() && { + KJ_REQUIRE(parts.size() > 0, "root path has no parent"); + return Path(KJ_MAP(p, parts.slice(0, parts.size() - 1)) { return kj::mv(p); }, ALREADY_CHECKED); +} + +String PathPtr::toString(bool absolute) const { + if (parts.size() == 0) { + // Special-case empty path. + return absolute ? kj::str("/") : kj::str("."); + } + + size_t size = absolute + (parts.size() - 1); + for (auto& p: parts) size += p.size(); + + String result = kj::heapString(size); + + char* ptr = result.begin(); + bool leadingSlash = absolute; + for (auto& p: parts) { + if (leadingSlash) *ptr++ = '/'; + leadingSlash = true; + memcpy(ptr, p.begin(), p.size()); + ptr += p.size(); + } + KJ_ASSERT(ptr == result.end()); + + return result; +} + +Path Path::slice(size_t start, size_t end) && { + return Path(KJ_MAP(p, parts.slice(start, end)) { return kj::mv(p); }); +} + +bool PathPtr::operator==(PathPtr other) const { + return parts == other.parts; +} +bool PathPtr::operator< (PathPtr other) const { + for (size_t i = 0; i < kj::min(parts.size(), other.parts.size()); i++) { + int comp = strcmp(parts[i].cStr(), other.parts[i].cStr()); + if (comp < 0) return true; + if (comp > 0) return false; + } + + return parts.size() < other.parts.size(); +} + +bool PathPtr::startsWith(PathPtr prefix) const { + return parts.size() >= prefix.parts.size() && + parts.slice(0, prefix.parts.size()) == prefix.parts; +} + +bool PathPtr::endsWith(PathPtr suffix) const { + return parts.size() >= suffix.parts.size() && + parts.slice(parts.size() - suffix.parts.size(), parts.size()) == suffix.parts; +} + +Path PathPtr::evalWin32(StringPtr pathText) const { + Vector newParts(parts.size() + Path::countPartsWin32(pathText)); + for (auto& p: parts) newParts.add(heapString(p)); + return Path::evalWin32Impl(kj::mv(newParts), pathText); +} +Path Path::evalWin32(StringPtr pathText) && { + Vector newParts(parts.size() + countPartsWin32(pathText)); + for (auto& p: parts) newParts.add(kj::mv(p)); + return evalWin32Impl(kj::mv(newParts), pathText); +} + +String PathPtr::toWin32StringImpl(bool absolute, bool forApi) const { + if (parts.size() == 0) { + // Special-case empty path. + KJ_REQUIRE(!absolute, "absolute path is missing disk designator") { + break; + } + return absolute ? kj::str("\\\\") : kj::str("."); + } + + bool isUncPath = false; + if (absolute) { + if (Path::isWin32Drive(parts[0])) { + // It's a win32 drive + } else if (Path::isNetbiosName(parts[0])) { + isUncPath = true; + } else { + KJ_FAIL_REQUIRE("absolute win32 path must start with drive letter or netbios host name", + parts[0]); + } + } else { + // Currently we do nothing differently in the forApi case for relative paths. + forApi = false; + } + + size_t size = forApi + ? (isUncPath ? 8 : 4) + (parts.size() - 1) + : (isUncPath ? 2 : 0) + (parts.size() - 1); + for (auto& p: parts) size += p.size(); + + String result = heapString(size); + + char* ptr = result.begin(); + + if (forApi) { + *ptr++ = '\\'; + *ptr++ = '\\'; + *ptr++ = '?'; + *ptr++ = '\\'; + if (isUncPath) { + *ptr++ = 'U'; + *ptr++ = 'N'; + *ptr++ = 'C'; + *ptr++ = '\\'; + } + } else { + if (isUncPath) { + *ptr++ = '\\'; + *ptr++ = '\\'; + } + } + + bool leadingSlash = false; + for (auto& p: parts) { + if (leadingSlash) *ptr++ = '\\'; + leadingSlash = true; + + KJ_REQUIRE(!Path::isWin32Special(p), "path cannot contain DOS reserved name", p) { + // Recover by blotting out the name with invalid characters which Win32 syscalls will reject. + for (size_t i = 0; i < p.size(); i++) { + *ptr++ = '|'; + } + goto skip; + } + + memcpy(ptr, p.begin(), p.size()); + ptr += p.size(); + skip:; + } + + KJ_ASSERT(ptr == result.end()); + + // Check for colons (other than in drive letter), which on NTFS would be interpreted as an + // "alternate data stream", which can lead to surprising results. If we want to support ADS, we + // should do so using an explicit API. Note that this check also prevents a relative path from + // appearing to start with a drive letter. + for (size_t i: kj::indices(result)) { + if (result[i] == ':') { + if (absolute && i == (forApi ? 5 : 1)) { + // False alarm: this is the drive letter. + } else { + KJ_FAIL_REQUIRE( + "colons are prohibited in win32 paths to avoid triggering alterante data streams", + result) { + // Recover by using a different character which we know Win32 syscalls will reject. + result[i] = '|'; + break; + } + } + } + } + + return result; +} + +Array PathPtr::forWin32Api(bool absolute) const { + return encodeWideString(toWin32StringImpl(absolute, true), true); +} + +// ----------------------------------------------------------------------------- + +String Path::stripNul(String input) { + kj::Vector output(input.size()); + for (char c: input) { + if (c != '\0') output.add(c); + } + output.add('\0'); + return String(output.releaseAsArray()); +} + +void Path::validatePart(StringPtr part) { + KJ_REQUIRE(part != "" && part != "." && part != "..", "invalid path component", part); + KJ_REQUIRE(strlen(part.begin()) == part.size(), "NUL character in path component", part); + KJ_REQUIRE(part.findFirst('/') == nullptr, + "'/' character in path component; did you mean to use Path::parse()?", part); +} + +void Path::evalPart(Vector& parts, ArrayPtr part) { + if (part.size() == 0) { + // Ignore consecutive or trailing '/'s. + } else if (part.size() == 1 && part[0] == '.') { + // Refers to current directory; ignore. + } else if (part.size() == 2 && part[0] == '.' && part [1] == '.') { + KJ_REQUIRE(parts.size() > 0, "can't use \"..\" to break out of starting directory") { + // When exceptions are disabled, ignore. + return; + } + parts.removeLast(); + } else { + auto str = heapString(part); + KJ_REQUIRE(strlen(str.begin()) == str.size(), "NUL character in path component", str) { + // When exceptions are disabled, strip out '\0' chars. + str = stripNul(kj::mv(str)); + break; + } + parts.add(kj::mv(str)); + } +} + +Path Path::evalImpl(Vector&& parts, StringPtr path) { + if (path.startsWith("/")) { + parts.clear(); + } + + size_t partStart = 0; + for (auto i: kj::indices(path)) { + if (path[i] == '/') { + evalPart(parts, path.slice(partStart, i)); + partStart = i + 1; + } + } + evalPart(parts, path.slice(partStart)); + + return Path(parts.releaseAsArray(), Path::ALREADY_CHECKED); +} + +Path Path::evalWin32Impl(Vector&& parts, StringPtr path, bool fromApi) { + // Convert all forward slashes to backslashes. + String ownPath; + if (!fromApi && path.findFirst('/') != nullptr) { + ownPath = heapString(path); + for (char& c: ownPath) { + if (c == '/') c = '\\'; + } + path = ownPath; + } + + // Interpret various forms of absolute paths. + if (fromApi && path.startsWith("\\\\?\\")) { + path = path.slice(4); + if (path.startsWith("UNC\\")) { + path = path.slice(4); + } + + // The path is absolute. + parts.clear(); + } else if (path.startsWith("\\\\")) { + // UNC path. + path = path.slice(2); + + // This path is absolute. The first component is a server name. + parts.clear(); + } else if (path.startsWith("\\")) { + KJ_REQUIRE(!fromApi, "parseWin32Api() requires absolute path"); + + // Path is relative to the current drive / network share. + if (parts.size() >= 1 && isWin32Drive(parts[0])) { + // Leading \ interpreted as root of current drive. + parts.truncate(1); + } else if (parts.size() >= 2) { + // Leading \ interpreted as root of current network share (which is indicated by the first + // *two* components of the path). + parts.truncate(2); + } else { + KJ_FAIL_REQUIRE("must specify drive letter", path) { + // Recover by assuming C drive. + parts.clear(); + parts.add(kj::str("c:")); + break; + } + } + } else if ((path.size() == 2 || (path.size() > 2 && path[2] == '\\')) && + isWin32Drive(path.slice(0, 2))) { + // Starts with a drive letter. + parts.clear(); + } else { + KJ_REQUIRE(!fromApi, "parseWin32Api() requires absolute path"); + } + + size_t partStart = 0; + for (auto i: kj::indices(path)) { + if (path[i] == '\\') { + evalPart(parts, path.slice(partStart, i)); + partStart = i + 1; + } + } + evalPart(parts, path.slice(partStart)); + + return Path(parts.releaseAsArray(), Path::ALREADY_CHECKED); +} + +size_t Path::countParts(StringPtr path) { + size_t result = 1; + for (char c: path) { + result += (c == '/'); + } + return result; +} + +size_t Path::countPartsWin32(StringPtr path) { + size_t result = 1; + for (char c: path) { + result += (c == '/' || c == '\\'); + } + return result; +} + +bool Path::isWin32Drive(ArrayPtr part) { + return part.size() == 2 && part[1] == ':' && + (('a' <= part[0] && part[0] <= 'z') || ('A' <= part[0] && part[0] <= 'Z')); +} + +bool Path::isNetbiosName(ArrayPtr part) { + // Characters must be alphanumeric or '.' or '-'. + for (char c: part) { + if (c != '.' && c != '-' && + (c < 'a' || 'z' < c) && + (c < 'A' || 'Z' < c) && + (c < '0' || '9' < c)) { + return false; + } + } + + // Can't be empty nor start or end with a '.' or a '-'. + return part.size() > 0 && + part[0] != '.' && part[0] != '-' && + part[part.size() - 1] != '.' && part[part.size() - 1] != '-'; +} + +bool Path::isWin32Special(StringPtr part) { + bool isNumbered; + if (part.size() == 3 || (part.size() > 3 && part[3] == '.')) { + // Filename is three characters or three characters followed by an extension. + isNumbered = false; + } else if ((part.size() == 4 || (part.size() > 4 && part[4] == '.')) && + '1' <= part[3] && part[3] <= '9') { + // Filename is four characters or four characters followed by an extension, and the fourth + // character is a nonzero digit. + isNumbered = true; + } else { + return false; + } + + // OK, this could be a Win32 special filename. We need to match the first three letters against + // the list of specials, case-insensitively. + char tmp[4]; + memcpy(tmp, part.begin(), 3); + tmp[3] = '\0'; + for (char& c: tmp) { + if ('A' <= c && c <= 'Z') { + c += 'a' - 'A'; + } + } + + StringPtr str(tmp, 3); + if (isNumbered) { + // Specials that are followed by a digit. + return str == "com" || str == "lpt"; + } else { + // Specials that are not followed by a digit. + return str == "con" || str == "prn" || str == "aux" || str == "nul"; + } +} + +// ======================================================================================= + +String ReadableFile::readAllText() const { + String result = heapString(stat().size); + size_t n = read(0, result.asBytes()); + if (n < result.size()) { + // Apparently file was truncated concurrently. Reduce to new size to match. + result = heapString(result.slice(0, n)); + } + return result; +} + +Array ReadableFile::readAllBytes() const { + Array result = heapArray(stat().size); + size_t n = read(0, result.asBytes()); + if (n < result.size()) { + // Apparently file was truncated concurrently. Reduce to new size to match. + result = heapArray(result.slice(0, n)); + } + return result; +} + +void File::writeAll(ArrayPtr bytes) const { + truncate(0); + write(0, bytes); +} + +void File::writeAll(StringPtr text) const { + writeAll(text.asBytes()); +} + +size_t File::copy(uint64_t offset, const ReadableFile& from, + uint64_t fromOffset, uint64_t size) const { + byte buffer[8192]; + + size_t result = 0; + while (size > 0) { + size_t n = from.read(fromOffset, kj::arrayPtr(buffer, kj::min(sizeof(buffer), size))); + write(offset, arrayPtr(buffer, n)); + result += n; + if (n < sizeof(buffer)) { + // Either we copied the amount requested or we hit EOF. + break; + } + fromOffset += n; + offset += n; + size -= n; + } + + return result; +} + +FsNode::Metadata ReadableDirectory::lstat(PathPtr path) const { + KJ_IF_MAYBE(meta, tryLstat(path)) { + return *meta; + } else { + KJ_FAIL_REQUIRE("no such file or directory", path) { break; } + return FsNode::Metadata(); + } +} + +Own ReadableDirectory::openFile(PathPtr path) const { + KJ_IF_MAYBE(file, tryOpenFile(path)) { + return kj::mv(*file); + } else { + KJ_FAIL_REQUIRE("no such file", path) { break; } + return newInMemoryFile(nullClock()); + } +} + +Own ReadableDirectory::openSubdir(PathPtr path) const { + KJ_IF_MAYBE(dir, tryOpenSubdir(path)) { + return kj::mv(*dir); + } else { + KJ_FAIL_REQUIRE("no such directory", path) { break; } + return newInMemoryDirectory(nullClock()); + } +} + +String ReadableDirectory::readlink(PathPtr path) const { + KJ_IF_MAYBE(p, tryReadlink(path)) { + return kj::mv(*p); + } else { + KJ_FAIL_REQUIRE("not a symlink", path) { break; } + return kj::str("."); + } +} + +Own Directory::openFile(PathPtr path, WriteMode mode) const { + KJ_IF_MAYBE(f, tryOpenFile(path, mode)) { + return kj::mv(*f); + } else if (has(mode, WriteMode::CREATE) && !has(mode, WriteMode::MODIFY)) { + KJ_FAIL_REQUIRE("file already exists", path) { break; } + } else if (has(mode, WriteMode::MODIFY) && !has(mode, WriteMode::CREATE)) { + KJ_FAIL_REQUIRE("file does not exist", path) { break; } + } else if (!has(mode, WriteMode::MODIFY) && !has(mode, WriteMode::CREATE)) { + KJ_FAIL_ASSERT("neither WriteMode::CREATE nor WriteMode::MODIFY was given", path) { break; } + } else { + // Shouldn't happen. + KJ_FAIL_ASSERT("tryOpenFile() returned null despite no preconditions", path) { break; } + } + return newInMemoryFile(nullClock()); +} + +Own Directory::appendFile(PathPtr path, WriteMode mode) const { + KJ_IF_MAYBE(f, tryAppendFile(path, mode)) { + return kj::mv(*f); + } else if (has(mode, WriteMode::CREATE) && !has(mode, WriteMode::MODIFY)) { + KJ_FAIL_REQUIRE("file already exists", path) { break; } + } else if (has(mode, WriteMode::MODIFY) && !has(mode, WriteMode::CREATE)) { + KJ_FAIL_REQUIRE("file does not exist", path) { break; } + } else if (!has(mode, WriteMode::MODIFY) && !has(mode, WriteMode::CREATE)) { + KJ_FAIL_ASSERT("neither WriteMode::CREATE nor WriteMode::MODIFY was given", path) { break; } + } else { + // Shouldn't happen. + KJ_FAIL_ASSERT("tryAppendFile() returned null despite no preconditions", path) { break; } + } + return newFileAppender(newInMemoryFile(nullClock())); +} + +Own Directory::openSubdir(PathPtr path, WriteMode mode) const { + KJ_IF_MAYBE(f, tryOpenSubdir(path, mode)) { + return kj::mv(*f); + } else if (has(mode, WriteMode::CREATE) && !has(mode, WriteMode::MODIFY)) { + KJ_FAIL_REQUIRE("directory already exists", path) { break; } + } else if (has(mode, WriteMode::MODIFY) && !has(mode, WriteMode::CREATE)) { + KJ_FAIL_REQUIRE("directory does not exist", path) { break; } + } else if (!has(mode, WriteMode::MODIFY) && !has(mode, WriteMode::CREATE)) { + KJ_FAIL_ASSERT("neither WriteMode::CREATE nor WriteMode::MODIFY was given", path) { break; } + } else { + // Shouldn't happen. + KJ_FAIL_ASSERT("tryOpenSubdir() returned null despite no preconditions", path) { break; } + } + return newInMemoryDirectory(nullClock()); +} + +void Directory::symlink(PathPtr linkpath, StringPtr content, WriteMode mode) const { + if (!trySymlink(linkpath, content, mode)) { + if (has(mode, WriteMode::CREATE)) { + KJ_FAIL_REQUIRE("path already exists", linkpath) { break; } + } else { + // Shouldn't happen. + KJ_FAIL_ASSERT("symlink() returned null despite no preconditions", linkpath) { break; } + } + } +} + +void Directory::transfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode) const { + if (!tryTransfer(toPath, toMode, fromDirectory, fromPath, mode)) { + if (has(toMode, WriteMode::CREATE)) { + KJ_FAIL_REQUIRE("toPath already exists or fromPath doesn't exist", toPath, fromPath) { + break; + } + } else { + KJ_FAIL_ASSERT("fromPath doesn't exist", fromPath) { break; } + } + } +} + +static void copyContents(const Directory& to, const ReadableDirectory& from); + +static bool tryCopyDirectoryEntry(const Directory& to, PathPtr toPath, WriteMode toMode, + const ReadableDirectory& from, PathPtr fromPath, + FsNode::Type type, bool atomic) { + // TODO(cleanup): Make this reusable? + + switch (type) { + case FsNode::Type::FILE: { + KJ_IF_MAYBE(fromFile, from.tryOpenFile(fromPath)) { + if (atomic) { + auto replacer = to.replaceFile(toPath, toMode); + replacer->get().copy(0, **fromFile, 0, kj::maxValue); + return replacer->tryCommit(); + } else KJ_IF_MAYBE(toFile, to.tryOpenFile(toPath, toMode)) { + toFile->get()->copy(0, **fromFile, 0, kj::maxValue); + return true; + } else { + return false; + } + } else { + // Apparently disappeared. Treat as source-doesn't-exist. + return false; + } + } + case FsNode::Type::DIRECTORY: + KJ_IF_MAYBE(fromSubdir, from.tryOpenSubdir(fromPath)) { + if (atomic) { + auto replacer = to.replaceSubdir(toPath, toMode); + copyContents(replacer->get(), **fromSubdir); + return replacer->tryCommit(); + } else KJ_IF_MAYBE(toSubdir, to.tryOpenSubdir(toPath, toMode)) { + copyContents(**toSubdir, **fromSubdir); + return true; + } else { + return false; + } + } else { + // Apparently disappeared. Treat as source-doesn't-exist. + return false; + } + case FsNode::Type::SYMLINK: + KJ_IF_MAYBE(content, from.tryReadlink(fromPath)) { + return to.trySymlink(toPath, *content, toMode); + } else { + // Apparently disappeared. Treat as source-doesn't-exist. + return false; + } + break; + + default: + // Note: Unclear whether it's better to throw an error here or just ignore it / log a + // warning. Can reconsider when we see an actual use case. + KJ_FAIL_REQUIRE("can only copy files, directories, and symlinks", fromPath) { + return false; + } + } +} + +static void copyContents(const Directory& to, const ReadableDirectory& from) { + for (auto& entry: from.listEntries()) { + Path subPath(kj::mv(entry.name)); + tryCopyDirectoryEntry(to, subPath, WriteMode::CREATE, from, subPath, entry.type, false); + } +} + +bool Directory::tryTransfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode) const { + KJ_REQUIRE(toPath.size() > 0, "can't replace self") { return false; } + + // First try reversing. + KJ_IF_MAYBE(result, fromDirectory.tryTransferTo(*this, toPath, toMode, fromPath, mode)) { + return *result; + } + + switch (mode) { + case TransferMode::COPY: + KJ_IF_MAYBE(meta, fromDirectory.tryLstat(fromPath)) { + return tryCopyDirectoryEntry(*this, toPath, toMode, fromDirectory, + fromPath, meta->type, true); + } else { + // Source doesn't exist. + return false; + } + case TransferMode::MOVE: + // Implement move as copy-then-delete. + if (!tryTransfer(toPath, toMode, fromDirectory, fromPath, TransferMode::COPY)) { + return false; + } + fromDirectory.remove(fromPath); + return true; + case TransferMode::LINK: + KJ_FAIL_REQUIRE("can't link across different Directory implementations") { return false; } + } + + KJ_UNREACHABLE; +} + +Maybe Directory::tryTransferTo(const Directory& toDirectory, PathPtr toPath, WriteMode toMode, + PathPtr fromPath, TransferMode mode) const { + return nullptr; +} + +void Directory::remove(PathPtr path) const { + if (!tryRemove(path)) { + KJ_FAIL_REQUIRE("path to remove doesn't exist", path) { break; } + } +} + +void Directory::commitFailed(WriteMode mode) { + if (has(mode, WriteMode::CREATE) && !has(mode, WriteMode::MODIFY)) { + KJ_FAIL_REQUIRE("replace target already exists") { break; } + } else if (has(mode, WriteMode::MODIFY) && !has(mode, WriteMode::CREATE)) { + KJ_FAIL_REQUIRE("replace target does not exist") { break; } + } else if (!has(mode, WriteMode::MODIFY) && !has(mode, WriteMode::CREATE)) { + KJ_FAIL_ASSERT("neither WriteMode::CREATE nor WriteMode::MODIFY was given") { break; } + } else { + KJ_FAIL_ASSERT("tryCommit() returned null despite no preconditions") { break; } + } +} + +// ======================================================================================= + +namespace { + +class InMemoryFile final: public File, public AtomicRefcounted { +public: + InMemoryFile(const Clock& clock): impl(clock) {} + + Own cloneFsNode() const override { + return atomicAddRef(*this); + } + + Maybe getFd() const override { + return nullptr; + } + + Metadata stat() const override { + auto lock = impl.lockShared(); + uint64_t hash = reinterpret_cast(this); + return Metadata { Type::FILE, lock->size, lock->size, lock->lastModified, 1, hash }; + } + + void sync() const override {} + void datasync() const override {} + // no-ops + + size_t read(uint64_t offset, ArrayPtr buffer) const override { + auto lock = impl.lockShared(); + if (offset >= lock->size) { + // Entirely out-of-range. + return 0; + } + + size_t readSize = kj::min(buffer.size(), lock->size - offset); + memcpy(buffer.begin(), lock->bytes.begin() + offset, readSize); + return readSize; + } + + Array mmap(uint64_t offset, uint64_t size) const override { + KJ_REQUIRE(offset + size >= offset, "mmap() request overflows uint64"); + auto lock = impl.lockExclusive(); + lock->ensureCapacity(offset + size); + + ArrayDisposer* disposer = new MmapDisposer(atomicAddRef(*this)); + return Array(lock->bytes.begin() + offset, size, *disposer); + } + + Array mmapPrivate(uint64_t offset, uint64_t size) const override { + // Return a copy. + + // Allocate exactly the size requested. + auto result = heapArray(size); + + // Use read() to fill it. + size_t actual = read(offset, result); + + // Ignore the rest. + if (actual < size) { + memset(result.begin() + actual, 0, size - actual); + } + + return result; + } + + void write(uint64_t offset, ArrayPtr data) const override { + if (data.size() == 0) return; + auto lock = impl.lockExclusive(); + lock->modified(); + uint64_t end = offset + data.size(); + KJ_REQUIRE(end >= offset, "write() request overflows uint64"); + lock->ensureCapacity(end); + lock->size = kj::max(lock->size, end); + memcpy(lock->bytes.begin() + offset, data.begin(), data.size()); + } + + void zero(uint64_t offset, uint64_t zeroSize) const override { + if (zeroSize == 0) return; + auto lock = impl.lockExclusive(); + lock->modified(); + uint64_t end = offset + zeroSize; + KJ_REQUIRE(end >= offset, "zero() request overflows uint64"); + lock->ensureCapacity(end); + lock->size = kj::max(lock->size, end); + memset(lock->bytes.begin() + offset, 0, zeroSize); + } + + void truncate(uint64_t newSize) const override { + auto lock = impl.lockExclusive(); + if (newSize < lock->size) { + lock->modified(); + memset(lock->bytes.begin() + newSize, 0, lock->size - newSize); + lock->size = newSize; + } else if (newSize > lock->size) { + lock->modified(); + lock->ensureCapacity(newSize); + lock->size = newSize; + } + } + + Own mmapWritable(uint64_t offset, uint64_t size) const override { + uint64_t end = offset + size; + KJ_REQUIRE(end >= offset, "mmapWritable() request overflows uint64"); + auto lock = impl.lockExclusive(); + lock->ensureCapacity(end); + return heap(atomicAddRef(*this), lock->bytes.slice(offset, end)); + } + + size_t copy(uint64_t offset, const ReadableFile& from, + uint64_t fromOffset, uint64_t copySize) const override { + size_t fromFileSize = from.stat().size; + if (fromFileSize <= fromOffset) return 0; + + // Clamp size to EOF. + copySize = kj::min(copySize, fromFileSize - fromOffset); + if (copySize == 0) return 0; + + auto lock = impl.lockExclusive(); + + // Allocate space for the copy. + uint64_t end = offset + copySize; + lock->ensureCapacity(end); + + // Read directly into our backing store. + size_t n = from.read(fromOffset, lock->bytes.slice(offset, end)); + lock->size = kj::max(lock->size, offset + n); + + lock->modified(); + return n; + } + +private: + struct Impl { + const Clock& clock; + Array bytes; + size_t size = 0; // bytes may be larger than this to accommodate mmaps + Date lastModified; + uint mmapCount = 0; // number of mappings outstanding + + Impl(const Clock& clock): clock(clock), lastModified(clock.now()) {} + + void ensureCapacity(size_t capacity) { + if (bytes.size() < capacity) { + KJ_ASSERT(mmapCount == 0, + "InMemoryFile cannot resize the file backing store while memory mappings exist."); + + auto newBytes = heapArray(kj::max(capacity, bytes.size() * 2)); + if (size > 0) { // placate ubsan; bytes.begin() might be null + memcpy(newBytes.begin(), bytes.begin(), size); + } + memset(newBytes.begin() + size, 0, newBytes.size() - size); + bytes = kj::mv(newBytes); + } + } + + void modified() { + lastModified = clock.now(); + } + }; + kj::MutexGuarded impl; + + class MmapDisposer final: public ArrayDisposer { + public: + MmapDisposer(Own&& refParam): ref(kj::mv(refParam)) { + ++ref->impl.getAlreadyLockedExclusive().mmapCount; + } + ~MmapDisposer() noexcept(false) { + --ref->impl.lockExclusive()->mmapCount; + } + + void disposeImpl(void* firstElement, size_t elementSize, size_t elementCount, + size_t capacity, void (*destroyElement)(void*)) const override { + delete this; + } + + private: + Own ref; + }; + + class WritableFileMappingImpl final: public WritableFileMapping { + public: + WritableFileMappingImpl(Own&& refParam, ArrayPtr range) + : ref(kj::mv(refParam)), range(range) { + ++ref->impl.getAlreadyLockedExclusive().mmapCount; + } + ~WritableFileMappingImpl() noexcept(false) { + --ref->impl.lockExclusive()->mmapCount; + } + + ArrayPtr get() const override { + // const_cast OK because WritableFileMapping does indeed provide a writable view despite + // being const itself. + return arrayPtr(const_cast(range.begin()), range.size()); + } + + void changed(ArrayPtr slice) const override { + ref->impl.lockExclusive()->modified(); + } + + void sync(ArrayPtr slice) const override { + ref->impl.lockExclusive()->modified(); + } + + private: + Own ref; + ArrayPtr range; + }; +}; + +// ----------------------------------------------------------------------------- + +class InMemoryDirectory final: public Directory, public AtomicRefcounted { +public: + InMemoryDirectory(const Clock& clock): impl(clock) {} + + Own cloneFsNode() const override { + return atomicAddRef(*this); + } + + Maybe getFd() const override { + return nullptr; + } + + Metadata stat() const override { + auto lock = impl.lockShared(); + uint64_t hash = reinterpret_cast(this); + return Metadata { Type::DIRECTORY, 0, 0, lock->lastModified, 1, hash }; + } + + void sync() const override {} + void datasync() const override {} + // no-ops + + Array listNames() const override { + auto lock = impl.lockShared(); + return KJ_MAP(e, lock->entries) { return heapString(e.first); }; + } + + Array listEntries() const override { + auto lock = impl.lockShared(); + return KJ_MAP(e, lock->entries) { + FsNode::Type type; + if (e.second.node.is()) { + type = FsNode::Type::SYMLINK; + } else if (e.second.node.is()) { + type = FsNode::Type::FILE; + } else { + KJ_ASSERT(e.second.node.is()); + type = FsNode::Type::DIRECTORY; + } + + return Entry { type, heapString(e.first) }; + }; + } + + bool exists(PathPtr path) const override { + if (path.size() == 0) { + return true; + } else if (path.size() == 1) { + auto lock = impl.lockShared(); + KJ_IF_MAYBE(entry, lock->tryGetEntry(path[0])) { + return exists(lock, *entry); + } else { + return false; + } + } else { + KJ_IF_MAYBE(subdir, tryGetParent(path[0])) { + return subdir->get()->exists(path.slice(1, path.size())); + } else { + return false; + } + } + } + + Maybe tryLstat(PathPtr path) const override { + if (path.size() == 0) { + return stat(); + } else if (path.size() == 1) { + auto lock = impl.lockShared(); + KJ_IF_MAYBE(entry, lock->tryGetEntry(path[0])) { + if (entry->node.is()) { + return entry->node.get().file->stat(); + } else if (entry->node.is()) { + return entry->node.get().directory->stat(); + } else if (entry->node.is()) { + auto& link = entry->node.get(); + uint64_t hash = reinterpret_cast(link.content.begin()); + return FsNode::Metadata { FsNode::Type::SYMLINK, 0, 0, link.lastModified, 1, hash }; + } else { + KJ_FAIL_ASSERT("unknown node type") { return nullptr; } + } + } else { + return nullptr; + } + } else { + KJ_IF_MAYBE(subdir, tryGetParent(path[0])) { + return subdir->get()->tryLstat(path.slice(1, path.size())); + } else { + return nullptr; + } + } + } + + Maybe> tryOpenFile(PathPtr path) const override { + if (path.size() == 0) { + KJ_FAIL_REQUIRE("not a file") { return nullptr; } + } else if (path.size() == 1) { + auto lock = impl.lockShared(); + KJ_IF_MAYBE(entry, lock->tryGetEntry(path[0])) { + return asFile(lock, *entry); + } else { + return nullptr; + } + } else { + KJ_IF_MAYBE(subdir, tryGetParent(path[0])) { + return subdir->get()->tryOpenFile(path.slice(1, path.size())); + } else { + return nullptr; + } + } + } + + Maybe> tryOpenSubdir(PathPtr path) const override { + if (path.size() == 0) { + return clone(); + } else if (path.size() == 1) { + auto lock = impl.lockShared(); + KJ_IF_MAYBE(entry, lock->tryGetEntry(path[0])) { + return asDirectory(lock, *entry); + } else { + return nullptr; + } + } else { + KJ_IF_MAYBE(subdir, tryGetParent(path[0])) { + return subdir->get()->tryOpenSubdir(path.slice(1, path.size())); + } else { + return nullptr; + } + } + } + + Maybe tryReadlink(PathPtr path) const override { + if (path.size() == 0) { + KJ_FAIL_REQUIRE("not a symlink") { return nullptr; } + } else if (path.size() == 1) { + auto lock = impl.lockShared(); + KJ_IF_MAYBE(entry, lock->tryGetEntry(path[0])) { + return asSymlink(lock, *entry); + } else { + return nullptr; + } + } else { + KJ_IF_MAYBE(subdir, tryGetParent(path[0])) { + return subdir->get()->tryReadlink(path.slice(1, path.size())); + } else { + return nullptr; + } + } + } + + Maybe> tryOpenFile(PathPtr path, WriteMode mode) const override { + if (path.size() == 0) { + if (has(mode, WriteMode::MODIFY)) { + KJ_FAIL_REQUIRE("not a file") { return nullptr; } + } else if (has(mode, WriteMode::CREATE)) { + return nullptr; // already exists (as a directory) + } else { + KJ_FAIL_REQUIRE("can't replace self") { return nullptr; } + } + } else if (path.size() == 1) { + auto lock = impl.lockExclusive(); + KJ_IF_MAYBE(entry, lock->openEntry(path[0], mode)) { + return asFile(lock, *entry, mode); + } else { + return nullptr; + } + } else { + KJ_IF_MAYBE(child, tryGetParent(path[0], mode)) { + return child->get()->tryOpenFile(path.slice(1, path.size()), mode); + } else { + return nullptr; + } + } + } + + Own> replaceFile(PathPtr path, WriteMode mode) const override { + if (path.size() == 0) { + KJ_FAIL_REQUIRE("can't replace self") { break; } + } else if (path.size() == 1) { + // don't need lock just to read the clock ref + return heap>(*this, path[0], + newInMemoryFile(impl.getWithoutLock().clock), mode); + } else { + KJ_IF_MAYBE(child, tryGetParent(path[0], mode)) { + return child->get()->replaceFile(path.slice(1, path.size()), mode); + } + } + return heap>(newInMemoryFile(impl.getWithoutLock().clock)); + } + + Maybe> tryOpenSubdir(PathPtr path, WriteMode mode) const override { + if (path.size() == 0) { + if (has(mode, WriteMode::MODIFY)) { + return atomicAddRef(*this); + } else if (has(mode, WriteMode::CREATE)) { + return nullptr; // already exists + } else { + KJ_FAIL_REQUIRE("can't replace self") { return nullptr; } + } + } else if (path.size() == 1) { + auto lock = impl.lockExclusive(); + KJ_IF_MAYBE(entry, lock->openEntry(path[0], mode)) { + return asDirectory(lock, *entry, mode); + } else { + return nullptr; + } + } else { + KJ_IF_MAYBE(child, tryGetParent(path[0], mode)) { + return child->get()->tryOpenSubdir(path.slice(1, path.size()), mode); + } else { + return nullptr; + } + } + } + + Own> replaceSubdir(PathPtr path, WriteMode mode) const override { + if (path.size() == 0) { + KJ_FAIL_REQUIRE("can't replace self") { break; } + } else if (path.size() == 1) { + // don't need lock just to read the clock ref + return heap>(*this, path[0], + newInMemoryDirectory(impl.getWithoutLock().clock), mode); + } else { + KJ_IF_MAYBE(child, tryGetParent(path[0], mode)) { + return child->get()->replaceSubdir(path.slice(1, path.size()), mode); + } + } + return heap>(newInMemoryDirectory(impl.getWithoutLock().clock)); + } + + Maybe> tryAppendFile(PathPtr path, WriteMode mode) const override { + if (path.size() == 0) { + if (has(mode, WriteMode::MODIFY)) { + KJ_FAIL_REQUIRE("not a file") { return nullptr; } + } else if (has(mode, WriteMode::CREATE)) { + return nullptr; // already exists (as a directory) + } else { + KJ_FAIL_REQUIRE("can't replace self") { return nullptr; } + } + } else if (path.size() == 1) { + auto lock = impl.lockExclusive(); + KJ_IF_MAYBE(entry, lock->openEntry(path[0], mode)) { + return asFile(lock, *entry, mode).map(newFileAppender); + } else { + return nullptr; + } + } else { + KJ_IF_MAYBE(child, tryGetParent(path[0], mode)) { + return child->get()->tryAppendFile(path.slice(1, path.size()), mode); + } else { + return nullptr; + } + } + } + + bool trySymlink(PathPtr path, StringPtr content, WriteMode mode) const override { + if (path.size() == 0) { + if (has(mode, WriteMode::CREATE)) { + return false; + } else { + KJ_FAIL_REQUIRE("can't replace self") { return false; } + } + } else if (path.size() == 1) { + auto lock = impl.lockExclusive(); + KJ_IF_MAYBE(entry, lock->openEntry(path[0], mode)) { + entry->init(SymlinkNode { lock->clock.now(), heapString(content) }); + lock->modified(); + return true; + } else { + return false; + } + } else { + KJ_IF_MAYBE(child, tryGetParent(path[0], mode)) { + return child->get()->trySymlink(path.slice(1, path.size()), content, mode); + } else { + KJ_FAIL_REQUIRE("couldn't create parent directory") { return false; } + } + } + } + + Own createTemporary() const override { + // Don't need lock just to read the clock ref. + return newInMemoryFile(impl.getWithoutLock().clock); + } + + bool tryTransfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode) const override { + if (toPath.size() == 0) { + if (has(toMode, WriteMode::CREATE)) { + return false; + } else { + KJ_FAIL_REQUIRE("can't replace self") { return false; } + } + } else if (toPath.size() == 1) { + // tryTransferChild() needs to at least know the node type, so do an lstat. + KJ_IF_MAYBE(meta, fromDirectory.tryLstat(fromPath)) { + auto lock = impl.lockExclusive(); + KJ_IF_MAYBE(entry, lock->openEntry(toPath[0], toMode)) { + // Make sure if we just cerated a new entry, and we don't successfully transfer to it, we + // remove the entry before returning. + bool needRollback = entry->node == nullptr; + KJ_DEFER(if (needRollback) { lock->entries.erase(toPath[0]); }); + + if (lock->tryTransferChild(*entry, meta->type, meta->lastModified, meta->size, + fromDirectory, fromPath, mode)) { + lock->modified(); + needRollback = false; + return true; + } else { + KJ_FAIL_REQUIRE("InMemoryDirectory can't link an inode of this type", fromPath) { + return false; + } + } + } else { + return false; + } + } else { + return false; + } + } else { + // TODO(someday): Ideally we wouldn't create parent directories if fromPath doesn't exist. + // This requires a different approach to the code here, though. + KJ_IF_MAYBE(child, tryGetParent(toPath[0], toMode)) { + return child->get()->tryTransfer( + toPath.slice(1, toPath.size()), toMode, fromDirectory, fromPath, mode); + } else { + return false; + } + } + } + + Maybe tryTransferTo(const Directory& toDirectory, PathPtr toPath, WriteMode toMode, + PathPtr fromPath, TransferMode mode) const override { + if (fromPath.size() <= 1) { + // If `fromPath` is in this directory (or *is* this directory) then we don't have any + // optimizations. + return nullptr; + } + + // `fromPath` is in a subdirectory. It could turn out that that subdirectory is not an + // InMemoryDirectory and is instead something `toDirectory` is friendly with. So let's follow + // the path. + + KJ_IF_MAYBE(child, tryGetParent(fromPath[0], WriteMode::MODIFY)) { + // OK, switch back to tryTransfer() but use the subdirectory. + return toDirectory.tryTransfer(toPath, toMode, + **child, fromPath.slice(1, fromPath.size()), mode); + } else { + // Hmm, doesn't exist. Fall back to standard path. + return nullptr; + } + } + + bool tryRemove(PathPtr path) const override { + if (path.size() == 0) { + KJ_FAIL_REQUIRE("can't remove self from self") { return false; } + } else if (path.size() == 1) { + auto lock = impl.lockExclusive(); + auto iter = lock->entries.find(path[0]); + if (iter == lock->entries.end()) { + return false; + } else { + lock->entries.erase(iter); + lock->modified(); + return true; + } + } else { + KJ_IF_MAYBE(child, tryGetParent(path[0], WriteMode::MODIFY)) { + return child->get()->tryRemove(path.slice(1, path.size())); + } else { + return false; + } + } + } + +private: + struct FileNode { + Own file; + }; + struct DirectoryNode { + Own directory; + }; + struct SymlinkNode { + Date lastModified; + String content; + + Path parse() const { + KJ_CONTEXT("parsing symlink", content); + return Path::parse(content); + } + }; + + struct EntryImpl { + String name; + OneOf node; + + EntryImpl(String&& name): name(kj::mv(name)) {} + + Own init(FileNode&& value) { + return node.init(kj::mv(value)).file->clone(); + } + Own init(DirectoryNode&& value) { + return node.init(kj::mv(value)).directory->clone(); + } + void init(SymlinkNode&& value) { + node.init(kj::mv(value)); + } + bool init(OneOf&& value) { + node = kj::mv(value); + return node != nullptr; + } + + void set(Own&& value) { + node.init(FileNode { kj::mv(value) }); + } + void set(Own&& value) { + node.init(DirectoryNode { kj::mv(value) }); + } + }; + + template + class ReplacerImpl final: public Replacer { + public: + ReplacerImpl(const InMemoryDirectory& directory, kj::StringPtr name, + Own inner, WriteMode mode) + : Replacer(mode), directory(atomicAddRef(directory)), name(heapString(name)), + inner(kj::mv(inner)) {} + + const T& get() override { return *inner; } + + bool tryCommit() override { + KJ_REQUIRE(!committed, "commit() already called") { return true; } + + auto lock = directory->impl.lockExclusive(); + KJ_IF_MAYBE(entry, lock->openEntry(name, Replacer::mode)) { + entry->set(inner->clone()); + lock->modified(); + return true; + } else { + return false; + } + } + + private: + bool committed = false; + Own directory; + kj::String name; + Own inner; + }; + + template + class BrokenReplacer final: public Replacer { + // For recovery path when exceptions are disabled. + + public: + BrokenReplacer(Own inner) + : Replacer(WriteMode::CREATE | WriteMode::MODIFY), + inner(kj::mv(inner)) {} + + const T& get() override { return *inner; } + bool tryCommit() override { return false; } + + private: + Own inner; + }; + + struct Impl { + const Clock& clock; + + std::map entries; + // Note: If this changes to a non-sorted map, listNames() and listEntries() must be updated to + // sort their results. + + Date lastModified; + + Impl(const Clock& clock): clock(clock), lastModified(clock.now()) {} + + Maybe openEntry(kj::StringPtr name, WriteMode mode) { + // TODO(perf): We could avoid a copy if the entry exists, at the expense of a double-lookup + // if it doesn't. Maybe a better map implementation will solve everything? + return openEntry(heapString(name), mode); + } + + Maybe openEntry(String&& name, WriteMode mode) { + if (has(mode, WriteMode::CREATE)) { + EntryImpl entry(kj::mv(name)); + StringPtr nameRef = entry.name; + auto insertResult = entries.insert(std::make_pair(nameRef, kj::mv(entry))); + + if (!insertResult.second && !has(mode, WriteMode::MODIFY)) { + // Entry already existed and MODIFY not specified. + return nullptr; + } + + return insertResult.first->second; + } else if (has(mode, WriteMode::MODIFY)) { + return tryGetEntry(name); + } else { + // Neither CREATE nor MODIFY specified: precondition always fails. + return nullptr; + } + } + + kj::Maybe tryGetEntry(kj::StringPtr name) const { + auto iter = entries.find(name); + if (iter == entries.end()) { + return nullptr; + } else { + return iter->second; + } + } + + kj::Maybe tryGetEntry(kj::StringPtr name) { + auto iter = entries.find(name); + if (iter == entries.end()) { + return nullptr; + } else { + return iter->second; + } + } + + void modified() { + lastModified = clock.now(); + } + + bool tryTransferChild(EntryImpl& entry, const FsNode::Type type, kj::Maybe lastModified, + kj::Maybe size, const Directory& fromDirectory, + PathPtr fromPath, TransferMode mode) { + switch (type) { + case FsNode::Type::FILE: + KJ_IF_MAYBE(file, fromDirectory.tryOpenFile(fromPath, WriteMode::MODIFY)) { + if (mode == TransferMode::COPY) { + auto copy = newInMemoryFile(clock); + copy->copy(0, **file, 0, size.orDefault(kj::maxValue)); + entry.set(kj::mv(copy)); + } else { + if (mode == TransferMode::MOVE) { + KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { + return false; + } + } + entry.set(kj::mv(*file)); + } + return true; + } else { + KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { + return false; + } + } + case FsNode::Type::DIRECTORY: + KJ_IF_MAYBE(subdir, fromDirectory.tryOpenSubdir(fromPath, WriteMode::MODIFY)) { + if (mode == TransferMode::COPY) { + auto copy = atomicRefcounted(clock); + auto& cpim = copy->impl.getWithoutLock(); // safe because just-created + for (auto& subEntry: subdir->get()->listEntries()) { + EntryImpl newEntry(kj::mv(subEntry.name)); + Path filename(newEntry.name); + if (!cpim.tryTransferChild(newEntry, subEntry.type, nullptr, nullptr, **subdir, + filename, TransferMode::COPY)) { + KJ_LOG(ERROR, "couldn't copy node of type not supported by InMemoryDirectory", + filename); + } else { + StringPtr nameRef = newEntry.name; + cpim.entries.insert(std::make_pair(nameRef, kj::mv(newEntry))); + } + } + entry.set(kj::mv(copy)); + } else { + if (mode == TransferMode::MOVE) { + KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { + return false; + } + } + entry.set(kj::mv(*subdir)); + } + return true; + } else { + KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { + return false; + } + } + case FsNode::Type::SYMLINK: + KJ_IF_MAYBE(content, fromDirectory.tryReadlink(fromPath)) { + // Since symlinks are immutable, we can implement LINK the same as COPY. + entry.init(SymlinkNode { lastModified.orDefault(clock.now()), kj::mv(*content) }); + if (mode == TransferMode::MOVE) { + KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { + return false; + } + } + return true; + } else { + KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { + return false; + } + } + default: + return false; + } + } + }; + + kj::MutexGuarded impl; + + bool exists(kj::Locked& lock, const EntryImpl& entry) const { + if (entry.node.is()) { + auto newPath = entry.node.get().parse(); + lock.release(); + return exists(newPath); + } else { + return true; + } + } + Maybe> asFile( + kj::Locked& lock, const EntryImpl& entry) const { + if (entry.node.is()) { + return entry.node.get().file->clone(); + } else if (entry.node.is()) { + auto newPath = entry.node.get().parse(); + lock.release(); + return tryOpenFile(newPath); + } else { + KJ_FAIL_REQUIRE("not a file") { return nullptr; } + } + } + Maybe> asDirectory( + kj::Locked& lock, const EntryImpl& entry) const { + if (entry.node.is()) { + return entry.node.get().directory->clone(); + } else if (entry.node.is()) { + auto newPath = entry.node.get().parse(); + lock.release(); + return tryOpenSubdir(newPath); + } else { + KJ_FAIL_REQUIRE("not a directory") { return nullptr; } + } + } + Maybe asSymlink(kj::Locked& lock, const EntryImpl& entry) const { + if (entry.node.is()) { + return heapString(entry.node.get().content); + } else { + KJ_FAIL_REQUIRE("not a symlink") { return nullptr; } + } + } + + Maybe> asFile(kj::Locked& lock, EntryImpl& entry, WriteMode mode) const { + if (entry.node.is()) { + return entry.node.get().file->clone(); + } else if (entry.node.is()) { + // CREATE_PARENT doesn't apply to creating the parents of a symlink target. However, the + // target itself can still be created. + auto newPath = entry.node.get().parse(); + lock.release(); + return tryOpenFile(newPath, mode - WriteMode::CREATE_PARENT); + } else if (entry.node == nullptr) { + KJ_ASSERT(has(mode, WriteMode::CREATE)); + lock->modified(); + return entry.init(FileNode { newInMemoryFile(lock->clock) }); + } else { + KJ_FAIL_REQUIRE("not a file") { return nullptr; } + } + } + Maybe> asDirectory( + kj::Locked& lock, EntryImpl& entry, WriteMode mode) const { + if (entry.node.is()) { + return entry.node.get().directory->clone(); + } else if (entry.node.is()) { + // CREATE_PARENT doesn't apply to creating the parents of a symlink target. However, the + // target itself can still be created. + auto newPath = entry.node.get().parse(); + lock.release(); + return tryOpenSubdir(newPath, mode - WriteMode::CREATE_PARENT); + } else if (entry.node == nullptr) { + KJ_ASSERT(has(mode, WriteMode::CREATE)); + lock->modified(); + return entry.init(DirectoryNode { newInMemoryDirectory(lock->clock) }); + } else { + KJ_FAIL_REQUIRE("not a directory") { return nullptr; } + } + } + + kj::Maybe> tryGetParent(kj::StringPtr name) const { + auto lock = impl.lockShared(); + KJ_IF_MAYBE(entry, impl.lockShared()->tryGetEntry(name)) { + return asDirectory(lock, *entry); + } else { + return nullptr; + } + } + + kj::Maybe> tryGetParent(kj::StringPtr name, WriteMode mode) const { + // Get a directory which is a parent of the eventual target. If `mode` includes + // WriteMode::CREATE_PARENTS, possibly create the parent directory. + + auto lock = impl.lockExclusive(); + + WriteMode parentMode = has(mode, WriteMode::CREATE) && has(mode, WriteMode::CREATE_PARENT) + ? WriteMode::CREATE | WriteMode::MODIFY // create parent + : WriteMode::MODIFY; // don't create parent + + // Possibly create parent. + KJ_IF_MAYBE(entry, lock->openEntry(name, parentMode)) { + if (entry->node.is()) { + return entry->node.get().directory->clone(); + } else if (entry->node == nullptr) { + lock->modified(); + return entry->init(DirectoryNode { newInMemoryDirectory(lock->clock) }); + } + // Continue on. + } + + if (has(mode, WriteMode::CREATE)) { + // CREATE is documented as returning null when the file already exists. In this case, the + // file does NOT exist because the parent directory does not exist or is not a directory. + KJ_FAIL_REQUIRE("parent is not a directory") { return nullptr; } + } else { + return nullptr; + } + } +}; + +// ----------------------------------------------------------------------------- + +class AppendableFileImpl final: public AppendableFile { +public: + AppendableFileImpl(Own&& fileParam): file(kj::mv(fileParam)) {} + + Own cloneFsNode() const override { + return heap(file->clone()); + } + + Maybe getFd() const override { + return nullptr; + } + + Metadata stat() const override { + return file->stat(); + } + + void sync() const override { file->sync(); } + void datasync() const override { file->datasync(); } + + void write(const void* buffer, size_t size) override { + file->write(file->stat().size, arrayPtr(reinterpret_cast(buffer), size)); + } + +private: + Own file; +}; + +} // namespace + +// ----------------------------------------------------------------------------- + +Own newInMemoryFile(const Clock& clock) { + return atomicRefcounted(clock); +} +Own newInMemoryDirectory(const Clock& clock) { + return atomicRefcounted(clock); +} +Own newFileAppender(Own inner) { + return heap(kj::mv(inner)); +} + +} // namespace kj diff --git a/c++/src/kj/filesystem.h b/c++/src/kj/filesystem.h new file mode 100644 index 0000000000..323420a442 --- /dev/null +++ b/c++/src/kj/filesystem.h @@ -0,0 +1,1123 @@ +// Copyright (c) 2015 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "memory.h" +#include "io.h" +#include +#include "time.h" +#include "function.h" +#include "hash.h" + +KJ_BEGIN_HEADER + +namespace kj { + +template +class Vector; + +class PathPtr; + +class Path { + // A Path identifies a file in a directory tree. + // + // In KJ, we avoid representing paths as plain strings because this can lead to path injection + // bugs as well as numerous kinds of bugs relating to path parsing edge cases. The Path class's + // interface is designed to "make it hard to screw up". + // + // A "Path" is in fact a list of strings, each string being one component of the path (as would + // normally be separated by '/'s). Path components are not allowed to contain '/' nor '\0', nor + // are they allowed to be the special names "", ".", nor "..". + // + // If you explicitly want to parse a path that contains '/'s, ".", and "..", you must use + // parse() and/or eval(). However, users of this interface are encouraged to avoid parsing + // paths at all, and instead express paths as string arrays. + // + // Note that when using the Path class, ".." is always canonicalized in path space without + // consulting the actual filesystem. This means that "foo/some-symlink/../bar" is exactly + // equivalent to "foo/bar". This differs from the kernel's behavior when resolving paths passed + // to system calls: the kernel would have resolved "some-symlink" to its target physical path, + // and then would have interpreted ".." relative to that. In practice, the kernel's behavior is + // rarely what the user or programmer intended, hence canonicalizing in path space produces a + // better result. + // + // Path objects are "immutable": functions that "modify" the path return a new path. However, + // if the path being operated on is an rvalue, copying can be avoided. Hence it makes sense to + // write code like: + // + // Path p = ...; + // p = kj::mv(p).append("bar"); // in-place update, avoids string copying + +public: + Path(decltype(nullptr)); // empty path + + explicit Path(StringPtr name); + explicit Path(String&& name); + // Create a Path containing only one component. `name` is a single filename; it cannot contain + // '/' nor '\0' nor can it be exactly "" nor "." nor "..". + // + // If you want to allow '/'s and such, you must call Path::parse(). We force you to do this to + // prevent path injection bugs where you didn't consider what would happen if the path contained + // a '/'. + + explicit Path(std::initializer_list parts); + explicit Path(ArrayPtr parts); + explicit Path(Array parts); + // Construct a path from an array. Note that this means you can do: + // + // Path{"foo", "bar", "baz"} // equivalent to Path::parse("foo/bar/baz") + + KJ_DISALLOW_COPY(Path); + Path(Path&&) = default; + Path& operator=(Path&&) = default; + + Path clone() const; + + static Path parse(StringPtr path); + // Parses a path in traditional format. Components are separated by '/'. Any use of "." or + // ".." will be canonicalized (if they can't be canonicalized, e.g. because the path starts with + // "..", an exception is thrown). Multiple consecutive '/'s will be collapsed. A leading '/' + // is NOT accepted -- if that is a problem, you probably want `eval()`. Trailing '/'s are + // ignored. + + Path append(Path&& suffix) const&; + Path append(Path&& suffix) &&; + Path append(PathPtr suffix) const&; + Path append(PathPtr suffix) &&; + Path append(StringPtr suffix) const&; + Path append(StringPtr suffix) &&; + Path append(String&& suffix) const&; + Path append(String&& suffix) &&; + // Create a new path by appending the given path to this path. + // + // `suffix` cannot contain '/' characters. Instead, you can append an array: + // + // path.append({"foo", "bar"}) + // + // Or, use Path::parse(): + // + // path.append(Path::parse("foo//baz/../bar")) + + Path eval(StringPtr pathText) const&; + Path eval(StringPtr pathText) &&; + // Evaluates a traditional path relative to this one. `pathText` is parsed like `parse()` would, + // except that: + // - It can contain leading ".." components that traverse up the tree. + // - It can have a leading '/' which completely replaces the current path. + // + // THE NAME OF THIS METHOD WAS CHOSEN TO INSPIRE FEAR. + // + // Instead of using `path.eval(str)`, always consider whether you really want + // `path.append(Path::parse(str))`. The former is much riskier than the latter in terms of path + // injection vulnerabilities. + + PathPtr basename() const&; + Path basename() &&; + // Get the last component of the path. (Use `basename()[0]` to get just the string.) + + PathPtr parent() const&; + Path parent() &&; + // Get the parent path. + + String toString(bool absolute = false) const; + // Converts the path to a traditional path string, appropriate to pass to a unix system call. + // Never throws. + + const String& operator[](size_t i) const&; + String operator[](size_t i) &&; + size_t size() const; + const String* begin() const; + const String* end() const; + PathPtr slice(size_t start, size_t end) const&; + Path slice(size_t start, size_t end) &&; + // A Path can be accessed as an array of strings. + + bool operator==(PathPtr other) const; + bool operator!=(PathPtr other) const; + bool operator< (PathPtr other) const; + bool operator> (PathPtr other) const; + bool operator<=(PathPtr other) const; + bool operator>=(PathPtr other) const; + // Compare path components lexically. + + bool operator==(const Path& other) const; + bool operator!=(const Path& other) const; + bool operator< (const Path& other) const; + bool operator> (const Path& other) const; + bool operator<=(const Path& other) const; + bool operator>=(const Path& other) const; + + uint hashCode() const; + // Can use in HashMap. + + bool startsWith(PathPtr prefix) const; + bool endsWith(PathPtr suffix) const; + // Compare prefix / suffix. + + Path evalWin32(StringPtr pathText) const&; + Path evalWin32(StringPtr pathText) &&; + // Evaluates a Win32-style path, as might be written by a user. Differences from `eval()` + // include: + // + // - Backslashes can be used as path separators. + // - Absolute paths begin with a drive letter followed by a colon. The drive letter, including + // the colon, will become the first component of the path, e.g. "c:\foo" becomes {"c:", "foo"}. + // - A network path like "\\host\share\path" is parsed as {"host", "share", "path"}. + + Path evalNative(StringPtr pathText) const&; + Path evalNative(StringPtr pathText) &&; + // Alias for either eval() or evalWin32() depending on the target platform. Use this when you are + // parsing a path provided by a user and you want the user to be able to use the "natural" format + // for their platform. + + String toWin32String(bool absolute = false) const; + // Converts the path to a Win32 path string, as you might display to a user. + // + // This is meant for display. For making Win32 system calls, consider `toWin32Api()` instead. + // + // If `absolute` is true, the path is expected to be an absolute path, meaning the first + // component is a drive letter, namespace, or network host name. These are converted to their + // regular Win32 format -- i.e. this method does the reverse of `evalWin32()`. + // + // This throws if the path would have unexpected special meaning or is otherwise invalid on + // Windows, such as if it contains backslashes (within a path component), colons, or special + // names like "con". + + String toNativeString(bool absolute = false) const; + // Alias for either toString() or toWin32String() depending on the target platform. Use this when + // you are formatting a path to display to a user and you want to present it in the "natural" + // format for the user's platform. + + Array forWin32Api(bool absolute) const; + // Like toWin32String, but additionally: + // - Converts the path to UTF-16, with a NUL terminator included. + // - For absolute paths, adds the "\\?\" prefix which opts into permitting paths longer than + // MAX_PATH, and turns off relative path processing (which KJ paths already handle in userspace + // anyway). + // + // This method is good to use when making a Win32 API call, e.g.: + // + // DeleteFileW(path.forWin32Api(true).begin()); + + static Path parseWin32Api(ArrayPtr text); + // Parses an absolute path as returned by a Win32 API call like GetFinalPathNameByHandle() or + // GetCurrentDirectory(). A "\\?\" prefix is optional but understood if present. + // + // Since such Win32 API calls generally return a length, this function inputs an array slice. + // The slice should not include any NUL terminator. + +private: + Array parts; + + // TODO(perf): Consider unrolling one element from `parts`, so that a one-element path doesn't + // require allocation of an array. + + enum { ALREADY_CHECKED }; + Path(Array parts, decltype(ALREADY_CHECKED)); + + friend class PathPtr; + + static String stripNul(String input); + static void validatePart(StringPtr part); + static void evalPart(Vector& parts, ArrayPtr part); + static Path evalImpl(Vector&& parts, StringPtr path); + static Path evalWin32Impl(Vector&& parts, StringPtr path, bool fromApi = false); + static size_t countParts(StringPtr path); + static size_t countPartsWin32(StringPtr path); + static bool isWin32Drive(ArrayPtr part); + static bool isNetbiosName(ArrayPtr part); + static bool isWin32Special(StringPtr part); +}; + +class PathPtr { + // Points to a Path or a slice of a Path, but doesn't own it. + // + // PathPtr is to Path as ArrayPtr is to Array and StringPtr is to String. + +public: + PathPtr(decltype(nullptr)); + PathPtr(const Path& path); + + Path clone(); + Path append(Path&& suffix) const; + Path append(PathPtr suffix) const; + Path append(StringPtr suffix) const; + Path append(String&& suffix) const; + Path eval(StringPtr pathText) const; + PathPtr basename() const; + PathPtr parent() const; + String toString(bool absolute = false) const; + const String& operator[](size_t i) const; + size_t size() const; + const String* begin() const; + const String* end() const; + PathPtr slice(size_t start, size_t end) const; + bool operator==(PathPtr other) const; + bool operator!=(PathPtr other) const; + bool operator< (PathPtr other) const; + bool operator> (PathPtr other) const; + bool operator<=(PathPtr other) const; + bool operator>=(PathPtr other) const; + uint hashCode() const; + bool startsWith(PathPtr prefix) const; + bool endsWith(PathPtr suffix) const; + Path evalWin32(StringPtr pathText) const; + Path evalNative(StringPtr pathText) const; + String toWin32String(bool absolute = false) const; + String toNativeString(bool absolute = false) const; + Array forWin32Api(bool absolute) const; + // Equivalent to the corresponding methods of `Path`. + +private: + ArrayPtr parts; + + explicit PathPtr(ArrayPtr parts); + + String toWin32StringImpl(bool absolute, bool forApi) const; + + friend class Path; +}; + +// ======================================================================================= +// The filesystem API +// +// This API is strictly synchronous because, unfortunately, there's no such thing as asynchronous +// filesystem access in practice. The filesystem drivers on Linux are written to assume they can +// block. The AIO API is only actually asynchronous for reading/writing the raw file blocks, but if +// the filesystem needs to be involved (to allocate blocks, update metadata, etc.) that will block. +// It's best to imagine that the filesystem is just another tier of memory that happens to be +// slower than RAM (which is slower than L3 cache, which is slower than L2, which is slower than +// L1). You can't do asynchronous RAM access so why asynchronous filesystem? The only way to +// parallelize these is using threads. +// +// All KJ filesystem objects are thread-safe, and so all methods are marked "const" (even write +// methods). Of course, if you concurrently write the same bytes of a file from multiple threads, +// it's unspecified which write will "win". + +class FsNode { + // Base class for filesystem node types. + +public: + Own clone() const; + // Creates a new object of exactly the same type as this one, pointing at exactly the same + // external object. + // + // Under the hood, this will call dup(), so the FD number will not be the same. + + virtual Maybe getFd() const { return nullptr; } + // Get the underlying Unix file descriptor, if any. Returns nullptr if this object actually isn't + // wrapping a file descriptor. + + virtual Maybe getWin32Handle() const { return nullptr; } + // Get the underlying Win32 HANDLE, if any. Returns nullptr if this object actually isn't + // wrapping a handle. + + enum class Type { + FILE, + DIRECTORY, + SYMLINK, + BLOCK_DEVICE, + CHARACTER_DEVICE, + NAMED_PIPE, + SOCKET, + OTHER, + }; + + struct Metadata { + Type type = Type::FILE; + + uint64_t size = 0; + // Logical size of the file. + + uint64_t spaceUsed = 0; + // Physical size of the file on disk. May be smaller for sparse files, or larger for + // pre-allocated files. + + Date lastModified = UNIX_EPOCH; + // Last modification time of the file. + + uint linkCount = 1; + // Number of hard links pointing to this node. + + uint64_t hashCode = 0; + // Hint which can be used to determine if two FsNode instances point to the same underlying + // file object. If two FsNodes report different hashCodes, then they are not the same object. + // If they report the same hashCode, then they may or may not be the same object. + // + // The Unix filesystem implementation builds the hashCode based on st_dev and st_ino of + // `struct stat`. However, note that some filesystems -- especially FUSE-based -- may not fill + // in st_ino. + // + // The Windows filesystem implementation builds the hashCode based on dwVolumeSerialNumber and + // dwFileIndex{Low,High} of the BY_HANDLE_FILE_INFORMATION structure. However, these are again + // not guaranteed to be unique on all filesystems. In particular the documentation says that + // ReFS uses 128-bit identifiers which can't be represented here, and again virtual filesystems + // may often not report real identifiers. + // + // Of course, the process of hashing values into a single hash code can also cause collisions + // even if the filesystem reports reliable information. + // + // Additionally note that this value is not reliable when returned by `lstat()`. You should + // actually open the object, then call `stat()` on the opened object. + + // Not currently included: + // - Access control info: Differs wildly across platforms, and KJ prefers capabilities anyway. + // - Other timestamps: Differs across platforms. + // - Device number: If you care, you're probably doing platform-specific stuff anyway. + + Metadata() = default; + Metadata(Type type, uint64_t size, uint64_t spaceUsed, Date lastModified, uint linkCount, + uint64_t hashCode) + : type(type), size(size), spaceUsed(spaceUsed), lastModified(lastModified), + linkCount(linkCount), hashCode(hashCode) {} + // TODO(cleanup): This constructor is redundant in C++14, but needed in C++11. + }; + + virtual Metadata stat() const = 0; + + virtual void sync() const = 0; + virtual void datasync() const = 0; + // Maps to fsync() and fdatasync() system calls. + // + // Also, when creating or overwriting a file, the first call to sync() atomically links the file + // into the filesystem (*after* syncing the data), so than incomplete data is never visible to + // other processes. (In practice this works by writing into a temporary file and then rename()ing + // it.) + +protected: + virtual Own cloneFsNode() const = 0; + // Implements clone(). Required to return an object with exactly the same type as this one. + // Hence, every subclass must implement this. +}; + +class ReadableFile: public FsNode { +public: + Own clone() const; + + String readAllText() const; + // Read all text in the file and return as a big string. + + Array readAllBytes() const; + // Read all bytes in the file and return as a big byte array. + // + // This differs from mmap() in that the read is performed all at once. Future changes to the file + // do not affect the returned copy. Consider using mmap() instead, particularly for large files. + + virtual size_t read(uint64_t offset, ArrayPtr buffer) const = 0; + // Fills `buffer` with data starting at `offset`. Returns the number of bytes actually read -- + // the only time this is less than `buffer.size()` is when EOF occurs mid-buffer. + + virtual Array mmap(uint64_t offset, uint64_t size) const = 0; + // Maps the file to memory read-only. The returned array always has exactly the requested size. + // Depending on the capabilities of the OS and filesystem, the mapping may or may not reflect + // changes that happen to the file after mmap() returns. + // + // Multiple calls to mmap() on the same file may or may not return the same mapping (it is + // immutable, so there's no possibility of interference). + // + // If the file cannot be mmap()ed, an implementation may choose to allocate a buffer on the heap, + // read into it, and return that. This should only happen if a real mmap() is impossible. + // + // The returned array is always exactly the size requested. However, accessing bytes beyond the + // current end of the file may raise SIGBUS, or may simply return zero. + + virtual Array mmapPrivate(uint64_t offset, uint64_t size) const = 0; + // Like mmap() but returns a view that the caller can modify. Modifications will not be written + // to the underlying file. Every call to this method returns a unique mapping. Changes made to + // the underlying file by other clients may or may not be reflected in the mapping -- in fact, + // some changes may be reflected while others aren't, even within the same mapping. + // + // In practice this is often implemented using copy-on-write pages. When you first write to a + // page, a copy is made. Hence, changes to the underlying file within that page stop being + // reflected in the mapping. +}; + +class AppendableFile: public FsNode, public OutputStream { +public: + Own clone() const; + + // All methods are inherited. +}; + +class WritableFileMapping { +public: + virtual ArrayPtr get() const = 0; + // Gets the mapped bytes. The returned array can be modified, and those changes may be written to + // the underlying file, but there is no guarantee that they are written unless you subsequently + // call changed(). + + virtual void changed(ArrayPtr slice) const = 0; + // Notifies the implementation that the given bytes have changed. For some implementations this + // may be a no-op while for others it may be necessary in order for the changes to be written + // back at all. + // + // `slice` must be a slice of `bytes()`. + + virtual void sync(ArrayPtr slice) const = 0; + // Implies `changed()`, and then waits until the range has actually been written to disk before + // returning. + // + // `slice` must be a slice of `bytes()`. + // + // On Windows, this calls FlushViewOfFile(). The documentation for this function implies that in + // some circumstances, to fully sync to physical disk, you may need to call FlushFileBuffers() on + // the file HANDLE as well. The documentation is not very clear on when and why this is needed. + // If you believe your program needs this, you can accomplish it by calling `.sync()` on the File + // object after calling `.sync()` on the WritableFileMapping. +}; + +class File: public ReadableFile { +public: + Own clone() const; + + void writeAll(ArrayPtr bytes) const; + void writeAll(StringPtr text) const; + // Completely replace the file with the given bytes or text. + + virtual void write(uint64_t offset, ArrayPtr data) const = 0; + // Write the given data starting at the given offset in the file. + + virtual void zero(uint64_t offset, uint64_t size) const = 0; + // Write zeros to the file, starting at `offset` and continuing for `size` bytes. If the platform + // supports it, this will "punch a hole" in the file, such that blocks that are entirely zeros + // do not take space on disk. + + virtual void truncate(uint64_t size) const = 0; + // Set the file end pointer to `size`. If `size` is less than the current size, data past the end + // is truncated. If `size` is larger than the current size, zeros are added to the end of the + // file. If the platform supports it, blocks containing all-zeros will not be stored to disk. + + virtual Own mmapWritable(uint64_t offset, uint64_t size) const = 0; + // Like ReadableFile::mmap() but returns a mapping for which any changes will be immediately + // visible in other mappings of the file on the same system and will eventually be written back + // to the file. + + virtual size_t copy(uint64_t offset, const ReadableFile& from, uint64_t fromOffset, + uint64_t size) const; + // Copies bytes from one file to another. + // + // Copies `size` bytes or to EOF, whichever comes first. Returns the number of bytes actually + // copied. Hint: Pass kj::maxValue for `size` to always copy to EOF. + // + // The copy is not atomic. Concurrent writes may lead to garbage results. + // + // The default implementation performs a series of reads and writes. Subclasses can often provide + // superior implementations that offload the work to the OS or even implement copy-on-write. +}; + +class ReadableDirectory: public FsNode { + // Read-only subset of `Directory`. + +public: + Own clone() const; + + virtual Array listNames() const = 0; + // List the contents of this directory. Does NOT include "." nor "..". + + struct Entry { + FsNode::Type type; + String name; + + inline bool operator< (const Entry& other) const { return name < other.name; } + inline bool operator> (const Entry& other) const { return name > other.name; } + inline bool operator<=(const Entry& other) const { return name <= other.name; } + inline bool operator>=(const Entry& other) const { return name >= other.name; } + // Convenience comparison operators to sort entries by name. + }; + + virtual Array listEntries() const = 0; + // List the contents of the directory including the type of each file. On some platforms and + // filesystems, this is just as fast as listNames(), but on others it may require stat()ing each + // file. + + virtual bool exists(PathPtr path) const = 0; + // Does the specified path exist? + // + // If the path is a symlink, the symlink is followed and the return value indicates if the target + // exists. If you want to know if the symlink exists, use lstat(). (This implies that listNames() + // may return names for which exists() reports false.) + + FsNode::Metadata lstat(PathPtr path) const; + virtual Maybe tryLstat(PathPtr path) const = 0; + // Gets metadata about the path. If the path is a symlink, it is not followed -- the metadata + // describes the symlink itself. `tryLstat()` returns null if the path doesn't exist. + + Own openFile(PathPtr path) const; + virtual Maybe> tryOpenFile(PathPtr path) const = 0; + // Open a file for reading. + // + // `tryOpenFile()` returns null if the path doesn't exist. Other errors still throw exceptions. + + Own openSubdir(PathPtr path) const; + virtual Maybe> tryOpenSubdir(PathPtr path) const = 0; + // Opens a subdirectory. + // + // `tryOpenSubdir()` returns null if the path doesn't exist. Other errors still throw exceptions. + + String readlink(PathPtr path) const; + virtual Maybe tryReadlink(PathPtr path) const = 0; + // If `path` is a symlink, reads and returns the link contents. + // + // Note that tryReadlink() differs subtly from tryOpen*(). For example, tryOpenFile() throws if + // the path is not a file (e.g. if it's a directory); it only returns null if the path doesn't + // exist at all. tryReadlink() returns null if either the path doesn't exist, or if it does exist + // but isn't a symlink. This is because if it were to throw instead, then almost every real-world + // use case of tryReadlink() would be forced to perform an lstat() first for the sole purpose of + // checking if it is a link, wasting a syscall and a path traversal. + // + // See Directory::symlink() for warnings about symlinks. +}; + +enum class WriteMode { + // Mode for opening a file (or directory) for write. + // + // (To open a file or directory read-only, do not specify a mode.) + // + // WriteMode is a bitfield. Hence, it overloads the bitwise logic operators. To check if a + // particular bit is set in a bitfield, use kj::has(), like: + // + // if (kj::has(mode, WriteMode::MUST_EXIST)) { + // requireExists(path); + // } + // + // (`if (mode & WriteMode::MUST_EXIST)` doesn't work because WriteMode is an enum class, which + // cannot be converted to bool. Alas, C++ does not allow you to define a conversion operator + // on an enum type, so we can't define a conversion to bool.) + + // ----------------------------------------- + // Core flags + // + // At least one of CREATE or MODIFY must be specified. Optionally, the two flags can be combined + // with a bitwise-OR. + + CREATE = 1, + // Create a new empty file. + // + // When not combined with MODIFY, if the file already exists (including as a broken symlink), + // tryOpenFile() returns null (and openFile() throws). + // + // When combined with MODIFY, if the path already exists, it will be opened as if CREATE hadn't + // been specified at all. If the path refers to a broken symlink, the file at the target of the + // link will be created (if its parent directory exists). + + MODIFY = 2, + // Modify an existing file. + // + // When not combined with CREATE, if the file doesn't exist (including if it is a broken symlink), + // tryOpenFile() returns null (and openFile() throws). + // + // When combined with CREATE, if the path doesn't exist, it will be created as if MODIFY hadn't + // been specified at all. If the path refers to a broken symlink, the file at the target of the + // link will be created (if its parent directory exists). + + // ----------------------------------------- + // Additional flags + // + // Any number of these may be OR'd with the core flags. + + CREATE_PARENT = 4, + // Indicates that if the target node's parent directory doesn't exist, it should be created + // automatically, along with its parent, and so on. This creation is NOT atomic. + // + // This bit only makes sense with CREATE or REPLACE. + + EXECUTABLE = 8, + // Mark this file executable, if this is a meaningful designation on the host platform. + + PRIVATE = 16, + // Indicates that this file is sensitive and should have permissions masked so that it is only + // accessible by the current user. + // + // When this is not used, the platform's default access control settings are used. On Unix, + // that usually means the umask is applied. On Windows, it means permissions are inherited from + // the parent. +}; + +inline constexpr WriteMode operator|(WriteMode a, WriteMode b) { + return static_cast(static_cast(a) | static_cast(b)); +} +inline constexpr WriteMode operator&(WriteMode a, WriteMode b) { + return static_cast(static_cast(a) & static_cast(b)); +} +inline constexpr WriteMode operator+(WriteMode a, WriteMode b) { + return static_cast(static_cast(a) | static_cast(b)); +} +inline constexpr WriteMode operator-(WriteMode a, WriteMode b) { + return static_cast(static_cast(a) & ~static_cast(b)); +} +template > +bool has(T haystack, T needle) { + return (static_cast<__underlying_type(T)>(haystack) & + static_cast<__underlying_type(T)>(needle)) == + static_cast<__underlying_type(T)>(needle); +} + +enum class TransferMode { + // Specifies desired behavior for Directory::transfer(). + + MOVE, + // The node is moved to the new location, i.e. the old location is deleted. If possible, this + // move is performed without copying, otherwise it is performed as a copy followed by a delete. + + LINK, + // The new location becomes a synonym for the old location (a "hard link"). Filesystems have + // varying support for this -- typically, it is not supported on directories. + + COPY + // The new location becomes a copy of the old. + // + // Some filesystems may implement this in terms of copy-on-write. + // + // If the filesystem supports sparse files, COPY takes sparseness into account -- it will punch + // holes in the target file where holes exist in the source file. +}; + +class Directory: public ReadableDirectory { + // Refers to a specific directory on disk. + // + // A `Directory` object *only* provides access to children of the directory, not parents. That + // is, you cannot open the file "..", nor jump to the root directory with "/". + // + // On OSs that support it, a `Directory` is backed by an open handle to the directory node. This + // means: + // - If the directory is renamed on-disk, the `Directory` object still points at it. + // - Opening files in the directory only requires the OS to traverse the path from the directory + // to the file; it doesn't have to re-traverse all the way from the filesystem root. + // + // On Windows, a `Directory` object holds a lock on the underlying directory such that it cannot + // be renamed nor deleted while the object exists. This is necessary because Windows does not + // fully support traversing paths relative to file handles (it does for some operations but not + // all), so the KJ filesystem implementation is forced to remember the full path and needs to + // ensure that the path is not invalidated. If, in the future, Windows fully supports + // handle-relative paths, KJ may stop locking directories in this way, so do not rely on this + // behavior. + +public: + Own clone() const; + + template + class Replacer { + // Implements an atomic replacement of a file or directory, allowing changes to be made to + // storage in a way that avoids losing data in a power outage and prevents other processes + // from observing content in an inconsistent state. + // + // `T` may be `File` or `Directory`. For readability, the text below describes replacing a + // file, but the logic is the same for directories. + // + // When you call `Directory::replaceFile()`, a temporary file is created, but the specified + // path is not yet touched. You may call `get()` to obtain the temporary file object, through + // which you may initialize its content, knowing that no other process can see it yet. The file + // is atomically moved to its final path when you call `commit()`. If you destroy the Replacer + // without calling commit(), the temporary file is deleted. + // + // Note that most operating systems sadly do not support creating a truly unnamed temporary file + // and then linking it in later. Moreover, the file cannot necessarily be created in the system + // temporary directory because it might not be on the same filesystem as the target. Therefore, + // the replacement file may initially be created in the same directory as its eventual target. + // The implementation of Directory will choose a name that is unique and "hidden" according to + // the conventions of the filesystem. Additionally, the implementation of Directory will avoid + // returning these temporary files from its list*() methods, in order to avoid observable + // inconsistencies across platforms. + public: + explicit Replacer(WriteMode mode); + + virtual const T& get() = 0; + // Gets the File or Directory representing the replacement data. Fill in this object before + // calling commit(). + + void commit(); + virtual bool tryCommit() = 0; + // Commit the replacement. + // + // `tryCommit()` may return false based on the CREATE/MODIFY bits passed as the WriteMode when + // the replacement was initiated. (If CREATE but not MODIFY was used, tryCommit() returns + // false to indicate that the target file already existed. If MODIFY but not CREATE was used, + // tryCommit() returns false to indicate that the file didn't exist.) + // + // `commit()` is atomic, meaning that there is no point in time at which other processes + // observing the file will see it in an intermediate state -- they will either see the old + // content or the complete new content. This includes in the case of a power outage or machine + // failure: on recovery, the file will either be in the old state or the new state, but not in + // some intermediate state. + // + // It's important to note that a power failure *after commit() returns* can still revert the + // file to its previous state. That is, `commit()` does NOT guarantee that, upon return, the + // new content is durable. In order to guarantee this, you must call `sync()` on the immediate + // parent directory of the replaced file. + // + // Note that, sadly, not all filesystems / platforms are capable of supporting all of the + // guarantees documented above. In such cases, commit() will make a best-effort attempt to do + // what it claims. Some examples of possible problems include: + // - Any guarantees about durability through a power outage probably require a journaling + // filesystem. + // - Many platforms do not support atomically replacing a non-empty directory. Linux does as + // of kernel 3.15 (via the renameat2() syscall using RENAME_EXCHANGE). Where not supported, + // the old directory will be moved away just before the replacement is moved into place. + // - Many platforms do not support atomically requiring the existence or non-existence of a + // file before replacing it. In these cases, commit() may have to perform the check as a + // separate step, with a small window for a race condition. + // - Many platforms do not support "unlinking" a non-empty directory, meaning that a replaced + // directory will need to be deconstructed by deleting all contents. If another process has + // the directory open when it is replaced, that process will observe the contents + // disappearing after the replacement (actually, a swap) has taken place. This differs from + // files, where a process that has opened a file before it is replaced will continue see the + // file's old content unchanged after the replacement. + // - On Windows, there are multiple ways to replace one file with another in a single system + // call, but none are documented as being atomic. KJ always uses `MoveFileEx()` with + // MOVEFILE_REPLACE_EXISTING. While the alternative `ReplaceFile()` is attractive for many + // reasons, it has the critical problem that it cannot be used when the source file has open + // file handles, which is generally the case when using Replacer. + + protected: + const WriteMode mode; + }; + + using ReadableDirectory::openFile; + using ReadableDirectory::openSubdir; + using ReadableDirectory::tryOpenFile; + using ReadableDirectory::tryOpenSubdir; + + Own openFile(PathPtr path, WriteMode mode) const; + virtual Maybe> tryOpenFile(PathPtr path, WriteMode mode) const = 0; + // Open a file for writing. + // + // `tryOpenFile()` returns null if the path is required to exist but doesn't (MODIFY or REPLACE) + // or if the path is required not to exist but does (CREATE or RACE). These are the only cases + // where it returns null -- all other types of errors (like "access denied") throw exceptions. + + virtual Own> replaceFile(PathPtr path, WriteMode mode) const = 0; + // Construct a file which, when ready, will be atomically moved to `path`, replacing whatever + // is there already. See `Replacer` for detalis. + // + // The `CREATE` and `MODIFY` bits of `mode` are not enforced until commit time, hence + // `replaceFile()` has no "try" variant. + + virtual Own createTemporary() const = 0; + // Create a temporary file backed by this directory's filesystem, but which isn't linked into + // the directory tree. The file is deleted from disk when all references to it have been dropped. + + Own appendFile(PathPtr path, WriteMode mode) const; + virtual Maybe> tryAppendFile(PathPtr path, WriteMode mode) const = 0; + // Opens the file for appending only. Useful for log files. + // + // If the underlying filesystem supports it, writes to the file will always be appended even if + // other writers are writing to the same file at the same time -- however, some implementations + // may instead assume that no other process is changing the file size between writes. + + Own openSubdir(PathPtr path, WriteMode mode) const; + virtual Maybe> tryOpenSubdir(PathPtr path, WriteMode mode) const = 0; + // Opens a subdirectory for writing. + + virtual Own> replaceSubdir(PathPtr path, WriteMode mode) const = 0; + // Construct a directory which, when ready, will be atomically moved to `path`, replacing + // whatever is there already. See `Replacer` for detalis. + // + // The `CREATE` and `MODIFY` bits of `mode` are not enforced until commit time, hence + // `replaceSubdir()` has no "try" variant. + + void symlink(PathPtr linkpath, StringPtr content, WriteMode mode) const; + virtual bool trySymlink(PathPtr linkpath, StringPtr content, WriteMode mode) const = 0; + // Create a symlink. `content` is the raw text which will be written into the symlink node. + // How this text is interpreted is entirely dependent on the filesystem. Note in particular that: + // - Windows will require a path that uses backslashes as the separator. + // - InMemoryDirectory does not support symlinks containing "..". + // + // Unfortunately under many implementations symlink() can be used to break out of the directory + // by writing an absolute path or utilizing "..". Do not call this method with a value for + // `target` that you don't trust. + // + // `mode` must be CREATE or REPLACE, not MODIFY. CREATE_PARENT is honored but EXECUTABLE and + // PRIVATE have no effect. `trySymlink()` returns false in CREATE mode when the target already + // exists. + + void transfer(PathPtr toPath, WriteMode toMode, + PathPtr fromPath, TransferMode mode) const; + void transfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode) const; + virtual bool tryTransfer(PathPtr toPath, WriteMode toMode, + const Directory& fromDirectory, PathPtr fromPath, + TransferMode mode) const; + virtual Maybe tryTransferTo(const Directory& toDirectory, PathPtr toPath, WriteMode toMode, + PathPtr fromPath, TransferMode mode) const; + // Move, link, or copy a file/directory tree from one location to another. + // + // Filesystems vary in what kinds of transfers are allowed, especially for TransferMode::LINK, + // and whether TransferMode::MOVE is implemented as an actual move vs. copy+delete. + // + // tryTransfer() returns false if the source location didn't exist, or when `toMode` is CREATE + // and the target already exists. The default implementation implements only TransferMode::COPY. + // + // tryTransferTo() exists to implement double-dispatch. It should be called as a fallback by + // implementations of tryTransfer() in cases where the target directory would otherwise fail or + // perform a pessimal transfer. The default implementation returns nullptr, which the caller + // should interpret as: "I don't have any special optimizations; do the obvious thing." + // + // `toMode` controls how the target path is created. CREATE_PARENT is honored but EXECUTABLE and + // PRIVATE have no effect. + + void remove(PathPtr path) const; + virtual bool tryRemove(PathPtr path) const = 0; + // Deletes/unlinks the given path. If the path names a directory, it is recursively deleted. + // + // tryRemove() returns false in the specific case that the path doesn't exist. remove() would + // throw in this case. In all other error cases (like "access denied"), tryRemove() still throws; + // it is only "does not exist" that produces a false return. + // + // WARNING: The Windows implementation of recursive deletion is currently not safe to call from a + // privileged process to delete directories writable by unprivileged users, due to a race + // condition in which the user could trick the algorithm into following a symlink and deleting + // everything at the destination. This race condition is not present in the Unix + // implementation. Fixing it for Windows would require rewriting a lot of code to use different + // APIs. If you're interested, see the TODO(security) in filesystem-disk-win32.c++. + + // TODO(someday): + // - Support sockets? There's no openat()-like interface for sockets, so it's hard to support + // them currently. Also you'd probably want to use them with the async library. + // - Support named pipes? Unclear if there's a use case that isn't better-served by sockets. + // Then again, they can be openat()ed. + // - Support watching for changes (inotify). Probably also requires the async library. Also + // lacks openat()-like semantics. + // - xattrs -- linux-specific + // - chown/chmod/etc. -- unix-specific, ACLs, eww + // - set timestamps -- only needed by archiving programs/ + // - advisory locks + // - sendfile? + // - fadvise and such + +private: + static void commitFailed(WriteMode mode); +}; + +class Filesystem { +public: + virtual const Directory& getRoot() const = 0; + // Get the filesystem's root directory, as of the time the Filesystem object was created. + + virtual const Directory& getCurrent() const = 0; + // Get the filesystem's current directory, as of the time the Filesystem object was created. + + virtual PathPtr getCurrentPath() const = 0; + // Get the path from the root to the current directory, as of the time the Filesystem object was + // created. Note that because a `Directory` does not provide access to its parent, if you want to + // follow `..` from the current directory, you must use `getCurrentPath().eval("..")` or + // `getCurrentPath().parent()`. + // + // This function attempts to determine the path as it appeared in the user's shell before this + // program was started. That means, if the user had `cd`ed into a symlink, the path through that + // symlink is returned, *not* the canonical path. + // + // Because of this, there is an important difference between how the operating system interprets + // "../foo" and what you get when you write `getCurrentPath().eval("../foo")`: The former + // will interpret ".." relative to the directory's canonical path, whereas the latter will + // interpret it relative to the path shown in the user's shell. In practice, the latter is + // almost always what the user wants! But the former behavior is what almost all commands do + // in practice, and it leads to confusion. KJ commands should implement the behavior the user + // expects. +}; + +// ======================================================================================= + +Own newInMemoryFile(const Clock& clock); +Own newInMemoryDirectory(const Clock& clock); +// Construct file and directory objects which reside in-memory. +// +// InMemoryFile has the following special properties: +// - The backing store is not sparse and never gets smaller even if you truncate the file. +// - While a non-private memory mapping exists, the backing store cannot get larger. Any operation +// which would expand it will throw. +// +// InMemoryDirectory has the following special properties: +// - Symlinks are processed using Path::parse(). This implies that a symlink cannot point to a +// parent directory -- InMemoryDirectory does not know its parent. +// - link() can link directory nodes in addition to files. +// - link() and rename() accept any kind of Directory as `fromDirectory` -- it doesn't need to be +// another InMemoryDirectory. However, for rename(), the from path must be a directory. + +Own newFileAppender(Own inner); +// Creates an AppendableFile by wrapping a File. Note that this implementation assumes it is the +// only writer. A correct implementation should always append to the file even if other writes +// are happening simultaneously, as is achieved with the O_APPEND flag to open(2), but that +// behavior is not possible to emulate on top of `File`. + +#if _WIN32 +typedef AutoCloseHandle OsFileHandle; +#else +typedef AutoCloseFd OsFileHandle; +#endif + +Own newDiskReadableFile(OsFileHandle fd); +Own newDiskAppendableFile(OsFileHandle fd); +Own newDiskFile(OsFileHandle fd); +Own newDiskReadableDirectory(OsFileHandle fd); +Own newDiskDirectory(OsFileHandle fd); +// Wrap a file descriptor (or Windows HANDLE) as various filesystem types. + +Own newDiskFilesystem(); +// Get at implementation of `Filesystem` representing the real filesystem. +// +// DO NOT CALL THIS except at the top level of your program, e.g. in main(). Anywhere else, you +// should instead have your caller pass in a Filesystem object, or a specific Directory object, +// or whatever it is that your code needs. This ensures that your code supports dependency +// injection, which makes it more reusable and testable. +// +// newDiskFilesystem() reads the current working directory at the time it is called. The returned +// object is not affected by subsequent calls to chdir(). + +// ======================================================================================= +// inline implementation details + +inline Path::Path(decltype(nullptr)): parts(nullptr) {} +inline Path::Path(std::initializer_list parts) + : Path(arrayPtr(parts.begin(), parts.end())) {} +inline Path::Path(Array parts, decltype(ALREADY_CHECKED)) + : parts(kj::mv(parts)) {} +inline Path Path::clone() const { return PathPtr(*this).clone(); } +inline Path Path::append(Path&& suffix) const& { return PathPtr(*this).append(kj::mv(suffix)); } +inline Path Path::append(PathPtr suffix) const& { return PathPtr(*this).append(suffix); } +inline Path Path::append(StringPtr suffix) const& { return append(Path(suffix)); } +inline Path Path::append(StringPtr suffix) && { return kj::mv(*this).append(Path(suffix)); } +inline Path Path::append(String&& suffix) const& { return append(Path(kj::mv(suffix))); } +inline Path Path::append(String&& suffix) && { return kj::mv(*this).append(Path(kj::mv(suffix))); } +inline Path Path::eval(StringPtr pathText) const& { return PathPtr(*this).eval(pathText); } +inline PathPtr Path::basename() const& { return PathPtr(*this).basename(); } +inline PathPtr Path::parent() const& { return PathPtr(*this).parent(); } +inline const String& Path::operator[](size_t i) const& { return parts[i]; } +inline String Path::operator[](size_t i) && { return kj::mv(parts[i]); } +inline size_t Path::size() const { return parts.size(); } +inline const String* Path::begin() const { return parts.begin(); } +inline const String* Path::end() const { return parts.end(); } +inline PathPtr Path::slice(size_t start, size_t end) const& { + return PathPtr(*this).slice(start, end); +} +inline bool Path::operator==(PathPtr other) const { return PathPtr(*this) == other; } +inline bool Path::operator!=(PathPtr other) const { return PathPtr(*this) != other; } +inline bool Path::operator< (PathPtr other) const { return PathPtr(*this) < other; } +inline bool Path::operator> (PathPtr other) const { return PathPtr(*this) > other; } +inline bool Path::operator<=(PathPtr other) const { return PathPtr(*this) <= other; } +inline bool Path::operator>=(PathPtr other) const { return PathPtr(*this) >= other; } +inline bool Path::operator==(const Path& other) const { return PathPtr(*this) == PathPtr(other); } +inline bool Path::operator!=(const Path& other) const { return PathPtr(*this) != PathPtr(other); } +inline bool Path::operator< (const Path& other) const { return PathPtr(*this) < PathPtr(other); } +inline bool Path::operator> (const Path& other) const { return PathPtr(*this) > PathPtr(other); } +inline bool Path::operator<=(const Path& other) const { return PathPtr(*this) <= PathPtr(other); } +inline bool Path::operator>=(const Path& other) const { return PathPtr(*this) >= PathPtr(other); } +inline uint Path::hashCode() const { return kj::hashCode(parts); } +inline bool Path::startsWith(PathPtr prefix) const { return PathPtr(*this).startsWith(prefix); } +inline bool Path::endsWith (PathPtr suffix) const { return PathPtr(*this).endsWith (suffix); } +inline String Path::toString(bool absolute) const { return PathPtr(*this).toString(absolute); } +inline Path Path::evalWin32(StringPtr pathText) const& { + return PathPtr(*this).evalWin32(pathText); +} +inline String Path::toWin32String(bool absolute) const { + return PathPtr(*this).toWin32String(absolute); +} +inline Array Path::forWin32Api(bool absolute) const { + return PathPtr(*this).forWin32Api(absolute); +} + +inline PathPtr::PathPtr(decltype(nullptr)): parts(nullptr) {} +inline PathPtr::PathPtr(const Path& path): parts(path.parts) {} +inline PathPtr::PathPtr(ArrayPtr parts): parts(parts) {} +inline Path PathPtr::append(StringPtr suffix) const { return append(Path(suffix)); } +inline Path PathPtr::append(String&& suffix) const { return append(Path(kj::mv(suffix))); } +inline const String& PathPtr::operator[](size_t i) const { return parts[i]; } +inline size_t PathPtr::size() const { return parts.size(); } +inline const String* PathPtr::begin() const { return parts.begin(); } +inline const String* PathPtr::end() const { return parts.end(); } +inline PathPtr PathPtr::slice(size_t start, size_t end) const { + return PathPtr(parts.slice(start, end)); +} +inline bool PathPtr::operator!=(PathPtr other) const { return !(*this == other); } +inline bool PathPtr::operator> (PathPtr other) const { return other < *this; } +inline bool PathPtr::operator<=(PathPtr other) const { return !(other < *this); } +inline bool PathPtr::operator>=(PathPtr other) const { return !(*this < other); } +inline uint PathPtr::hashCode() const { return kj::hashCode(parts); } +inline String PathPtr::toWin32String(bool absolute) const { + return toWin32StringImpl(absolute, false); +} + +#if _WIN32 +inline Path Path::evalNative(StringPtr pathText) const& { + return evalWin32(pathText); +} +inline Path Path::evalNative(StringPtr pathText) && { + return kj::mv(*this).evalWin32(pathText); +} +inline String Path::toNativeString(bool absolute) const { + return toWin32String(absolute); +} +inline Path PathPtr::evalNative(StringPtr pathText) const { + return evalWin32(pathText); +} +inline String PathPtr::toNativeString(bool absolute) const { + return toWin32String(absolute); +} +#else +inline Path Path::evalNative(StringPtr pathText) const& { + return eval(pathText); +} +inline Path Path::evalNative(StringPtr pathText) && { + return kj::mv(*this).eval(pathText); +} +inline String Path::toNativeString(bool absolute) const { + return toString(absolute); +} +inline Path PathPtr::evalNative(StringPtr pathText) const { + return eval(pathText); +} +inline String PathPtr::toNativeString(bool absolute) const { + return toString(absolute); +} +#endif // _WIN32, else + +inline Own FsNode::clone() const { return cloneFsNode(); } +inline Own ReadableFile::clone() const { + return cloneFsNode().downcast(); +} +inline Own AppendableFile::clone() const { + return cloneFsNode().downcast(); +} +inline Own File::clone() const { return cloneFsNode().downcast(); } +inline Own ReadableDirectory::clone() const { + return cloneFsNode().downcast(); +} +inline Own Directory::clone() const { + return cloneFsNode().downcast(); +} + +inline void Directory::transfer( + PathPtr toPath, WriteMode toMode, PathPtr fromPath, TransferMode mode) const { + return transfer(toPath, toMode, *this, fromPath, mode); +} + +template +inline Directory::Replacer::Replacer(WriteMode mode): mode(mode) {} + +template +void Directory::Replacer::commit() { + if (!tryCommit()) commitFailed(mode); +} + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/function-test.c++ b/c++/src/kj/function-test.c++ index 339f2e8380..16bdae7414 100644 --- a/c++/src/kj/function-test.c++ +++ b/c++/src/kj/function-test.c++ @@ -48,6 +48,10 @@ struct TestType { int foo(int a, int b) { return a + b + callCount++; } + + int foo(int c) { + return c * 100; + } }; TEST(Function, Method) { @@ -61,6 +65,9 @@ TEST(Function, Method) { EXPECT_EQ(3, obj.callCount); + Function f3 = KJ_BIND_METHOD(obj, foo); + EXPECT_EQ(12300, f3(123)); + // Bind to a temporary. f = KJ_BIND_METHOD(TestType(10), foo); @@ -117,5 +124,51 @@ TEST(ConstFunction, Method) { EXPECT_EQ(9 + 2 + 5, f(2, 9)); } +int testFunctionParam(FunctionParam func, char c, bool b) { + return func(c, b); +} + +int testFunctionParamRecursive(FunctionParam func, char c, bool b) { + return testFunctionParam(func, c, b); +} + +KJ_TEST("FunctionParam") { + { + int i = 123; + int result = testFunctionParam([i](char c, bool b) { + KJ_EXPECT(c == 'x'); + KJ_EXPECT(b); + KJ_EXPECT(i == 123); + return 456; + }, 'x', true); + + KJ_EXPECT(result == 456); + } + + { + int i = 123; + auto func = [i](char c, bool b) { + KJ_EXPECT(c == 'x'); + KJ_EXPECT(b); + KJ_EXPECT(i == 123); + return 456; + }; + int result = testFunctionParam(func, 'x', true); + KJ_EXPECT(result == 456); + } + + { + int i = 123; + int result = testFunctionParamRecursive([i](char c, bool b) { + KJ_EXPECT(c == 'x'); + KJ_EXPECT(b); + KJ_EXPECT(i == 123); + return 456; + }, 'x', true); + + KJ_EXPECT(result == 456); + } +} + } // namespace } // namespace kj diff --git a/c++/src/kj/function.h b/c++/src/kj/function.h index ba6601b560..59ba5f35ba 100644 --- a/c++/src/kj/function.h +++ b/c++/src/kj/function.h @@ -19,15 +19,12 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_FUNCTION_H_ -#define KJ_FUNCTION_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "memory.h" +KJ_BEGIN_HEADER + namespace kj { template @@ -90,6 +87,15 @@ template class ConstFunction; // Like Function, but wraps a "const" (i.e. thread-safe) call. +template +class FunctionParam; +// Like Function, but used specifically as a call parameter type. Does not do any heap allocation. +// +// This type MUST NOT be used for anything other than a parameter type to a function or method. +// This is because if FunctionParam binds to a temporary, it assumes that the temporary will +// outlive the FunctionParam instance. This is true when FunctionParam is used as a parameter type, +// but not if it is used as a local variable nor a class member variable. + template class Function { public: @@ -196,82 +202,92 @@ class ConstFunction { Own impl; }; -#if 1 - -namespace _ { // private - -template -class BoundMethod; - -template ::*method)(Params...)> -class BoundMethod::*)(Params...), method> { +template +class FunctionParam { public: - BoundMethod(T&& t): t(kj::fwd(t)) {} + template + FunctionParam(Func&& func) { + typedef Wrapper> WrapperType; + + // All instances of Wrapper are two pointers in size: a vtable, and a Func&. So if we + // allocate space for two pointers, we can construct a Wrapper in it! + static_assert(sizeof(WrapperType) == sizeof(space), + "expected WrapperType to be two pointers"); + + // Even if `func` is an rvalue reference, it's OK to use it as an lvalue here, because + // FunctionParam is used strictly for parameters. If we captured a temporary, we know that + // temporary will not be destroyed until after the function call completes. + ctor(*reinterpret_cast(space), func); + } + + FunctionParam(const FunctionParam& other) = default; + FunctionParam(FunctionParam&& other) = default; + // Magically, a plain copy works. - Return operator()(Params&&... params) { - return (t.*method)(kj::fwd(params)...); + inline Return operator()(Params... params) { + return (*reinterpret_cast(space))(kj::fwd(params)...); } private: - T t; + alignas(void*) char space[2 * sizeof(void*)]; + + class WrapperBase { + public: + virtual Return operator()(Params... params) = 0; + }; + + template + class Wrapper: public WrapperBase { + public: + Wrapper(Func& func): func(func) {} + + inline Return operator()(Params... params) override { + return func(kj::fwd(params)...); + } + + private: + Func& func; + }; }; -template ::*method)(Params...) const> -class BoundMethod::*)(Params...) const, method> { +namespace _ { // private + +template +class BoundMethod { public: - BoundMethod(T&& t): t(kj::fwd(t)) {} + BoundMethod(T&& t, Func&& func, ConstFunc&& constFunc) + : t(kj::fwd(t)), func(kj::mv(func)), constFunc(kj::mv(constFunc)) {} - Return operator()(Params&&... params) const { - return (t.*method)(kj::fwd(params)...); + template + auto operator()(Params&&... params) { + return func(t, kj::fwd(params)...); + } + template + auto operator()(Params&&... params) const { + return constFunc(t, kj::fwd(params)...); } private: T t; + Func func; + ConstFunc constFunc; }; -} // namespace _ (private) +template +BoundMethod boundMethod(T&& t, Func&& func, ConstFunc&& constFunc) { + return { kj::fwd(t), kj::fwd(func), kj::fwd(constFunc) }; +} -#define KJ_BIND_METHOD(obj, method) \ - ::kj::_::BoundMethod::method), \ - &::kj::Decay::method>(obj) -// Macro that produces a functor object which forwards to the method `obj.name`. If `obj` is an -// lvalue, the functor will hold a reference to it. If `obj` is an rvalue, the functor will -// contain a copy (by move) of it. -// -// The current implementation requires that the method is not overloaded. -// -// TODO(someday): C++14's generic lambdas may be able to simplify this code considerably, and -// probably make it work with overloaded methods. - -#else -// Here's a better implementation of the above that doesn't work with GCC (but does with Clang) -// because it uses a local class with a template method. Sigh. This implementation supports -// overloaded methods. +} // namespace _ (private) #define KJ_BIND_METHOD(obj, method) \ - ({ \ - typedef KJ_DECLTYPE_REF(obj) T; \ - class F { \ - public: \ - inline F(T&& t): t(::kj::fwd(t)) {} \ - template \ - auto operator()(Params&&... params) \ - -> decltype(::kj::instance().method(::kj::fwd(params)...)) { \ - return t.method(::kj::fwd(params)...); \ - } \ - private: \ - T t; \ - }; \ - (F(obj)); \ - }) + ::kj::_::boundMethod(obj, \ + [](auto& s, auto&&... p) mutable { return s.method(kj::fwd(p)...); }, \ + [](auto& s, auto&&... p) { return s.method(kj::fwd(p)...); }) // Macro that produces a functor object which forwards to the method `obj.name`. If `obj` is an // lvalue, the functor will hold a reference to it. If `obj` is an rvalue, the functor will -// contain a copy (by move) of it. - -#endif +// contain a copy (by move) of it. The method is allowed to be overloaded. } // namespace kj -#endif // KJ_FUNCTION_H_ +KJ_END_HEADER diff --git a/c++/src/kj/hash.c++ b/c++/src/kj/hash.c++ new file mode 100644 index 0000000000..bf80a55565 --- /dev/null +++ b/c++/src/kj/hash.c++ @@ -0,0 +1,65 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "hash.h" + +namespace kj { +namespace _ { // private + +uint HashCoder::operator*(ArrayPtr s) const { + // murmur2 adapted from libc++ source code. + // + // TODO(perf): Use CityHash or FarmHash on 64-bit machines? They seem optimized for x86-64; what + // about ARM? Ask Vlad for advice. + + constexpr uint m = 0x5bd1e995; + constexpr uint r = 24; + uint h = s.size(); + const byte* data = s.begin(); + uint len = s.size(); + for (; len >= 4; data += 4, len -= 4) { + uint k; + memcpy(&k, data, sizeof(k)); + k *= m; + k ^= k >> r; + k *= m; + h *= m; + h ^= k; + } + switch (len) { + case 3: + h ^= data[2] << 16; + KJ_FALLTHROUGH; + case 2: + h ^= data[1] << 8; + KJ_FALLTHROUGH; + case 1: + h ^= data[0]; + h *= m; + } + h ^= h >> 13; + h *= m; + h ^= h >> 15; + return h; +} + +} // namespace _ (private) +} // namespace kj diff --git a/c++/src/kj/hash.h b/c++/src/kj/hash.h new file mode 100644 index 0000000000..d6ff46fd81 --- /dev/null +++ b/c++/src/kj/hash.h @@ -0,0 +1,196 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "string.h" + +KJ_BEGIN_HEADER + +namespace kj { +namespace _ { // private + +struct HashCoder { + // This is a dummy type with only one instance: HASHCODER (below). To make an arbitrary type + // hashable, define `operator*(HashCoder, T)` to return any other type that is already hashable. + // Be sure to declare the operator in the same namespace as `T` **or** in the global scope. + // You can use the KJ_HASHCODE() macro as syntax sugar for this. + // + // A more usual way to accomplish what we're doing here would be to require that you define + // a function like `hashCode(T)` and then rely on argument-dependent lookup. However, this has + // the problem that it pollutes other people's namespaces and even the global namespace. For + // example, some other project may already have functions called `hashCode` which do something + // different. Declaring `operator*` with `HashCoder` as the left operand cannot conflict with + // anything. + + uint operator*(ArrayPtr s) const; + inline uint operator*(ArrayPtr s) const { return operator*(s.asConst()); } + + inline uint operator*(ArrayPtr s) const { return operator*(s.asBytes()); } + inline uint operator*(ArrayPtr s) const { return operator*(s.asBytes()); } + inline uint operator*(const Array& s) const { return operator*(s.asBytes()); } + inline uint operator*(const Array& s) const { return operator*(s.asBytes()); } + inline uint operator*(const String& s) const { return operator*(s.asBytes()); } + inline uint operator*(const StringPtr& s) const { return operator*(s.asBytes()); } + inline uint operator*(const ConstString& s) const { return operator*(s.asBytes()); } + + inline uint operator*(decltype(nullptr)) const { return 0; } + inline uint operator*(bool b) const { return b; } + inline uint operator*(char i) const { return i; } + inline uint operator*(signed char i) const { return i; } + inline uint operator*(unsigned char i) const { return i; } + inline uint operator*(signed short i) const { return i; } + inline uint operator*(unsigned short i) const { return i; } + inline uint operator*(signed int i) const { return i; } + inline uint operator*(unsigned int i) const { return i; } + + inline uint operator*(signed long i) const { + if (sizeof(i) == sizeof(uint)) { + return operator*(static_cast(i)); + } else { + return operator*(static_cast(i)); + } + } + inline uint operator*(unsigned long i) const { + if (sizeof(i) == sizeof(uint)) { + return operator*(static_cast(i)); + } else { + return operator*(static_cast(i)); + } + } + inline uint operator*(signed long long i) const { + return operator*(static_cast(i)); + } + inline uint operator*(unsigned long long i) const { + // Mix 64 bits to 32 bits in such a way that if our input values differ primarily in the upper + // 32 bits, we still get good diffusion. (I.e. we cannot just truncate!) + // + // 49123 is an arbitrarily-chosen prime that is vaguely close to 2^16. + // + // TODO(perf): I just made this up. Is it OK? + return static_cast(i) + static_cast(i >> 32) * 49123; + } + + template + uint operator*(T* ptr) const { + static_assert(!isSameType, char>(), "Wrap in StringPtr if you want to hash string " + "contents. If you want to hash the pointer, cast to void*"); + if (sizeof(ptr) == sizeof(uint)) { + // TODO(cleanup): In C++17, make the if() above be `if constexpr ()`, then change this to + // reinterpret_cast(ptr). + return reinterpret_cast(ptr); + } else { + return operator*(reinterpret_cast(ptr)); + } + } + + template () * instance())> + uint operator*(ArrayPtr arr) const; + template () * instance())> + uint operator*(const Array& arr) const; + template > + inline uint operator*(T e) const; + + template ().hashCode())> + inline Result operator*(T&& value) const { return kj::fwd(value).hashCode(); } +}; +static KJ_CONSTEXPR(const) HashCoder HASHCODER = HashCoder(); + +} // namespace _ (private) + +#define KJ_HASHCODE(...) operator*(::kj::_::HashCoder, __VA_ARGS__) +// Defines a hash function for a custom type. Example: +// +// class Foo {...}; +// inline uint KJ_HASHCODE(const Foo& foo) { return kj::hashCode(foo.x, foo.y); } +// +// This allows Foo to be passed to hashCode(). +// +// The function should be declared either in the same namespace as the target type or in the global +// namespace. It can return any type which itself is hashable -- that value will be hashed in turn +// until a `uint` comes out. + +inline uint hashCode(uint value) { return value; } +template +inline uint hashCode(T&& value) { return hashCode(_::HASHCODER * kj::fwd(value)); } +template +inline uint hashCode(T (&arr)[N]) { + static_assert(!isSameType, char>(), "Wrap in StringPtr if you want to hash string " + "contents. If you want to hash the pointer, cast to void*"); + static_assert(isSameType, char>(), "Wrap in ArrayPtr if you want to hash a C array. " + "If you want to hash the pointer, cast to void*"); + return 0; +} +template +inline uint hashCode(T&&... values) { + uint hashes[] = { hashCode(kj::fwd(values))... }; + return hashCode(kj::ArrayPtr(hashes).asBytes()); +} +// kj::hashCode() is a universal hashing function, like kj::str() is a universal stringification +// function. Throw stuff in, get a hash code. +// +// Hash codes may differ between different processes, even running exactly the same code. +// +// NOT SUITABLE FOR CRYPTOGRAPHY. This is for hash tables, not crypto. + +// ======================================================================================= +// inline implementation details + +namespace _ { // private + +template +inline uint HashCoder::operator*(ArrayPtr arr) const { + // Hash each array element to create a string of hashes, then murmur2 over those. + // + // TODO(perf): Choose a more-modern hash. (See hash.c++.) + + constexpr uint m = 0x5bd1e995; + constexpr uint r = 24; + uint h = arr.size() * sizeof(uint); + + for (auto& e: arr) { + uint k = kj::hashCode(e); + k *= m; + k ^= k >> r; + k *= m; + h *= m; + h ^= k; + } + + h ^= h >> 13; + h *= m; + h ^= h >> 15; + return h; +} +template +inline uint HashCoder::operator*(const Array& arr) const { + return operator*(arr.asPtr()); +} + +template +inline uint HashCoder::operator*(T e) const { + return operator*(static_cast<__underlying_type(T)>(e)); +} + +} // namespace _ (private) +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/io-test.c++ b/c++/src/kj/io-test.c++ index 0eee98ea96..ea7bac413b 100644 --- a/c++/src/kj/io-test.c++ +++ b/c++/src/kj/io-test.c++ @@ -19,6 +19,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include "io.h" #include "debug.h" #include "miniposix.h" @@ -100,12 +104,103 @@ KJ_TEST("VectorOutputStream") { KJ_ASSERT(kj::str(output.getArray().asChars()) == "abcdefghijklmnopABCD"); output.write(junk + 4, 20); - KJ_ASSERT(output.getArray().begin() != buf.begin()); + // (We can't assert output.getArray().begin() != buf.begin() because the memory allocator could + // legitimately have allocated a new array in the same space.) KJ_ASSERT(output.getArray().end() != buf3.begin() + 24); KJ_ASSERT(kj::str(output.getArray().asChars()) == "abcdefghijklmnopABCDEFGHIJKLMNOPQRSTUVWX"); KJ_ASSERT(output.getWriteBuffer().size() == 24); KJ_ASSERT(output.getWriteBuffer().begin() == output.getArray().begin() + 40); + + output.clear(); + KJ_ASSERT(output.getWriteBuffer().begin() == output.getArray().begin()); + KJ_ASSERT(output.getWriteBuffer().size() == 64); + KJ_ASSERT(output.getArray().size() == 0); +} + +class MockInputStream: public InputStream { +public: + MockInputStream(kj::ArrayPtr bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr bytes; + size_t blockSize; +}; + +KJ_TEST("InputStream::readAllText() / readAllBytes()") { + auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ","); + size_t inputSizes[] = { 0, 1, 256, 4096, 8191, 8192, 8193, 10000, bigText.size() }; + size_t blockSizes[] = { 1, 4, 256, 4096, 8192, bigText.size() }; + uint64_t limits[] = { + 0, 1, 256, + bigText.size() / 2, + bigText.size() - 1, + bigText.size(), + bigText.size() + 1, + kj::maxValue + }; + + for (size_t inputSize: inputSizes) { + for (size_t blockSize: blockSizes) { + for (uint64_t limit: limits) { + KJ_CONTEXT(inputSize, blockSize, limit); + auto textSlice = bigText.asBytes().slice(0, inputSize); + auto readAllText = [&]() { + MockInputStream input(textSlice, blockSize); + return input.readAllText(limit); + }; + auto readAllBytes = [&]() { + MockInputStream input(textSlice, blockSize); + return input.readAllBytes(limit); + }; + if (limit > inputSize) { + KJ_EXPECT(readAllText().asBytes() == textSlice); + KJ_EXPECT(readAllBytes() == textSlice); + } else { + KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllText()); + KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllBytes()); + } + } + } + } +} + +KJ_TEST("ArrayOutputStream::write() does not assume adjacent write buffer is its own") { + // Previously, if ArrayOutputStream::write(src, size) saw that `src` equaled its fill position, it + // would assume that the write was already in its buffer. This assumption was buggy if the write + // buffer was directly adjacent in memory to the ArrayOutputStream's buffer, and the + // ArrayOutputStream was full (i.e., its fill position was one-past-the-end). + // + // VectorOutputStream also suffered a similar bug, but it is much harder to test, since it + // performs its own allocation. + + kj::byte buffer[10] = { 0 }; + + ArrayOutputStream output(arrayPtr(buffer, buffer + 5)); + + // Succeeds and fills the ArrayOutputStream. + output.write(buffer + 5, 5); + + // Previously this threw an inscrutable "size <= array.end() - fillPos" requirement failure. + KJ_EXPECT_THROW_MESSAGE( + "backing array was not large enough for the data written", + output.write(buffer + 5, 5)); } } // namespace diff --git a/c++/src/kj/io.c++ b/c++/src/kj/io.c++ index 1db4c486ec..59d12e5863 100644 --- a/c++/src/kj/io.c++ +++ b/c++/src/kj/io.c++ @@ -19,17 +19,22 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#if _WIN32 +#include "win32-api-version.h" +#endif + #include "io.h" #include "debug.h" #include "miniposix.h" #include #include +#include "vector.h" #if _WIN32 -#ifndef NOMINMAX -#define NOMINMAX 1 -#endif -#define WIN32_LEAN_AND_MEAN #include #include "windows-sanity.h" #else @@ -62,6 +67,45 @@ void InputStream::skip(size_t bytes) { } } + +namespace { + +Array readAll(InputStream& input, uint64_t limit, bool nulTerminate) { + Vector> parts; + constexpr size_t BLOCK_SIZE = 4096; + + for (;;) { + KJ_REQUIRE(limit > 0, "Reached limit before EOF."); + auto part = heapArray(kj::min(BLOCK_SIZE, limit)); + size_t n = input.tryRead(part.begin(), part.size(), part.size()); + limit -= n; + if (n < part.size()) { + auto result = heapArray(parts.size() * BLOCK_SIZE + n + nulTerminate); + byte* pos = result.begin(); + for (auto& p: parts) { + memcpy(pos, p.begin(), BLOCK_SIZE); + pos += BLOCK_SIZE; + } + memcpy(pos, part.begin(), n); + pos += n; + if (nulTerminate) *pos++ = '\0'; + KJ_ASSERT(pos == result.end()); + return result; + } else { + parts.add(kj::mv(part)); + } + } +} + +} // namespace + +String InputStream::readAllText(uint64_t limit) { + return String(readAll(*this, limit, true).releaseAsChars()); +} +Array InputStream::readAllBytes(uint64_t limit) { + return readAll(*this, limit, false); +} + void OutputStream::write(ArrayPtr> pieces) { for (auto piece: pieces) { write(piece.begin(), piece.size()); @@ -227,9 +271,9 @@ ArrayPtr ArrayOutputStream::getWriteBuffer() { } void ArrayOutputStream::write(const void* src, size_t size) { - if (src == fillPos) { + if (src == fillPos && fillPos != array.end()) { // Oh goody, the caller wrote directly into our buffer. - KJ_REQUIRE(size <= array.end() - fillPos); + KJ_REQUIRE(size <= array.end() - fillPos, size, fillPos, array.end() - fillPos); fillPos += size; } else { KJ_REQUIRE(size <= (size_t)(array.end() - fillPos), @@ -255,9 +299,9 @@ ArrayPtr VectorOutputStream::getWriteBuffer() { } void VectorOutputStream::write(const void* src, size_t size) { - if (src == fillPos) { + if (src == fillPos && fillPos != vector.end()) { // Oh goody, the caller wrote directly into our buffer. - KJ_REQUIRE(size <= vector.end() - fillPos); + KJ_REQUIRE(size <= vector.end() - fillPos, size, fillPos, vector.end() - fillPos); fillPos += size; } else { if (vector.end() - fillPos < size) { @@ -282,14 +326,13 @@ void VectorOutputStream::grow(size_t minSize) { AutoCloseFd::~AutoCloseFd() noexcept(false) { if (fd >= 0) { - unwindDetector.catchExceptionsIfUnwinding([&]() { - // Don't use SYSCALL() here because close() should not be repeated on EINTR. - if (miniposix::close(fd) < 0) { - KJ_FAIL_SYSCALL("close", errno, fd) { - break; - } + // Don't use SYSCALL() here because close() should not be repeated on EINTR. + if (miniposix::close(fd) < 0) { + KJ_FAIL_SYSCALL("close", errno, fd) { + // This ensures we don't throw an exception if unwinding. + break; } - }); + } } } @@ -334,7 +377,7 @@ void FdOutputStream::write(ArrayPtr> pieces) { OutputStream::write(pieces); #else - const size_t iovmax = miniposix::iovMax(pieces.size()); + const size_t iovmax = miniposix::iovMax(); while (pieces.size() > iovmax) { write(pieces.slice(0, iovmax)); pieces = pieces.slice(iovmax, pieces.size()); diff --git a/c++/src/kj/io.h b/c++/src/kj/io.h index f5c03bfe7b..3edc300ca5 100644 --- a/c++/src/kj/io.h +++ b/c++/src/kj/io.h @@ -19,17 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_IO_H_ -#define KJ_IO_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include #include "common.h" #include "array.h" #include "exception.h" +#include + +KJ_BEGIN_HEADER namespace kj { @@ -66,6 +64,14 @@ class InputStream { virtual void skip(size_t bytes); // Skips past the given number of bytes, discarding them. The default implementation read()s // into a scratch buffer. + + String readAllText(uint64_t limit = kj::maxValue); + Array readAllBytes(uint64_t limit = kj::maxValue); + // Read until EOF and return as one big byte array or string. Throw an exception if EOF is not + // seen before reading `limit` bytes. + // + // To prevent runaway memory allocation, consider using a more conservative value for `limit` than + // the default, particularly on untrusted data streams which may never see EOF. }; class OutputStream { @@ -135,7 +141,7 @@ class BufferedInputStreamWrapper: public BufferedInputStream { // If the second parameter is non-null, the stream uses the given buffer instead of allocating // its own. This may improve performance if the buffer can be reused. - KJ_DISALLOW_COPY(BufferedInputStreamWrapper); + KJ_DISALLOW_COPY_AND_MOVE(BufferedInputStreamWrapper); ~BufferedInputStreamWrapper() noexcept(false); // implements BufferedInputStream ---------------------------------- @@ -161,7 +167,7 @@ class BufferedOutputStreamWrapper: public BufferedOutputStream { // If the second parameter is non-null, the stream uses the given buffer instead of allocating // its own. This may improve performance if the buffer can be reused. - KJ_DISALLOW_COPY(BufferedOutputStreamWrapper); + KJ_DISALLOW_COPY_AND_MOVE(BufferedOutputStreamWrapper); ~BufferedOutputStreamWrapper() noexcept(false); void flush(); @@ -187,7 +193,7 @@ class BufferedOutputStreamWrapper: public BufferedOutputStream { class ArrayInputStream: public BufferedInputStream { public: explicit ArrayInputStream(ArrayPtr array); - KJ_DISALLOW_COPY(ArrayInputStream); + KJ_DISALLOW_COPY_AND_MOVE(ArrayInputStream); ~ArrayInputStream() noexcept(false); // implements BufferedInputStream ---------------------------------- @@ -202,7 +208,7 @@ class ArrayInputStream: public BufferedInputStream { class ArrayOutputStream: public BufferedOutputStream { public: explicit ArrayOutputStream(ArrayPtr array); - KJ_DISALLOW_COPY(ArrayOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(ArrayOutputStream); ~ArrayOutputStream() noexcept(false); ArrayPtr getArray() { @@ -222,7 +228,7 @@ class ArrayOutputStream: public BufferedOutputStream { class VectorOutputStream: public BufferedOutputStream { public: explicit VectorOutputStream(size_t initialCapacity = 4096); - KJ_DISALLOW_COPY(VectorOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(VectorOutputStream); ~VectorOutputStream() noexcept(false); ArrayPtr getArray() { @@ -230,6 +236,8 @@ class VectorOutputStream: public BufferedOutputStream { return arrayPtr(vector.begin(), fillPos); } + void clear() { fillPos = vector.begin(); } + // implements BufferedInputStream ---------------------------------- ArrayPtr getWriteBuffer() override; void write(const void* buffer, size_t size) override; @@ -283,9 +291,15 @@ class AutoCloseFd { inline bool operator==(decltype(nullptr)) { return fd < 0; } inline bool operator!=(decltype(nullptr)) { return fd >= 0; } + inline int release() { + // Release ownership of an FD. Not recommended. + int result = fd; + fd = -1; + return result; + } + private: int fd; - UnwindDetector unwindDetector; }; inline auto KJ_STRINGIFY(const AutoCloseFd& fd) @@ -299,7 +313,7 @@ class FdInputStream: public InputStream { public: explicit FdInputStream(int fd): fd(fd) {} explicit FdInputStream(AutoCloseFd fd): fd(fd), autoclose(mv(fd)) {} - KJ_DISALLOW_COPY(FdInputStream); + KJ_DISALLOW_COPY_AND_MOVE(FdInputStream); ~FdInputStream() noexcept(false); size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; @@ -317,7 +331,7 @@ class FdOutputStream: public OutputStream { public: explicit FdOutputStream(int fd): fd(fd) {} explicit FdOutputStream(AutoCloseFd fd): fd(fd), autoclose(mv(fd)) {} - KJ_DISALLOW_COPY(FdOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(FdOutputStream); ~FdOutputStream() noexcept(false); void write(const void* buffer, size_t size) override; @@ -376,6 +390,13 @@ class AutoCloseHandle { inline bool operator==(decltype(nullptr)) { return handle != (void*)-1; } inline bool operator!=(decltype(nullptr)) { return handle == (void*)-1; } + inline void* release() { + // Release ownership of an FD. Not recommended. + void* result = handle; + handle = (void*)-1; + return result; + } + private: void* handle; // -1 (aka INVALID_HANDLE_VALUE) if not valid. }; @@ -386,7 +407,7 @@ class HandleInputStream: public InputStream { public: explicit HandleInputStream(void* handle): handle(handle) {} explicit HandleInputStream(AutoCloseHandle handle): handle(handle), autoclose(mv(handle)) {} - KJ_DISALLOW_COPY(HandleInputStream); + KJ_DISALLOW_COPY_AND_MOVE(HandleInputStream); ~HandleInputStream() noexcept(false); size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; @@ -402,7 +423,7 @@ class HandleOutputStream: public OutputStream { public: explicit HandleOutputStream(void* handle): handle(handle) {} explicit HandleOutputStream(AutoCloseHandle handle): handle(handle), autoclose(mv(handle)) {} - KJ_DISALLOW_COPY(HandleOutputStream); + KJ_DISALLOW_COPY_AND_MOVE(HandleOutputStream); ~HandleOutputStream() noexcept(false); void write(const void* buffer, size_t size) override; @@ -416,4 +437,4 @@ class HandleOutputStream: public OutputStream { } // namespace kj -#endif // KJ_IO_H_ +KJ_END_HEADER diff --git a/c++/src/kj/list-test.c++ b/c++/src/kj/list-test.c++ new file mode 100644 index 0000000000..9286226e5e --- /dev/null +++ b/c++/src/kj/list-test.c++ @@ -0,0 +1,213 @@ +// Copyright (c) 2021 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "list.h" +#include + +namespace kj { +namespace { + +struct TestElement { + int i; + ListLink link; + + TestElement(int i): i(i) {} +}; + +KJ_TEST("List") { + List list; + KJ_EXPECT(list.empty()); + KJ_EXPECT(list.size() == 0); + + TestElement foo(123); + TestElement bar(456); + TestElement baz(789); + + { + list.add(foo); + KJ_DEFER(list.remove(foo)); + KJ_EXPECT(!list.empty()); + KJ_EXPECT(list.size() == 1); + KJ_EXPECT(list.front().i == 123); + + { + list.add(bar); + KJ_EXPECT(list.size() == 2); + KJ_DEFER(list.remove(bar)); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 123); + ++iter; + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 456); + iter->i = 321; + KJ_EXPECT(bar.i == 321); + ++iter; + KJ_ASSERT(iter == list.end()); + } + + const List& clist = list; + + { + auto iter = clist.begin(); + KJ_ASSERT(iter != clist.end()); + KJ_EXPECT(iter->i == 123); + ++iter; + KJ_ASSERT(iter != clist.end()); + KJ_EXPECT(iter->i == 321); + ++iter; + KJ_ASSERT(iter == clist.end()); + } + + { + list.addFront(baz); + KJ_EXPECT(list.size() == 3); + KJ_DEFER(list.remove(baz)); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 789); + ++iter; + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 123); + ++iter; + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 321); + ++iter; + KJ_ASSERT(iter == list.end()); + } + } + } + + KJ_EXPECT(list.size() == 1); + + KJ_EXPECT(!list.empty()); + KJ_EXPECT(list.front().i == 123); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 123); + ++iter; + KJ_ASSERT(iter == list.end()); + } + } + + KJ_EXPECT(list.empty()); + KJ_EXPECT(list.size() == 0); + + { + list.addFront(bar); + KJ_DEFER(list.remove(bar)); + KJ_EXPECT(!list.empty()); + KJ_EXPECT(list.size() == 1); + KJ_EXPECT(list.front().i == 321); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 321); + ++iter; + KJ_ASSERT(iter == list.end()); + } + + { + list.add(baz); + KJ_EXPECT(list.size() == 2); + KJ_DEFER(list.remove(baz)); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 321); + ++iter; + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 789); + ++iter; + KJ_ASSERT(iter == list.end()); + } + } + } + + KJ_EXPECT(list.empty()); + KJ_EXPECT(list.size() == 0); +} + +KJ_TEST("List remove while iterating") { + List list; + KJ_EXPECT(list.empty()); + + TestElement foo(123); + list.add(foo); + KJ_DEFER(list.remove(foo)); + + TestElement bar(456); + list.add(bar); + + TestElement baz(789); + list.add(baz); + KJ_DEFER(list.remove(baz)); + + KJ_EXPECT(foo.link.isLinked()); + KJ_EXPECT(bar.link.isLinked()); + KJ_EXPECT(baz.link.isLinked()); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 123); + ++iter; + + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 456); + list.remove(*iter); + ++iter; + + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 789); + ++iter; + + KJ_EXPECT(iter == list.end()); + } + + KJ_EXPECT(foo.link.isLinked()); + KJ_EXPECT(!bar.link.isLinked()); + KJ_EXPECT(baz.link.isLinked()); + + { + auto iter = list.begin(); + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 123); + ++iter; + + KJ_ASSERT(iter != list.end()); + KJ_EXPECT(iter->i == 789); + ++iter; + + KJ_EXPECT(iter == list.end()); + } +} + +} // namespace +} // namespace kj diff --git a/c++/src/kj/list.c++ b/c++/src/kj/list.c++ new file mode 100644 index 0000000000..a7aa006c55 --- /dev/null +++ b/c++/src/kj/list.c++ @@ -0,0 +1,46 @@ +// Copyright (c) 2021 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "list.h" +#include "debug.h" + +namespace kj { +namespace _ { + +void throwDoubleAdd() { + kj::throwFatalException(KJ_EXCEPTION(FAILED, + "tried to add element to kj::List but the element is already in a list")); +} +void throwRemovedNotPresent() { + kj::throwFatalException(KJ_EXCEPTION(FAILED, + "tried to remove element from kj::List but the element is not in a list")); +} +void throwRemovedWrongList() { + kj::throwFatalException(KJ_EXCEPTION(FAILED, + "tried to remove element from kj::List but the element is in a different list")); +} +void throwDestroyedWhileInList() { + kj::throwFatalException(KJ_EXCEPTION(FAILED, + "destroyed object that is still in a kj::List")); +} + +} // namespace _ +} // namespace kj diff --git a/c++/src/kj/list.h b/c++/src/kj/list.h new file mode 100644 index 0000000000..4575b0f96e --- /dev/null +++ b/c++/src/kj/list.h @@ -0,0 +1,227 @@ +// Copyright (c) 2021 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "common.h" + +KJ_BEGIN_HEADER + +namespace kj { + +template +class ListLink; + +template T::*link> +class ListIterator; + +namespace _ { // (private) + +KJ_NORETURN(void throwDoubleAdd()); +KJ_NORETURN(void throwRemovedNotPresent()); +KJ_NORETURN(void throwRemovedWrongList()); +KJ_NORETURN(void throwDestroyedWhileInList()); + +} // namespace _ (private) + +template T::*link> +class List { + // A linked list that does no memory allocation. + // + // The list contains elements of type T that are allocated elsewhere. An existing object of type + // T can be added to the list and removed again without doing any heap allocation. This is + // achieved by requiring that T contains a field of type ListLink. A pointer-to-member to + // this field is the second parameter to the `List` template. + // + // kj::List is ideally suited to situations where an object wants to be able to "add itself" to + // a list of objects waiting for a notification, with the ability to remove itself early if it + // wants to stop waiting. With traditional STL containers, these operations would require memory + // allocation. + // + // Example: + // + // struct Item { + // ListLink link; + // // ... other members ... + // }; + // + // kj::List itemList; + // + // Item foo; + // itemList.add(foo); + // itemList.remove(foo); + // + // Note that you MUST manually remove an element from the list before destroying it. ListLinks + // do not automatically unlink themselves because this could lead to subtle thread-safety bugs + // if the List is guarded by a mutex, and that mutex is not currently locked. Normally, you should + // have T's destructor remove it from any lists. You can use `link.isLinked()` to check if the + // item is currently in a list. + // + // kj::List is a doubly-linked list in order to allow O(1) removal of any element given only a + // reference to the element. However, it only supports forward iteration. + // + // When iterating over a kj::List, you can safely remove current element which the iterator + // points to without breaking the iteration. However, removing any *other* element could + // invalidate the iterator. + +public: + List() = default; + KJ_DISALLOW_COPY_AND_MOVE(List); + + bool empty() const { + return head == nullptr; + } + + size_t size() const { + return listSize; + } + + void add(T& element) { + if ((element.*link).prev != nullptr) _::throwDoubleAdd(); + *tail = element; + (element.*link).prev = tail; + tail = &((element.*link).next); + ++listSize; + } + + void addFront(T& element) { + if ((element.*link).prev != nullptr) _::throwDoubleAdd(); + (element.*link).next = head; + (element.*link).prev = &head; + KJ_IF_MAYBE(oldHead, head) { + (oldHead->*link).prev = &(element.*link).next; + } else { + tail = &(element.*link).next; + } + head = element; + ++listSize; + } + + void remove(T& element) { + if ((element.*link).prev == nullptr) _::throwRemovedNotPresent(); + *((element.*link).prev) = (element.*link).next; + KJ_IF_MAYBE(n, (element.*link).next) { + (n->*link).prev = (element.*link).prev; + } else { + if (tail != &((element.*link).next)) _::throwRemovedWrongList(); + tail = (element.*link).prev; + } + (element.*link).next = nullptr; + (element.*link).prev = nullptr; + --listSize; + } + + typedef ListIterator Iterator; + typedef ListIterator ConstIterator; + + Iterator begin() { return Iterator(head); } + Iterator end() { return Iterator(nullptr); } + ConstIterator begin() const { return ConstIterator(head); } + ConstIterator end() const { return ConstIterator(nullptr); } + + T& front() { return *begin(); } + const T& front() const { return *begin(); } + +private: + Maybe head; + Maybe* tail = &head; + size_t listSize = 0; +}; + +template +class ListLink { +public: + ListLink(): next(nullptr), prev(nullptr) {} + ~ListLink() noexcept { + // Intentionally `noexcept` because we want to crash if a dangling pointer was left in a list. + if (prev != nullptr) _::throwDestroyedWhileInList(); + } + KJ_DISALLOW_COPY_AND_MOVE(ListLink); + + bool isLinked() const { return prev != nullptr; } + +private: + Maybe next; + Maybe* prev; + + template U::*link> + friend class List; + template U::*link> + friend class ListIterator; +}; + +template T::*link> +class ListIterator { +public: + ListIterator() = default; + + MaybeConstT& operator*() { + KJ_IREQUIRE(current != nullptr, "tried to dereference end of list"); + return *_::readMaybe(current); + } + const T& operator*() const { + KJ_IREQUIRE(current != nullptr, "tried to dereference end of list"); + return *_::readMaybe(current); + } + MaybeConstT* operator->() { + KJ_IREQUIRE(current != nullptr, "tried to dereference end of list"); + return _::readMaybe(current); + } + const T* operator->() const { + KJ_IREQUIRE(current != nullptr, "tried to dereference end of list"); + return _::readMaybe(current); + } + + inline ListIterator& operator++() { + current = next; + next = current.map([](MaybeConstT& obj) -> kj::Maybe { return (obj.*link).next; }) + .orDefault(nullptr); + return *this; + } + inline ListIterator operator++(int) { + ListIterator result = *this; + ++*this; + return result; + } + + inline bool operator==(const ListIterator& other) const { + return _::readMaybe(current) == _::readMaybe(other.current); + } + inline bool operator!=(const ListIterator& other) const { + return _::readMaybe(current) != _::readMaybe(other.current); + } + +private: + Maybe current; + + Maybe next; + // so that the current item can be removed from the list without invalidating the iterator + + explicit ListIterator(Maybe start) + : current(start), + next(start.map([](MaybeConstT& obj) -> kj::Maybe { return (obj.*link).next; }) + .orDefault(nullptr)) {} + friend class List; +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/main.c++ b/c++/src/kj/main.c++ index 4d84294a82..6b980de94e 100644 --- a/c++/src/kj/main.c++ +++ b/c++/src/kj/main.c++ @@ -19,6 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#if _WIN32 +#include "win32-api-version.h" +#endif + #include "main.h" #include "debug.h" #include "arena.h" @@ -30,10 +38,6 @@ #include #if _WIN32 -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -#define NOMINMAX 1 -#endif #include #include "windows-sanity.h" #else @@ -631,7 +635,7 @@ void MainBuilder::MainImpl::usageError(StringPtr programName, StringPtr message) class MainBuilder::Impl::OptionDisplayOrder { public: - bool operator()(const Option* a, const Option* b) { + bool operator()(const Option* a, const Option* b) const { if (a == b) return false; char aShort = '\0'; diff --git a/c++/src/kj/main.h b/c++/src/kj/main.h index 4dcd804fd4..2533000649 100644 --- a/c++/src/kj/main.h +++ b/c++/src/kj/main.h @@ -19,18 +19,15 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_MAIN_H_ -#define KJ_MAIN_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "array.h" #include "string.h" #include "vector.h" #include "function.h" +KJ_BEGIN_HEADER + namespace kj { class ProcessContext { @@ -404,4 +401,4 @@ class MainBuilder { } // namespace kj -#endif // KJ_MAIN_H_ +KJ_END_HEADER diff --git a/c++/src/kj/map-test.c++ b/c++/src/kj/map-test.c++ new file mode 100644 index 0000000000..42b5846e60 --- /dev/null +++ b/c++/src/kj/map-test.c++ @@ -0,0 +1,221 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "map.h" +#include + +namespace kj { +namespace _ { +namespace { + +KJ_TEST("HashMap") { + HashMap map; + + kj::String ownFoo = kj::str("foo"); + const char* origFoo = ownFoo.begin(); + map.insert(kj::mv(ownFoo), 123); + map.insert(kj::str("bar"), 456); + + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 123); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("bar"_kj)) == 456); + KJ_EXPECT(map.find("baz"_kj) == nullptr); + + map.upsert(kj::str("foo"), 789, [](int& old, uint newValue) { + KJ_EXPECT(old == 123); + KJ_EXPECT(newValue == 789); + old = 4321; + }); + + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 4321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.findEntry("foo"_kj)).key.begin() == origFoo); + + map.upsert(kj::str("foo"), 321); + + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.findEntry("foo"_kj)).key.begin() == origFoo); + + KJ_EXPECT( + map.findOrCreate("foo"_kj, + []() -> HashMap::Entry { KJ_FAIL_ASSERT("shouldn't have been called"); }) + == 321); + KJ_EXPECT(map.findOrCreate("baz"_kj, + [](){ return HashMap::Entry { kj::str("baz"), 654 }; }) == 654); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("baz"_kj)) == 654); + + KJ_EXPECT(map.erase("bar"_kj)); + KJ_EXPECT(map.erase("baz"_kj)); + KJ_EXPECT(!map.erase("qux"_kj)); + + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 321); + KJ_EXPECT(map.size() == 1); + KJ_EXPECT(map.begin()->key == "foo"); + auto iter = map.begin(); + ++iter; + KJ_EXPECT(iter == map.end()); + + map.erase(*map.begin()); + KJ_EXPECT(map.size() == 0); +} + +KJ_TEST("TreeMap") { + TreeMap map; + + kj::String ownFoo = kj::str("foo"); + const char* origFoo = ownFoo.begin(); + map.insert(kj::mv(ownFoo), 123); + map.insert(kj::str("bar"), 456); + + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 123); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("bar"_kj)) == 456); + KJ_EXPECT(map.find("baz"_kj) == nullptr); + + map.upsert(kj::str("foo"), 789, [](int& old, uint newValue) { + KJ_EXPECT(old == 123); + KJ_EXPECT(newValue == 789); + old = 4321; + }); + + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 4321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.findEntry("foo"_kj)).key.begin() == origFoo); + + map.upsert(kj::str("foo"), 321); + + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 321); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.findEntry("foo"_kj)).key.begin() == origFoo); + + KJ_EXPECT( + map.findOrCreate("foo"_kj, + []() -> TreeMap::Entry { KJ_FAIL_ASSERT("shouldn't have been called"); }) + == 321); + KJ_EXPECT(map.findOrCreate("baz"_kj, + [](){ return TreeMap::Entry { kj::str("baz"), 654 }; }) == 654); + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("baz"_kj)) == 654); + + KJ_EXPECT(map.erase("bar"_kj)); + KJ_EXPECT(map.erase("baz"_kj)); + KJ_EXPECT(!map.erase("qux"_kj)); + + KJ_EXPECT(KJ_ASSERT_NONNULL(map.find("foo"_kj)) == 321); + KJ_EXPECT(map.size() == 1); + KJ_EXPECT(map.begin()->key == "foo"); + auto iter = map.begin(); + ++iter; + KJ_EXPECT(iter == map.end()); + + map.erase(*map.begin()); + KJ_EXPECT(map.size() == 0); +} + +KJ_TEST("TreeMap range") { + TreeMap map; + + map.insert(kj::str("foo"), 1); + map.insert(kj::str("bar"), 2); + map.insert(kj::str("baz"), 3); + map.insert(kj::str("qux"), 4); + map.insert(kj::str("corge"), 5); + + { + auto ordered = KJ_MAP(e, map) -> kj::StringPtr { return e.key; }; + KJ_ASSERT(ordered.size() == 5); + KJ_EXPECT(ordered[0] == "bar"); + KJ_EXPECT(ordered[1] == "baz"); + KJ_EXPECT(ordered[2] == "corge"); + KJ_EXPECT(ordered[3] == "foo"); + KJ_EXPECT(ordered[4] == "qux"); + } + + { + auto range = map.range("baz", "foo"); + auto iter = range.begin(); + KJ_EXPECT(iter->key == "baz"); + ++iter; + KJ_EXPECT(iter->key == "corge"); + ++iter; + KJ_EXPECT(iter == range.end()); + } + + map.eraseRange("baz", "foo"); + + { + auto ordered = KJ_MAP(e, map) -> kj::StringPtr { return e.key; }; + KJ_ASSERT(ordered.size() == 3); + KJ_EXPECT(ordered[0] == "bar"); + KJ_EXPECT(ordered[1] == "foo"); + KJ_EXPECT(ordered[2] == "qux"); + } +} + +#if !KJ_NO_EXCEPTIONS +KJ_TEST("HashMap findOrCreate throws") { + HashMap m; + try { + m.findOrCreate(1, []() -> HashMap::Entry { + throw "foo"; + }); + KJ_FAIL_ASSERT("shouldn't get here"); + } catch (const char*) { + // expected + } + + KJ_EXPECT(m.find(1) == nullptr); + m.findOrCreate(1, []() { + return HashMap::Entry { 1, kj::str("ok") }; + }); + + KJ_EXPECT(KJ_ASSERT_NONNULL(m.find(1)) == "ok"); +} +#endif + +template +void testEraseAll(MapType& m) { + m.insert(12, "foo"); + m.insert(83, "bar"); + m.insert(99, "baz"); + m.insert(6, "qux"); + m.insert(55, "corge"); + + auto count = m.eraseAll([](int i, StringPtr s) { + return i == 99 || s == "foo"; + }); + + KJ_EXPECT(count == 2); + KJ_EXPECT(m.size() == 3); + KJ_EXPECT(m.find(12) == nullptr); + KJ_EXPECT(m.find(99) == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(m.find(83)) == "bar"); + KJ_EXPECT(KJ_ASSERT_NONNULL(m.find(6)) == "qux"); + KJ_EXPECT(KJ_ASSERT_NONNULL(m.find(55)) == "corge"); +} + +KJ_TEST("HashMap eraseAll") { + HashMap m; + testEraseAll(m); +} + +KJ_TEST("TreeMap eraseAll") { + TreeMap m; + testEraseAll(m); +} + +} // namespace +} // namespace _ +} // namespace kj diff --git a/c++/src/kj/map.h b/c++/src/kj/map.h new file mode 100644 index 0000000000..4f92a2034e --- /dev/null +++ b/c++/src/kj/map.h @@ -0,0 +1,578 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "table.h" +#include "hash.h" + +KJ_BEGIN_HEADER + +namespace kj { + +template +class HashMap { + // A key/value mapping backed by hashing. + // + // `Key` must be hashable (via a `.hashCode()` method or `KJ_HASHCODE()`; see `hash.h`) and must + // implement `operator==()`. Additionally, when performing lookups, you can use key types other + // than `Key` as long as the other type is also hashable (producing the same hash codes) and + // there is an `operator==` implementation with `Key` on the left and that other type on the + // right. For example, if the key type is `String`, you can pass `StringPtr` to `find()`. + +public: + void reserve(size_t size); + // Pre-allocates space for a map of the given size. + + size_t size() const; + size_t capacity() const; + void clear(); + + struct Entry { + Key key; + Value value; + }; + + Entry* begin(); + Entry* end(); + const Entry* begin() const; + const Entry* end() const; + // Deterministic iteration. If you only ever insert(), iteration order will be insertion order. + // If you erase(), the erased element is swapped with the last element in the ordering. + + Entry& insert(Key key, Value value); + // Inserts a new entry. Throws if the key already exists. + + template + void insertAll(Collection&& collection); + // Given an iterable collection of `Entry`s, inserts all of them into this map. If the + // input is an rvalue, the entries will be moved rather than copied. + + template + Entry& upsert(Key key, Value value, UpdateFunc&& update); + Entry& upsert(Key key, Value value); + // Tries to insert a new entry. However, if a duplicate already exists (according to some index), + // then update(Value& existingValue, Value&& newValue) is called to modify the existing value. + // If no function is provided, the default is to simply replace the value (but not the key). + + template + kj::Maybe find(KeyLike&& key); + template + kj::Maybe find(KeyLike&& key) const; + // Search for a matching key. The input does not have to be of type `Key`; it merely has to + // be something that the Hasher accepts. + // + // Note that the default hasher for String accepts StringPtr. + + template + Value& findOrCreate(KeyLike&& key, Func&& createEntry); + // Like find() but if the key isn't present then call createEntry() to create the corresponding + // entry and insert it. createEntry() must return type `Entry`. + + template + kj::Maybe findEntry(KeyLike&& key); + template + kj::Maybe findEntry(KeyLike&& key) const; + template + Entry& findOrCreateEntry(KeyLike&& key, Func&& createEntry); + // Sometimes you need to see the whole matching Entry, not just the Value. + + template + bool erase(KeyLike&& key); + // Erase the entry with the matching key. + // + // WARNING: This invalidates all pointers and iterators into the map. Use eraseAll() if you need + // to iterate and erase multiple entries. + + void erase(Entry& entry); + // Erase an entry by reference. + + Entry release(Entry& row); + // Erase an entry and return its content by move. + + template ()(instance(), instance()))> + size_t eraseAll(Predicate&& predicate); + // Erase all values for which predicate(key, value) returns true. This scans over the entire map. + +private: + class Callbacks { + public: + inline const Key& keyForRow(const Entry& entry) const { return entry.key; } + inline Key& keyForRow(Entry& entry) const { return entry.key; } + + template + inline bool matches(Entry& e, KeyLike&& key) const { + return e.key == key; + } + template + inline bool matches(const Entry& e, KeyLike&& key) const { + return e.key == key; + } + template + inline auto hashCode(KeyLike&& key) const { + return kj::hashCode(key); + } + }; + + kj::Table> table; +}; + +template +class TreeMap { + // A key/value mapping backed by a B-tree. + // + // `Key` must support `operator<` and `operator==` against other Keys, and against any type + // which you might want to pass to find() (with `Key` always on the left of the comparison). + +public: + void reserve(size_t size); + // Pre-allocates space for a map of the given size. + + size_t size() const; + size_t capacity() const; + void clear(); + + struct Entry { + Key key; + Value value; + }; + + auto begin(); + auto end(); + auto begin() const; + auto end() const; + // Iteration is in sorted order by key. + + Entry& insert(Key key, Value value); + // Inserts a new entry. Throws if the key already exists. + + template + void insertAll(Collection&& collection); + // Given an iterable collection of `Entry`s, inserts all of them into this map. If the + // input is an rvalue, the entries will be moved rather than copied. + + template + Entry& upsert(Key key, Value value, UpdateFunc&& update); + Entry& upsert(Key key, Value value); + // Tries to insert a new entry. However, if a duplicate already exists (according to some index), + // then update(Value& existingValue, Value&& newValue) is called to modify the existing value. + // If no function is provided, the default is to simply replace the value (but not the key). + + template + kj::Maybe find(KeyLike&& key); + template + kj::Maybe find(KeyLike&& key) const; + // Search for a matching key. The input does not have to be of type `Key`; it merely has to + // be something that can be compared against `Key`. + + template + Value& findOrCreate(KeyLike&& key, Func&& createEntry); + // Like find() but if the key isn't present then call createEntry() to create the corresponding + // entry and insert it. createEntry() must return type `Entry`. + + template + kj::Maybe findEntry(KeyLike&& key); + template + kj::Maybe findEntry(KeyLike&& key) const; + template + Entry& findOrCreateEntry(KeyLike&& key, Func&& createEntry); + // Sometimes you need to see the whole matching Entry, not just the Value. + + template + auto range(K1&& k1, K2&& k2); + template + auto range(K1&& k1, K2&& k2) const; + // Returns an iterable range of entries with keys between k1 (inclusive) and k2 (exclusive). + + template + bool erase(KeyLike&& key); + // Erase the entry with the matching key. + // + // WARNING: This invalidates all pointers and iterators into the map. Use eraseAll() if you need + // to iterate and erase multiple entries. + + void erase(Entry& entry); + // Erase an entry by reference. + + Entry release(Entry& row); + // Erase an entry and return its content by move. + + template ()(instance(), instance()))> + size_t eraseAll(Predicate&& predicate); + // Erase all values for which predicate(key, value) returns true. This scans over the entire map. + + template + size_t eraseRange(K1&& k1, K2&& k2); + // Erases all entries with keys between k1 (inclusive) and k2 (exclusive). + +private: + class Callbacks { + public: + inline const Key& keyForRow(const Entry& entry) const { return entry.key; } + inline Key& keyForRow(Entry& entry) const { return entry.key; } + + template + inline bool matches(Entry& e, KeyLike&& key) const { + return e.key == key; + } + template + inline bool matches(const Entry& e, KeyLike&& key) const { + return e.key == key; + } + template + inline bool isBefore(Entry& e, KeyLike&& key) const { + return e.key < key; + } + template + inline bool isBefore(const Entry& e, KeyLike&& key) const { + return e.key < key; + } + }; + + kj::Table> table; +}; + +namespace _ { // private + +class HashSetCallbacks { +public: + template + inline Row& keyForRow(Row& row) const { return row; } + + template + inline bool matches(T& a, U& b) const { return a == b; } + template + inline auto hashCode(KeyLike&& key) const { + return kj::hashCode(key); + } +}; + +class TreeSetCallbacks { +public: + template + inline Row& keyForRow(Row& row) const { return row; } + + template + inline bool matches(T& a, U& b) const { return a == b; } + template + inline bool isBefore(T& a, U& b) const { return a < b; } +}; + +} // namespace _ (private) + +template +class HashSet: public Table> { + // A simple hashtable-based set, using kj::hashCode() and operator==(). + +public: + // Everything is inherited. + + template + inline bool contains(Params&&... params) const { + return this->find(kj::fwd(params)...) != nullptr; + } +}; + +template +class TreeSet: public Table> { + // A simple b-tree-based set, using operator<() and operator==(). + +public: + // Everything is inherited. +}; + +// ======================================================================================= +// inline implementation details + +template +void HashMap::reserve(size_t size) { + table.reserve(size); +} + +template +size_t HashMap::size() const { + return table.size(); +} +template +size_t HashMap::capacity() const { + return table.capacity(); +} +template +void HashMap::clear() { + return table.clear(); +} + +template +typename HashMap::Entry* HashMap::begin() { + return table.begin(); +} +template +typename HashMap::Entry* HashMap::end() { + return table.end(); +} +template +const typename HashMap::Entry* HashMap::begin() const { + return table.begin(); +} +template +const typename HashMap::Entry* HashMap::end() const { + return table.end(); +} + +template +typename HashMap::Entry& HashMap::insert(Key key, Value value) { + return table.insert(Entry { kj::mv(key), kj::mv(value) }); +} + +template +template +void HashMap::insertAll(Collection&& collection) { + return table.insertAll(kj::fwd(collection)); +} + +template +template +typename HashMap::Entry& HashMap::upsert( + Key key, Value value, UpdateFunc&& update) { + return table.upsert(Entry { kj::mv(key), kj::mv(value) }, + [&](Entry& existingEntry, Entry&& newEntry) { + update(existingEntry.value, kj::mv(newEntry.value)); + }); +} + +template +typename HashMap::Entry& HashMap::upsert( + Key key, Value value) { + return table.upsert(Entry { kj::mv(key), kj::mv(value) }, + [&](Entry& existingEntry, Entry&& newEntry) { + existingEntry.value = kj::mv(newEntry.value); + }); +} + +template +template +kj::Maybe HashMap::find(KeyLike&& key) { + return table.find(key).map([](Entry& e) -> Value& { return e.value; }); +} +template +template +kj::Maybe HashMap::find(KeyLike&& key) const { + return table.find(key).map([](const Entry& e) -> const Value& { return e.value; }); +} + +template +template +Value& HashMap::findOrCreate(KeyLike&& key, Func&& createEntry) { + return table.findOrCreate(key, kj::fwd(createEntry)).value; +} + +template +template +kj::Maybe::Entry&> +HashMap::findEntry(KeyLike&& key) { + return table.find(kj::fwd(key)); +} +template +template +kj::Maybe::Entry&> +HashMap::findEntry(KeyLike&& key) const { + return table.find(kj::fwd(key)); +} +template +template +typename HashMap::Entry& +HashMap::findOrCreateEntry(KeyLike&& key, Func&& createEntry) { + return table.findOrCreate(kj::fwd(key), kj::fwd(createEntry)); +} + +template +template +bool HashMap::erase(KeyLike&& key) { + return table.eraseMatch(key); +} + +template +void HashMap::erase(Entry& entry) { + table.erase(entry); +} + +template +typename HashMap::Entry HashMap::release(Entry& entry) { + return table.release(entry); +} + +template +template +size_t HashMap::eraseAll(Predicate&& predicate) { + return table.eraseAll([&](Entry& entry) { + return predicate(entry.key, entry.value); + }); +} + +// ----------------------------------------------------------------------------- + +template +void TreeMap::reserve(size_t size) { + table.reserve(size); +} + +template +size_t TreeMap::size() const { + return table.size(); +} +template +size_t TreeMap::capacity() const { + return table.capacity(); +} +template +void TreeMap::clear() { + return table.clear(); +} + +template +auto TreeMap::begin() { + return table.ordered().begin(); +} +template +auto TreeMap::end() { + return table.ordered().end(); +} +template +auto TreeMap::begin() const { + return table.ordered().begin(); +} +template +auto TreeMap::end() const { + return table.ordered().end(); +} + +template +typename TreeMap::Entry& TreeMap::insert(Key key, Value value) { + return table.insert(Entry { kj::mv(key), kj::mv(value) }); +} + +template +template +void TreeMap::insertAll(Collection&& collection) { + return table.insertAll(kj::fwd(collection)); +} + +template +template +typename TreeMap::Entry& TreeMap::upsert( + Key key, Value value, UpdateFunc&& update) { + return table.upsert(Entry { kj::mv(key), kj::mv(value) }, + [&](Entry& existingEntry, Entry&& newEntry) { + update(existingEntry.value, kj::mv(newEntry.value)); + }); +} + +template +typename TreeMap::Entry& TreeMap::upsert( + Key key, Value value) { + return table.upsert(Entry { kj::mv(key), kj::mv(value) }, + [&](Entry& existingEntry, Entry&& newEntry) { + existingEntry.value = kj::mv(newEntry.value); + }); +} + +template +template +kj::Maybe TreeMap::find(KeyLike&& key) { + return table.find(key).map([](Entry& e) -> Value& { return e.value; }); +} +template +template +kj::Maybe TreeMap::find(KeyLike&& key) const { + return table.find(key).map([](const Entry& e) -> const Value& { return e.value; }); +} + +template +template +Value& TreeMap::findOrCreate(KeyLike&& key, Func&& createEntry) { + return table.findOrCreate(key, kj::fwd(createEntry)).value; +} + +template +template +kj::Maybe::Entry&> +TreeMap::findEntry(KeyLike&& key) { + return table.find(kj::fwd(key)); +} +template +template +kj::Maybe::Entry&> +TreeMap::findEntry(KeyLike&& key) const { + return table.find(kj::fwd(key)); +} +template +template +typename TreeMap::Entry& +TreeMap::findOrCreateEntry(KeyLike&& key, Func&& createEntry) { + return table.findOrCreate(kj::fwd(key), kj::fwd(createEntry)); +} + +template +template +auto TreeMap::range(K1&& k1, K2&& k2) { + return table.range(kj::fwd(k1), kj::fwd(k2)); +} +template +template +auto TreeMap::range(K1&& k1, K2&& k2) const { + return table.range(kj::fwd(k1), kj::fwd(k2)); +} + +template +template +bool TreeMap::erase(KeyLike&& key) { + return table.eraseMatch(key); +} + +template +void TreeMap::erase(Entry& entry) { + table.erase(entry); +} + +template +typename TreeMap::Entry TreeMap::release(Entry& entry) { + return table.release(entry); +} + +template +template +size_t TreeMap::eraseAll(Predicate&& predicate) { + return table.eraseAll([&](Entry& entry) { + return predicate(entry.key, entry.value); + }); +} + +template +template +size_t TreeMap::eraseRange(K1&& k1, K2&& k2) { + return table.eraseRange(kj::fwd(k1), kj::fwd(k2)); +} + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/memory-test.c++ b/c++/src/kj/memory-test.c++ index 2655c2f9d5..9ad3ce9e0a 100644 --- a/c++/src/kj/memory-test.c++ +++ b/c++/src/kj/memory-test.c++ @@ -65,6 +65,441 @@ TEST(Memory, AssignNested) { EXPECT_TRUE(destroyed1 && destroyed2); } +struct DestructionOrderRecorder { + DestructionOrderRecorder(uint& counter, uint& recordTo) + : counter(counter), recordTo(recordTo) {} + ~DestructionOrderRecorder() { + recordTo = ++counter; + } + + uint& counter; + uint& recordTo; +}; + +TEST(Memory, Attach) { + uint counter = 0; + uint destroyed1 = 0; + uint destroyed2 = 0; + uint destroyed3 = 0; + + auto obj1 = kj::heap(counter, destroyed1); + auto obj2 = kj::heap(counter, destroyed2); + auto obj3 = kj::heap(counter, destroyed3); + + auto ptr = obj1.get(); + + Own combined = obj1.attach(kj::mv(obj2), kj::mv(obj3)); + + KJ_EXPECT(combined.get() == ptr); + + KJ_EXPECT(obj1.get() == nullptr); + KJ_EXPECT(obj2.get() == nullptr); + KJ_EXPECT(obj3.get() == nullptr); + KJ_EXPECT(destroyed1 == 0); + KJ_EXPECT(destroyed2 == 0); + KJ_EXPECT(destroyed3 == 0); + + combined = nullptr; + + KJ_EXPECT(destroyed1 == 1, destroyed1); + KJ_EXPECT(destroyed2 == 2, destroyed2); + KJ_EXPECT(destroyed3 == 3, destroyed3); +} + +TEST(Memory, AttachNested) { + uint counter = 0; + uint destroyed1 = 0; + uint destroyed2 = 0; + uint destroyed3 = 0; + + auto obj1 = kj::heap(counter, destroyed1); + auto obj2 = kj::heap(counter, destroyed2); + auto obj3 = kj::heap(counter, destroyed3); + + auto ptr = obj1.get(); + + Own combined = obj1.attach(kj::mv(obj2)).attach(kj::mv(obj3)); + + KJ_EXPECT(combined.get() == ptr); + + KJ_EXPECT(obj1.get() == nullptr); + KJ_EXPECT(obj2.get() == nullptr); + KJ_EXPECT(obj3.get() == nullptr); + KJ_EXPECT(destroyed1 == 0); + KJ_EXPECT(destroyed2 == 0); + KJ_EXPECT(destroyed3 == 0); + + combined = nullptr; + + KJ_EXPECT(destroyed1 == 1, destroyed1); + KJ_EXPECT(destroyed2 == 2, destroyed2); + KJ_EXPECT(destroyed3 == 3, destroyed3); +} + +KJ_TEST("attachRef") { + uint counter = 0; + uint destroyed1 = 0; + uint destroyed2 = 0; + uint destroyed3 = 0; + + auto obj1 = kj::heap(counter, destroyed1); + auto obj2 = kj::heap(counter, destroyed2); + auto obj3 = kj::heap(counter, destroyed3); + + int i = 123; + + Own combined = attachRef(i, kj::mv(obj1), kj::mv(obj2), kj::mv(obj3)); + + KJ_EXPECT(combined.get() == &i); + + KJ_EXPECT(obj1.get() == nullptr); + KJ_EXPECT(obj2.get() == nullptr); + KJ_EXPECT(obj3.get() == nullptr); + KJ_EXPECT(destroyed1 == 0); + KJ_EXPECT(destroyed2 == 0); + KJ_EXPECT(destroyed3 == 0); + + combined = nullptr; + + KJ_EXPECT(destroyed1 == 1, destroyed1); + KJ_EXPECT(destroyed2 == 2, destroyed2); + KJ_EXPECT(destroyed3 == 3, destroyed3); +} + +KJ_TEST("attachVal") { + uint counter = 0; + uint destroyed1 = 0; + uint destroyed2 = 0; + uint destroyed3 = 0; + + auto obj1 = kj::heap(counter, destroyed1); + auto obj2 = kj::heap(counter, destroyed2); + auto obj3 = kj::heap(counter, destroyed3); + + int i = 123; + + Own combined = attachVal(i, kj::mv(obj1), kj::mv(obj2), kj::mv(obj3)); + + int* ptr = combined.get(); + KJ_EXPECT(ptr != &i); + KJ_EXPECT(*ptr == i); + + KJ_EXPECT(obj1.get() == nullptr); + KJ_EXPECT(obj2.get() == nullptr); + KJ_EXPECT(obj3.get() == nullptr); + KJ_EXPECT(destroyed1 == 0); + KJ_EXPECT(destroyed2 == 0); + KJ_EXPECT(destroyed3 == 0); + + combined = nullptr; + + KJ_EXPECT(destroyed1 == 1, destroyed1); + KJ_EXPECT(destroyed2 == 2, destroyed2); + KJ_EXPECT(destroyed3 == 3, destroyed3); +} + +struct StaticType { + int i; +}; + +struct DynamicType1 { + virtual void foo() {} + + int j; + + DynamicType1(int j): j(j) {} +}; + +struct DynamicType2 { + virtual void bar() {} + + int k; + + DynamicType2(int k): k(k) {} +}; + +struct SingularDerivedDynamic final: public DynamicType1 { + SingularDerivedDynamic(int j, bool& destructorCalled) + : DynamicType1(j), destructorCalled(destructorCalled) {} + + ~SingularDerivedDynamic() { + destructorCalled = true; + } + KJ_DISALLOW_COPY_AND_MOVE(SingularDerivedDynamic); + + bool& destructorCalled; +}; + +struct MultipleDerivedDynamic final: public DynamicType1, public DynamicType2 { + MultipleDerivedDynamic(int j, int k, bool& destructorCalled) + : DynamicType1(j), DynamicType2(k), destructorCalled(destructorCalled) {} + + ~MultipleDerivedDynamic() { + destructorCalled = true; + } + + KJ_DISALLOW_COPY_AND_MOVE(MultipleDerivedDynamic); + + bool& destructorCalled; +}; + +TEST(Memory, OwnVoid) { + { + Own ptr = heap({123}); + StaticType* addr = ptr.get(); + Own voidPtr = kj::mv(ptr); + KJ_EXPECT(voidPtr.get() == implicitCast(addr)); + } + + { + bool destructorCalled = false; + Own ptr = heap(123, destructorCalled); + SingularDerivedDynamic* addr = ptr.get(); + Own voidPtr = kj::mv(ptr); + KJ_EXPECT(voidPtr.get() == implicitCast(addr)); + } + + { + bool destructorCalled = false; + Own ptr = heap(123, 456, destructorCalled); + MultipleDerivedDynamic* addr = ptr.get(); + Own voidPtr = kj::mv(ptr); + KJ_EXPECT(voidPtr.get() == implicitCast(addr)); + + KJ_EXPECT(!destructorCalled); + voidPtr = nullptr; + KJ_EXPECT(destructorCalled); + } + + { + bool destructorCalled = false; + Own ptr = heap(123, 456, destructorCalled); + MultipleDerivedDynamic* addr = ptr.get(); + Own basePtr = kj::mv(ptr); + DynamicType2* baseAddr = basePtr.get(); + + // On most (all?) C++ ABIs, the second base class in a multiply-inherited class is offset from + // the beginning of the object (assuming the first base class has non-zero size). We use this + // fact here to verify that then casting to Own does in fact result in a pointer that + // points to the start of the overall object, not the base class. We expect that the pointers + // are different here to prove that the test below is non-trivial. + // + // If there is some other ABI where these pointers are the same, and thus this expectation + // fails, then it's no problem to #ifdef out the expectation on that platform. + KJ_EXPECT(static_cast(baseAddr) != static_cast(addr)); + + Own voidPtr = kj::mv(basePtr); + KJ_EXPECT(voidPtr.get() == static_cast(addr)); + + KJ_EXPECT(!destructorCalled); + voidPtr = nullptr; + KJ_EXPECT(destructorCalled); + } + + { + Maybe> maybe; + maybe = Own(&maybe, NullDisposer::instance); + KJ_EXPECT(KJ_ASSERT_NONNULL(maybe).get() == &maybe); + maybe = nullptr; + KJ_EXPECT(maybe == nullptr); + } +} + +TEST(Memory, OwnConstVoid) { + { + Own ptr = heap({123}); + StaticType* addr = ptr.get(); + Own voidPtr = kj::mv(ptr); + KJ_EXPECT(voidPtr.get() == implicitCast(addr)); + } + + { + bool destructorCalled = false; + Own ptr = heap(123, destructorCalled); + SingularDerivedDynamic* addr = ptr.get(); + Own voidPtr = kj::mv(ptr); + KJ_EXPECT(voidPtr.get() == implicitCast(addr)); + } + + { + bool destructorCalled = false; + Own ptr = heap(123, 456, destructorCalled); + MultipleDerivedDynamic* addr = ptr.get(); + Own voidPtr = kj::mv(ptr); + KJ_EXPECT(voidPtr.get() == implicitCast(addr)); + + KJ_EXPECT(!destructorCalled); + voidPtr = nullptr; + KJ_EXPECT(destructorCalled); + } + + { + bool destructorCalled = false; + Own ptr = heap(123, 456, destructorCalled); + MultipleDerivedDynamic* addr = ptr.get(); + Own basePtr = kj::mv(ptr); + DynamicType2* baseAddr = basePtr.get(); + + // On most (all?) C++ ABIs, the second base class in a multiply-inherited class is offset from + // the beginning of the object (assuming the first base class has non-zero size). We use this + // fact here to verify that then casting to Own does in fact result in a pointer that + // points to the start of the overall object, not the base class. We expect that the pointers + // are different here to prove that the test below is non-trivial. + // + // If there is some other ABI where these pointers are the same, and thus this expectation + // fails, then it's no problem to #ifdef out the expectation on that platform. + KJ_EXPECT(static_cast(baseAddr) != static_cast(addr)); + + Own voidPtr = kj::mv(basePtr); + KJ_EXPECT(voidPtr.get() == static_cast(addr)); + + KJ_EXPECT(!destructorCalled); + voidPtr = nullptr; + KJ_EXPECT(destructorCalled); + } + + { + Maybe> maybe; + maybe = Own(&maybe, NullDisposer::instance); + KJ_EXPECT(KJ_ASSERT_NONNULL(maybe).get() == &maybe); + maybe = nullptr; + KJ_EXPECT(maybe == nullptr); + } +} + +struct IncompleteType; +KJ_DECLARE_NON_POLYMORPHIC(IncompleteType) + +template +struct IncompleteTemplate; +template +KJ_DECLARE_NON_POLYMORPHIC(IncompleteTemplate) + +struct IncompleteDisposer: public Disposer { + mutable void* sawPtr = nullptr; + + virtual void disposeImpl(void* pointer) const { + sawPtr = pointer; + } +}; + +KJ_TEST("Own") { + static int i; + void* ptr = &i; + + { + IncompleteDisposer disposer; + + { + kj::Own foo(reinterpret_cast(ptr), disposer); + kj::Own bar = kj::mv(foo); + } + + KJ_EXPECT(disposer.sawPtr == ptr); + } + + { + IncompleteDisposer disposer; + + { + kj::Own> foo( + reinterpret_cast*>(ptr), disposer); + kj::Own> bar = kj::mv(foo); + } + + KJ_EXPECT(disposer.sawPtr == ptr); + } +} + +KJ_TEST("Own with static disposer") { + static int* disposedPtr = nullptr; + struct MyDisposer { + static void dispose(int* value) { + KJ_EXPECT(disposedPtr == nullptr); + disposedPtr = value; + }; + }; + + int i; + + { + Own ptr(&i); + KJ_EXPECT(disposedPtr == nullptr); + } + KJ_EXPECT(disposedPtr == &i); + disposedPtr = nullptr; + + { + Own ptr(&i); + KJ_EXPECT(disposedPtr == nullptr); + Own ptr2(kj::mv(ptr)); + KJ_EXPECT(disposedPtr == nullptr); + } + KJ_EXPECT(disposedPtr == &i); + disposedPtr = nullptr; + + { + Own ptr2; + { + Own ptr(&i); + KJ_EXPECT(disposedPtr == nullptr); + ptr2 = kj::mv(ptr); + KJ_EXPECT(disposedPtr == nullptr); + } + KJ_EXPECT(disposedPtr == nullptr); + } + KJ_EXPECT(disposedPtr == &i); +} + +KJ_TEST("Maybe>") { + Maybe> m = heap(123); + KJ_EXPECT(m != nullptr); + Maybe mRef = m; + KJ_EXPECT(KJ_ASSERT_NONNULL(mRef) == 123); + KJ_EXPECT(&KJ_ASSERT_NONNULL(mRef) == KJ_ASSERT_NONNULL(m).get()); +} + +#if __cplusplus > 201402L +int* sawIntPtr = nullptr; + +void freeInt(int* ptr) { + sawIntPtr = ptr; + delete ptr; +} + +void freeChar(char* c) { + delete c; +} + +void free(StaticType* ptr) { + delete ptr; +} + +void free(const char* ptr) {} + +KJ_TEST("disposeWith") { + auto i = new int(1); + { + auto p = disposeWith(i); + KJ_EXPECT(sawIntPtr == nullptr); + } + KJ_EXPECT(sawIntPtr == i); + { + auto c = new char('a'); + auto p = disposeWith(c); + } + { + // Explicit cast required to avoid ambiguity when overloads are present. + auto s = new StaticType{1}; + auto p = disposeWith(free)>(s); + } + { + const char c = 'a'; + auto p2 = disposeWith(free)>(&c); + } +} +#endif + // TODO(test): More tests. } // namespace diff --git a/c++/src/kj/memory.h b/c++/src/kj/memory.h index 60912b0a34..099489d4c5 100644 --- a/c++/src/kj/memory.h +++ b/c++/src/kj/memory.h @@ -19,17 +19,93 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_MEMORY_H_ -#define KJ_MEMORY_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "common.h" +KJ_BEGIN_HEADER + namespace kj { +template +inline constexpr bool _kj_internal_isPolymorphic(T*) { + // If you get a compiler error here complaining that T is incomplete, it's because you are trying + // to use kj::Own with a type that has only been forward-declared. Since KJ doesn't know if + // the type might be involved in inheritance (especially multiple inheritance), it doesn't know + // how to correctly call the disposer to destroy the type, since the object's true memory address + // may differ from the address used to point to a superclass. + // + // However, if you know for sure that T is NOT polymorphic (i.e. it doesn't have a vtable and + // isn't involved in inheritance), then you can use KJ_DECLARE_NON_POLYMORPHIC(T) to declare this + // to KJ without actually completing the type. Place this macro invocation either in the global + // scope, or in the same namespace as T is defined. + return __is_polymorphic(T); +} + +#define KJ_DECLARE_NON_POLYMORPHIC(...) \ + inline constexpr bool _kj_internal_isPolymorphic(__VA_ARGS__*) { \ + return false; \ + } +// If you want to use kj::Own for an incomplete type T that you know is not polymorphic, then +// write `KJ_DECLARE_NON_POLYMORPHIC(T)` either at the global scope or in the same namespace as +// T is declared. +// +// This also works for templates, e.g.: +// +// template +// struct MyType; +// template +// KJ_DECLARE_NON_POLYMORPHIC(MyType) + +namespace _ { // private + +template struct RefOrVoid_ { typedef T& Type; }; +template <> struct RefOrVoid_ { typedef void Type; }; +template <> struct RefOrVoid_ { typedef void Type; }; + +template +using RefOrVoid = typename RefOrVoid_::Type; +// Evaluates to T&, unless T is `void`, in which case evaluates to `void`. +// +// This is a hack needed to avoid defining Own as a totally separate class. + +template +struct CastToVoid_; + +template +struct CastToVoid_ { + static void* apply(T* ptr) { + return static_cast(ptr); + } + static const void* applyConst(T* ptr) { + const T* cptr = ptr; + return static_cast(cptr); + } +}; + +template +struct CastToVoid_ { + static void* apply(T* ptr) { + return dynamic_cast(ptr); + } + static const void* applyConst(T* ptr) { + const T* cptr = ptr; + return dynamic_cast(cptr); + } +}; + +template +void* castToVoid(T* ptr) { + return CastToVoid_::apply(ptr); +} + +template +const void* castToConstVoid(T* ptr) { + return CastToVoid_::applyConst(ptr); +} + +} // namespace _ (private) + // ======================================================================================= // Disposer -- Implementation details. @@ -64,7 +140,7 @@ class Disposer { // an exception. private: - template + template struct Dispose_; }; @@ -95,8 +171,11 @@ class NullDisposer: public Disposer { // ======================================================================================= // Own -- An owned pointer. +template +class Own; + template -class Own { +class Own { // A transferrable title to a T. When an Own goes out of scope, the object's Disposer is // called to dispose of it. An Own can be efficiently passed by move, without relocating the // underlying object; this transfers ownership. @@ -120,17 +199,18 @@ class Own { : disposer(other.disposer), ptr(other.ptr) { other.ptr = nullptr; } template ()>> inline Own(Own&& other) noexcept - : disposer(other.disposer), ptr(other.ptr) { - static_assert(__is_polymorphic(T), - "Casting owned pointers requires that the target type is polymorphic."); + : disposer(other.disposer), ptr(cast(other.ptr)) { other.ptr = nullptr; } + template ()>> + inline Own(Own&& other) noexcept; + // Convert statically-disposed Own to dynamically-disposed Own. inline Own(T* ptr, const Disposer& disposer) noexcept: disposer(&disposer), ptr(ptr) {} ~Own() noexcept(false) { dispose(); } inline Own& operator=(Own&& other) { - // Move-assingnment operator. + // Move-assignnment operator. // Careful, this might own `other`. Therefore we have to transfer the pointers first, then // dispose. @@ -150,6 +230,15 @@ class Own { return *this; } + template + Own attach(Attachments&&... attachments) KJ_WARN_UNUSED_RESULT; + // Returns an Own which points to the same object but which also ensures that all values + // passed to `attachments` remain alive until after this object is destroyed. Normally + // `attachments` are other Owns pointing to objects that this one depends on. + // + // Note that attachments will eventually be destroyed in the order they are listed. Hence, + // foo.attach(bar, baz) is equivalent to (but more efficient than) foo.attach(bar).attach(baz). + template Own downcast() { // Downcast the pointer to Own, destroying the original pointer. If this pointer does not @@ -168,8 +257,8 @@ class Own { #define NULLCHECK KJ_IREQUIRE(ptr != nullptr, "null Own<> dereference") inline T* operator->() { NULLCHECK; return ptr; } inline const T* operator->() const { NULLCHECK; return ptr; } - inline T& operator*() { NULLCHECK; return *ptr; } - inline const T& operator*() const { NULLCHECK; return *ptr; } + inline _::RefOrVoid operator*() { NULLCHECK; return *ptr; } + inline _::RefOrVoid operator*() const { NULLCHECK; return *ptr; } #undef NULLCHECK inline T* get() { return ptr; } inline const T* get() const { return ptr; } @@ -197,69 +286,203 @@ class Own { } template + static inline T* cast(U* ptr) { + static_assert(_kj_internal_isPolymorphic((T*)nullptr), + "Casting owned pointers requires that the target type is polymorphic."); + return ptr; + } + + template friend class Own; friend class Maybe>; }; +template <> +template +inline void* Own::cast(U* ptr) { + return _::castToVoid(ptr); +} + +template <> +template +inline const void* Own::cast(U* ptr) { + return _::castToConstVoid(ptr); +} + +template +class Own { + // If a `StaticDisposer` is specified (which is not the norm), then the object will be deleted + // by calling StaticDisposer::dispose(pointer). The pointer passed to `dispose()` could be a + // superclass of `T`, if the pointer has been upcast. + // + // This type can be useful for micro-optimization, if you've found that you are doing excessive + // heap allocations to the point where the virtual call on destruction is costing non-negligible + // resources. You should avoid this unless you have a specific need, because it precludes a lot + // of power. + +public: + KJ_DISALLOW_COPY(Own); + inline Own(): ptr(nullptr) {} + inline Own(Own&& other) noexcept + : ptr(other.ptr) { other.ptr = nullptr; } + inline Own(Own, StaticDisposer>&& other) noexcept + : ptr(other.ptr) { other.ptr = nullptr; } + template ()>> + inline Own(Own&& other) noexcept + : ptr(cast(other.ptr)) { + other.ptr = nullptr; + } + inline explicit Own(T* ptr) noexcept: ptr(ptr) {} + + ~Own() noexcept(false) { dispose(); } + + inline Own& operator=(Own&& other) { + // Move-assignnment operator. + + // Careful, this might own `other`. Therefore we have to transfer the pointers first, then + // dispose. + T* ptrCopy = ptr; + ptr = other.ptr; + other.ptr = nullptr; + if (ptrCopy != nullptr) { + StaticDisposer::dispose(ptrCopy); + } + return *this; + } + + inline Own& operator=(decltype(nullptr)) { + dispose(); + return *this; + } + + template + Own downcast() { + // Downcast the pointer to Own, destroying the original pointer. If this pointer does not + // actually point at an instance of U, the results are undefined (throws an exception in debug + // mode if RTTI is enabled, otherwise you're on your own). + + Own result; + if (ptr != nullptr) { + result.ptr = &kj::downcast(*ptr); + ptr = nullptr; + } + return result; + } + +#define NULLCHECK KJ_IREQUIRE(ptr != nullptr, "null Own<> dereference") + inline T* operator->() { NULLCHECK; return ptr; } + inline const T* operator->() const { NULLCHECK; return ptr; } + inline _::RefOrVoid operator*() { NULLCHECK; return *ptr; } + inline _::RefOrVoid operator*() const { NULLCHECK; return *ptr; } +#undef NULLCHECK + inline T* get() { return ptr; } + inline const T* get() const { return ptr; } + inline operator T*() { return ptr; } + inline operator const T*() const { return ptr; } + +private: + T* ptr; + + inline explicit Own(decltype(nullptr)): ptr(nullptr) {} + + inline bool operator==(decltype(nullptr)) { return ptr == nullptr; } + inline bool operator!=(decltype(nullptr)) { return ptr != nullptr; } + // Only called by Maybe>. + + inline void dispose() { + // Make sure that if an exception is thrown, we are left with a null ptr, so we won't possibly + // dispose again. + T* ptrCopy = ptr; + if (ptrCopy != nullptr) { + ptr = nullptr; + StaticDisposer::dispose(ptrCopy); + } + } + + template + static inline T* cast(U* ptr) { + return ptr; + } + + template + friend class Own; + friend class Maybe>; +}; + namespace _ { // private -template +template class OwnOwn { public: - inline OwnOwn(Own&& value) noexcept: value(kj::mv(value)) {} + inline OwnOwn(Own&& value) noexcept: value(kj::mv(value)) {} - inline Own& operator*() & { return value; } - inline const Own& operator*() const & { return value; } - inline Own&& operator*() && { return kj::mv(value); } - inline const Own&& operator*() const && { return kj::mv(value); } - inline Own* operator->() { return &value; } - inline const Own* operator->() const { return &value; } - inline operator Own*() { return value ? &value : nullptr; } - inline operator const Own*() const { return value ? &value : nullptr; } + inline Own& operator*() & { return value; } + inline const Own& operator*() const & { return value; } + inline Own&& operator*() && { return kj::mv(value); } + inline const Own&& operator*() const && { return kj::mv(value); } + inline Own* operator->() { return &value; } + inline const Own* operator->() const { return &value; } + inline operator Own*() { return value ? &value : nullptr; } + inline operator const Own*() const { return value ? &value : nullptr; } private: - Own value; + Own value; }; -template -OwnOwn readMaybe(Maybe>&& maybe) { return OwnOwn(kj::mv(maybe.ptr)); } -template -Own* readMaybe(Maybe>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; } -template -const Own* readMaybe(const Maybe>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; } +template +OwnOwn readMaybe(Maybe>&& maybe) { return OwnOwn(kj::mv(maybe.ptr)); } +template +Own* readMaybe(Maybe>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; } +template +const Own* readMaybe(const Maybe>& maybe) { + return maybe.ptr ? &maybe.ptr : nullptr; +} } // namespace _ (private) -template -class Maybe> { +template +class Maybe> { public: inline Maybe(): ptr(nullptr) {} - inline Maybe(Own&& t) noexcept: ptr(kj::mv(t)) {} + inline Maybe(Own&& t) noexcept: ptr(kj::mv(t)) {} inline Maybe(Maybe&& other) noexcept: ptr(kj::mv(other.ptr)) {} template - inline Maybe(Maybe>&& other): ptr(mv(other.ptr)) {} + inline Maybe(Maybe>&& other): ptr(mv(other.ptr)) {} template - inline Maybe(Own&& other): ptr(mv(other)) {} + inline Maybe(Own&& other): ptr(mv(other)) {} inline Maybe(decltype(nullptr)) noexcept: ptr(nullptr) {} - inline operator Maybe() { return ptr.get(); } - inline operator Maybe() const { return ptr.get(); } + inline Own& emplace(Own value) { + // Assign the Maybe to the given value and return the content. This avoids the need to do a + // KJ_ASSERT_NONNULL() immediately after setting the Maybe just to read it back again. + ptr = kj::mv(value); + return ptr; + } + + template + inline operator NoInfer>() { return ptr.get(); } + template + inline operator NoInfer>() const { return ptr.get(); } + // Implicit conversion to `Maybe`. The weird templating is to make sure that + // `Maybe>` can be instantiated with the compiler complaining about forming references + // to void -- the use of templates here will cause SFINAE to kick in and hide these, whereas if + // they are not templates then SFINAE isn't applied and so they are considered errors. inline Maybe& operator=(Maybe&& other) { ptr = kj::mv(other.ptr); return *this; } inline bool operator==(decltype(nullptr)) const { return ptr == nullptr; } inline bool operator!=(decltype(nullptr)) const { return ptr != nullptr; } - Own& orDefault(Own& defaultValue) { + Own& orDefault(Own& defaultValue) { if (ptr == nullptr) { return defaultValue; } else { return ptr; } } - const Own& orDefault(const Own& defaultValue) const { + const Own& orDefault(const Own& defaultValue) const { if (ptr == nullptr) { return defaultValue; } else { @@ -267,8 +490,18 @@ class Maybe> { } } + template () ? instance>() : instance()())> + Result orDefault(F&& lazyDefaultValue) && { + if (ptr == nullptr) { + return lazyDefaultValue(); + } else { + return kj::mv(ptr); + } + } + template - auto map(Func&& f) & -> Maybe&>()))> { + auto map(Func&& f) & -> Maybe&>()))> { if (ptr == nullptr) { return nullptr; } else { @@ -277,7 +510,7 @@ class Maybe> { } template - auto map(Func&& f) const & -> Maybe&>()))> { + auto map(Func&& f) const & -> Maybe&>()))> { if (ptr == nullptr) { return nullptr; } else { @@ -286,7 +519,7 @@ class Maybe> { } template - auto map(Func&& f) && -> Maybe&&>()))> { + auto map(Func&& f) && -> Maybe&&>()))> { if (ptr == nullptr) { return nullptr; } else { @@ -295,7 +528,7 @@ class Maybe> { } template - auto map(Func&& f) const && -> Maybe&&>()))> { + auto map(Func&& f) const && -> Maybe&&>()))> { if (ptr == nullptr) { return nullptr; } else { @@ -304,16 +537,16 @@ class Maybe> { } private: - Own ptr; + Own ptr; template friend class Maybe; - template - friend _::OwnOwn _::readMaybe(Maybe>&& maybe); - template - friend Own* _::readMaybe(Maybe>& maybe); - template - friend const Own* _::readMaybe(const Maybe>& maybe); + template + friend _::OwnOwn _::readMaybe(Maybe>&& maybe); + template + friend Own* _::readMaybe(Maybe>& maybe); + template + friend const Own* _::readMaybe(const Maybe>& maybe); }; namespace _ { // private @@ -326,8 +559,30 @@ class HeapDisposer final: public Disposer { static const HeapDisposer instance; }; +#if _MSC_VER && _MSC_VER < 1920 && !defined(__clang__) +template +__declspec(selectany) const HeapDisposer HeapDisposer::instance = HeapDisposer(); +// On MSVC 2017 we suddenly started seeing a linker error on one specific specialization of +// `HeapDisposer::instance` when seemingly-unrelated code was modified. Explicitly specifying +// `__declspec(selectany)` seems to fix it. But why? Shouldn't template members have `selectany` +// behavior by default? We don't know. It works and we're moving on. +#else template const HeapDisposer HeapDisposer::instance = HeapDisposer(); +#endif + +template +class CustomDisposer: public Disposer { +public: + static const CustomDisposer instance; + + void disposeImpl(void* pointer) const override { + (*F)(reinterpret_cast(pointer)); + } +}; + +template +const CustomDisposer CustomDisposer::instance = CustomDisposer(); } // namespace _ (private) @@ -352,6 +607,31 @@ Own> heap(T&& orig) { return Own(new T2(kj::fwd(orig)), _::HeapDisposer::instance); } +#if __cplusplus > 201402L +template +Own disposeWith(T* ptr) { + // Associate a pre-allocated raw pointer with a corresponding disposal function. + // The first template parameter should be a function pointer e.g. disposeWith(new int(0)). + + return Own(ptr, _::CustomDisposer::instance); +} +#endif + +template +Own> attachVal(T&& value, Attachments&&... attachments); +// Returns an Own that takes ownership of `value` and `attachments`, and points to `value`. +// +// This is equivalent to heap(value).attach(attachments), but only does one allocation rather than +// two. + +template +Own attachRef(T& value, Attachments&&... attachments); +// Like attach() but `value` is not moved; the resulting Own points to its existing location. +// This is preferred if `value` is already owned by one of `attachments`. +// +// This is equivalent to Own(&value, kj::NullDisposer::instance).attach(attachments), but +// is easier to write and allocates slightly less memory. + // ======================================================================================= // SpaceFor -- assists in manual allocation @@ -401,6 +681,94 @@ void Disposer::dispose(T* object) const { Dispose_::dispose(object, *this); } +namespace _ { // private + +template +struct OwnedBundle; + +template <> +struct OwnedBundle<> {}; + +template +struct OwnedBundle: public OwnedBundle { + OwnedBundle(First&& first, Rest&&... rest) + : OwnedBundle(kj::fwd(rest)...), first(kj::fwd(first)) {} + + // Note that it's intentional that `first` is destroyed before `rest`. This way, doing + // ptr.attach(foo, bar, baz) is equivalent to ptr.attach(foo).attach(bar).attach(baz) in terms + // of destruction order (although the former does fewer allocations). + Decay first; +}; + +template +struct DisposableOwnedBundle final: public Disposer, public OwnedBundle { + DisposableOwnedBundle(T&&... values): OwnedBundle(kj::fwd(values)...) {} + void disposeImpl(void* pointer) const override { delete this; } +}; + +template +class StaticDisposerAdapter final: public Disposer { + // Adapts a static disposer to be called dynamically. +public: + virtual void disposeImpl(void* pointer) const override { + StaticDisposer::dispose(reinterpret_cast(pointer)); + } + + static const StaticDisposerAdapter instance; +}; + +template +const StaticDisposerAdapter StaticDisposerAdapter::instance = + StaticDisposerAdapter(); + +} // namespace _ (private) + +template +template +Own Own::attach(Attachments&&... attachments) { + T* ptrCopy = ptr; + + KJ_IREQUIRE(ptrCopy != nullptr, "cannot attach to null pointer"); + + // HACK: If someone accidentally calls .attach() on a null pointer in opt mode, try our best to + // accomplish reasonable behavior: We turn the pointer non-null but still invalid, so that the + // disposer will still be called when the pointer goes out of scope. + if (ptrCopy == nullptr) ptrCopy = reinterpret_cast(1); + + auto bundle = new _::DisposableOwnedBundle, Attachments...>( + kj::mv(*this), kj::fwd(attachments)...); + return Own(ptrCopy, *bundle); +} + +template +Own attachRef(T& value, Attachments&&... attachments) { + auto bundle = new _::DisposableOwnedBundle(kj::fwd(attachments)...); + return Own(&value, *bundle); +} + +template +Own> attachVal(T&& value, Attachments&&... attachments) { + auto bundle = new _::DisposableOwnedBundle( + kj::fwd(value), kj::fwd(attachments)...); + return Own>(&bundle->first, *bundle); +} + +template +template +inline Own::Own(Own&& other) noexcept + : ptr(cast(other.ptr)) { + if (_::castToVoid(other.ptr) != reinterpret_cast(other.ptr)) { + // Oh dangit, there's some sort of multiple inheritance going on and `StaticDisposerAdapter` + // won't actually work because it'll receive a pointer pointing to the top of the object, which + // isn't exactly the same as the `U*` pointer it wants. We have no choice but to allocate + // a dynamic disposer here. + disposer = new _::DisposableOwnedBundle>(kj::mv(other)); + } else { + disposer = &_::StaticDisposerAdapter::instance; + other.ptr = nullptr; + } +} + } // namespace kj -#endif // KJ_MEMORY_H_ +KJ_END_HEADER diff --git a/c++/src/kj/miniposix.h b/c++/src/kj/miniposix.h index 111c9bc379..e9ae848d38 100644 --- a/c++/src/kj/miniposix.h +++ b/c++/src/kj/miniposix.h @@ -19,16 +19,11 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_MINIPOSIX_H_ -#define KJ_MINIPOSIX_H_ +#pragma once // This header provides a small subset of the POSIX API which also happens to be available on // Windows under slightly-different names. -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif - #if _WIN32 #include #include @@ -48,6 +43,11 @@ #include #endif +// To get KJ_BEGIN_HEADER/KJ_END_HEADER +#include "common.h" + +KJ_BEGIN_HEADER + namespace kj { namespace miniposix { @@ -101,7 +101,7 @@ using ::close; // We're on Windows, including MinGW. pipe() and mkdir() are non-standard even on MinGW. inline int pipe(int fds[2]) { - return ::_pipe(fds, 8192, _O_BINARY); + return ::_pipe(fds, 8192, _O_BINARY | _O_NOINHERIT); } inline int mkdir(const char* path, int mode) { return ::_mkdir(path); @@ -113,40 +113,32 @@ inline int mkdir(const char* path, int mode) { using ::pipe; using ::mkdir; -inline size_t iovMax(size_t count) { - // Apparently, there is a maximum number of iovecs allowed per call. I don't understand why. - // Most platforms define IOV_MAX but Linux defines only UIO_MAXIOV and others, like Hurd, - // define neither. - // - // On platforms where both IOV_MAX and UIO_MAXIOV are undefined, we poke sysconf(_SC_IOV_MAX), - // then try to fall back to the POSIX-mandated minimum of _XOPEN_IOV_MAX if that fails. - // - // http://pubs.opengroup.org/onlinepubs/9699919799/basedefs/limits.h.html#tag_13_23_03_01 +// Apparently, there is a maximum number of iovecs allowed per call. I don't understand why. +// Most platforms define IOV_MAX but Linux defines only UIO_MAXIOV and others, like Hurd, +// define neither. +// +// On platforms where both IOV_MAX and UIO_MAXIOV are undefined, we poke sysconf(_SC_IOV_MAX), +// then try to fall back to the POSIX-mandated minimum of _XOPEN_IOV_MAX if that fails. +// +// http://pubs.opengroup.org/onlinepubs/9699919799/basedefs/limits.h.html#tag_13_23_03_01 #if defined(IOV_MAX) - // Solaris (and others?) +// Solaris, MacOS (& all other BSD-variants?) (and others?) +static constexpr inline size_t iovMax() { return IOV_MAX; -#elif defined(UIO_MAXIOV) - // Linux - return UIO_MAXIOV; +} +#elif defined(UIO_MAX_IOV) +// Linux +static constexpr inline size_t iovMax() { + return UIO_MAX_IOV; +} #else - // POSIX mystery meat - - long iovmax; - - errno = 0; - if ((iovmax = sysconf(_SC_IOV_MAX)) == -1) { - // assume iovmax == -1 && errno == 0 means "unbounded" - return errno ? _XOPEN_IOV_MAX : count; - } else { - return (size_t) iovmax; - } +#error "Please determine the appropriate constant for IOV_MAX on your system." #endif -} #endif } // namespace miniposix } // namespace kj -#endif // KJ_MINIPOSIX_H_ +KJ_END_HEADER diff --git a/c++/src/kj/mutex-test.c++ b/c++/src/kj/mutex-test.c++ index cbf3486767..1b51e3e7c0 100644 --- a/c++/src/kj/mutex-test.c++ +++ b/c++/src/kj/mutex-test.c++ @@ -19,13 +19,22 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if _WIN32 +#include "win32-api-version.h" +#define NOGDI // NOGDI is needed to make EXPECT_EQ(123u, *lock) compile for some reason +#endif + +#include "time.h" + +#define KJ_MUTEX_TEST 1 + #include "mutex.h" #include "debug.h" #include "thread.h" #include +#include #if _WIN32 -#define NOGDI // NOGDI is needed to make EXPECT_EQ(123u, *lock) compile for some reason #include #undef NOGDI #else @@ -33,6 +42,17 @@ #include #endif +#ifdef KJ_CONTENTION_WARNING_THRESHOLD +#include +#endif + +#if KJ_TRACK_LOCK_BLOCKING +#include +#include +#include +#include +#endif + namespace kj { namespace { @@ -50,6 +70,29 @@ TEST(Mutex, MutexGuarded) { EXPECT_EQ(123u, *lock); EXPECT_EQ(123u, value.getAlreadyLockedExclusive()); +#if KJ_USE_FUTEX + auto timeout = MILLISECONDS * 50; + + auto startTime = systemPreciseMonotonicClock().now(); + EXPECT_TRUE(value.lockExclusiveWithTimeout(timeout) == nullptr); + auto duration = startTime - systemPreciseMonotonicClock().now(); + EXPECT_TRUE(duration < timeout); + + startTime = systemPreciseMonotonicClock().now(); + EXPECT_TRUE(value.lockSharedWithTimeout(timeout) == nullptr); + duration = startTime - systemPreciseMonotonicClock().now(); + EXPECT_TRUE(duration < timeout); + + // originally, upon timing out, the exclusive requested flag would be removed + // from the futex state. if we did remove the exclusive request flag this test + // would hang. + Thread lockTimeoutThread([&]() { + // try to timeout during 10 ms delay + Maybe> maybeLock = value.lockExclusiveWithTimeout(MILLISECONDS * 8); + EXPECT_TRUE(maybeLock == nullptr); + }); +#endif + Thread thread([&]() { Locked threadLock = value.lockExclusive(); EXPECT_EQ(456u, *threadLock); @@ -62,6 +105,11 @@ TEST(Mutex, MutexGuarded) { auto earlyRelease = kj::mv(lock); } +#if KJ_USE_FUTEX + EXPECT_EQ(789u, *KJ_ASSERT_NONNULL(value.lockExclusiveWithTimeout(MILLISECONDS * 50))); + EXPECT_EQ(789u, *KJ_ASSERT_NONNULL(value.lockSharedWithTimeout(MILLISECONDS * 50))); +#endif + EXPECT_EQ(789u, *value.lockExclusive()); { @@ -111,13 +159,322 @@ TEST(Mutex, MutexGuarded) { EXPECT_EQ(321u, *value.lockExclusive()); -#if !_WIN32 // Not checked on win32. +#if !_WIN32 && !__CYGWIN__ // Not checked on win32. EXPECT_DEBUG_ANY_THROW(value.getAlreadyLockedExclusive()); EXPECT_DEBUG_ANY_THROW(value.getAlreadyLockedShared()); #endif EXPECT_EQ(321u, value.getWithoutLock()); } +TEST(Mutex, When) { + MutexGuarded value(123); + + { + uint m = value.when([](uint n) { return n < 200; }, [](uint& n) { + ++n; + return n + 2; + }); + KJ_EXPECT(m == 126); + + KJ_EXPECT(*value.lockShared() == 124); + } + + { + kj::Thread thread([&]() { + delay(); + *value.lockExclusive() = 321; + }); + + uint m = value.when([](uint n) { return n > 200; }, [](uint& n) { + ++n; + return n + 2; + }); + KJ_EXPECT(m == 324); + + KJ_EXPECT(*value.lockShared() == 322); + } + + { + // Stress test. 100 threads each wait for a value and then set the next value. + *value.lockExclusive() = 0; + + auto threads = kj::heapArrayBuilder>(100); + for (auto i: kj::zeroTo(100)) { + threads.add(kj::heap([i,&value]() { + if (i % 2 == 0) delay(); + uint m = value.when([i](const uint& n) { return n == i; }, + [](uint& n) { return n++; }); + KJ_ASSERT(m == i); + })); + } + + uint m = value.when([](uint n) { return n == 100; }, [](uint& n) { + return n++; + }); + KJ_EXPECT(m == 100); + + KJ_EXPECT(*value.lockShared() == 101); + } + +#if !KJ_NO_EXCEPTIONS + { + // Throw from predicate. + KJ_EXPECT_THROW_MESSAGE("oops threw", value.when([](uint n) -> bool { + KJ_FAIL_ASSERT("oops threw"); + }, [](uint& n) { + KJ_FAIL_EXPECT("shouldn't get here"); + })); + + // Throw from predicate later on. + kj::Thread thread([&]() { + delay(); + *value.lockExclusive() = 321; + }); + + KJ_EXPECT_THROW_MESSAGE("oops threw", value.when([](uint n) -> bool { + KJ_ASSERT(n != 321, "oops threw"); + return false; + }, [](uint& n) { + KJ_FAIL_EXPECT("shouldn't get here"); + })); + } + + { + // Verify the exceptions didn't break the mutex. + uint m = value.when([](uint n) { return n > 0; }, [](uint& n) { + return n; + }); + KJ_EXPECT(m == 321); + + kj::Thread thread([&]() { + delay(); + *value.lockExclusive() = 654; + }); + + m = value.when([](uint n) { return n > 500; }, [](uint& n) { + return n; + }); + KJ_EXPECT(m == 654); + } +#endif +} + +TEST(Mutex, WhenWithTimeout) { + auto& clock = systemPreciseMonotonicClock(); + MutexGuarded value(123); + + // A timeout that won't expire. + static constexpr Duration LONG_TIMEOUT = 10 * kj::SECONDS; + + { + uint m = value.when([](uint n) { return n < 200; }, [](uint& n) { + ++n; + return n + 2; + }, LONG_TIMEOUT); + KJ_EXPECT(m == 126); + + KJ_EXPECT(*value.lockShared() == 124); + } + + { + kj::Thread thread([&]() { + delay(); + *value.lockExclusive() = 321; + }); + + uint m = value.when([](uint n) { return n > 200; }, [](uint& n) { + ++n; + return n + 2; + }, LONG_TIMEOUT); + KJ_EXPECT(m == 324); + + KJ_EXPECT(*value.lockShared() == 322); + } + + { + // Stress test. 100 threads each wait for a value and then set the next value. + *value.lockExclusive() = 0; + + auto threads = kj::heapArrayBuilder>(100); + for (auto i: kj::zeroTo(100)) { + threads.add(kj::heap([i,&value]() { + if (i % 2 == 0) delay(); + uint m = value.when([i](const uint& n) { return n == i; }, + [](uint& n) { return n++; }, LONG_TIMEOUT); + KJ_ASSERT(m == i); + })); + } + + uint m = value.when([](uint n) { return n == 100; }, [](uint& n) { + return n++; + }, LONG_TIMEOUT); + KJ_EXPECT(m == 100); + + KJ_EXPECT(*value.lockShared() == 101); + } + + { + auto start = clock.now(); + uint m = value.when([](uint n) { return n == 0; }, [&](uint& n) { + KJ_ASSERT(n == 101); + auto t = clock.now() - start; + KJ_EXPECT(t >= 10 * kj::MILLISECONDS, t); + return 12; + }, 10 * kj::MILLISECONDS); + KJ_EXPECT(m == 12); + + m = value.when([](uint n) { return n == 0; }, [&](uint& n) { + KJ_ASSERT(n == 101); + auto t = clock.now() - start; + KJ_EXPECT(t >= 20 * kj::MILLISECONDS, t); + return 34; + }, 10 * kj::MILLISECONDS); + KJ_EXPECT(m == 34); + + m = value.when([](uint n) { return n > 0; }, [&](uint& n) { + KJ_ASSERT(n == 101); + return 56; + }, LONG_TIMEOUT); + KJ_EXPECT(m == 56); + } + +#if !KJ_NO_EXCEPTIONS + { + // Throw from predicate. + KJ_EXPECT_THROW_MESSAGE("oops threw", value.when([](uint n) -> bool { + KJ_FAIL_ASSERT("oops threw"); + }, [](uint& n) { + KJ_FAIL_EXPECT("shouldn't get here"); + }, LONG_TIMEOUT)); + + // Throw from predicate later on. + kj::Thread thread([&]() { + delay(); + *value.lockExclusive() = 321; + }); + + KJ_EXPECT_THROW_MESSAGE("oops threw", value.when([](uint n) -> bool { + KJ_ASSERT(n != 321, "oops threw"); + return false; + }, [](uint& n) { + KJ_FAIL_EXPECT("shouldn't get here"); + }, LONG_TIMEOUT)); + } + + { + // Verify the exceptions didn't break the mutex. + uint m = value.when([](uint n) { return n > 0; }, [](uint& n) { + return n; + }, LONG_TIMEOUT); + KJ_EXPECT(m == 321); + + auto start = clock.now(); + m = value.when([](uint n) { return n == 0; }, [&](uint& n) { + KJ_EXPECT(clock.now() - start >= 10 * kj::MILLISECONDS); + return n + 1; + }, 10 * kj::MILLISECONDS); + KJ_EXPECT(m == 322); + + kj::Thread thread([&]() { + delay(); + *value.lockExclusive() = 654; + }); + + m = value.when([](uint n) { return n > 500; }, [](uint& n) { + return n; + }, LONG_TIMEOUT); + KJ_EXPECT(m == 654); + } +#endif +} + +TEST(Mutex, WhenWithTimeoutPreciseTiming) { + // Test that MutexGuarded::when() with a timeout sleeps for precisely the right amount of time. + + auto& clock = systemPreciseMonotonicClock(); + + for (uint retryCount = 0; retryCount < 20; retryCount++) { + MutexGuarded value(123); + + auto start = clock.now(); + uint m = value.when([&value](uint n) { + // HACK: Reset the value as a way of testing what happens when the waiting thread is woken + // up but then finds it's not ready yet. + value.getWithoutLock() = 123; + return n == 321; + }, [](uint& n) { + return 456; + }, 100 * kj::MILLISECONDS); + + KJ_EXPECT(m == 456); + + auto t = clock.now() - start; + KJ_EXPECT(t >= 100 * kj::MILLISECONDS); + // Provide a large margin of error here because some operating systems (e.g. Windows) can have + // long timeslices (13ms) and won't schedule more precisely than a timeslice. + if (t <= 120 * kj::MILLISECONDS) { + return; + } + } + KJ_FAIL_ASSERT("time not within expected bounds even after retries"); +} + +TEST(Mutex, WhenWithTimeoutPreciseTimingAfterInterrupt) { + // Test that MutexGuarded::when() with a timeout sleeps for precisely the right amount of time, + // even if the thread is spuriously woken in the middle. + + auto& clock = systemPreciseMonotonicClock(); + + for (uint retryCount = 0; retryCount < 20; retryCount++) { + MutexGuarded value(123); + + kj::Thread thread([&]() { + delay(); + value.lockExclusive().induceSpuriousWakeupForTest(); + }); + + auto start = clock.now(); + uint m = value.when([](uint n) { + return n == 321; + }, [](uint& n) { + return 456; + }, 100 * kj::MILLISECONDS); + + KJ_EXPECT(m == 456); + + auto t = clock.now() - start; + KJ_EXPECT(t >= 100 * kj::MILLISECONDS, t / kj::MILLISECONDS); + // Provide a large margin of error here because some operating systems (e.g. Windows) can have + // long timeslices (13ms) and won't schedule more precisely than a timeslice. + if (t <= 120 * kj::MILLISECONDS) { + return; + } + } + KJ_FAIL_ASSERT("time not within expected bounds even after retries"); +} + +KJ_TEST("wait()s wake each other") { + MutexGuarded value(0); + + { + kj::Thread thread([&]() { + auto lock = value.lockExclusive(); + ++*lock; + lock.wait([](uint value) { return value == 2; }); + ++*lock; + lock.wait([](uint value) { return value == 4; }); + }); + + { + auto lock = value.lockExclusive(); + lock.wait([](uint value) { return value == 1; }); + ++*lock; + lock.wait([](uint value) { return value == 3; }); + ++*lock; + } + } +} + TEST(Mutex, Lazy) { Lazy lazy; volatile bool initStarted = false; @@ -167,5 +524,407 @@ TEST(Mutex, LazyException) { #endif } +class OnlyTouchUnderLock { +public: + OnlyTouchUnderLock(): ptr(nullptr) {} + OnlyTouchUnderLock(MutexGuarded& ref): ptr(&ref) { + ptr->getAlreadyLockedExclusive()++; + } + OnlyTouchUnderLock(OnlyTouchUnderLock&& other): ptr(other.ptr) { + other.ptr = nullptr; + if (ptr) { + // Just verify it's locked. Don't increment because different compilers may or may not + // elide moves. + ptr->getAlreadyLockedExclusive(); + } + } + OnlyTouchUnderLock& operator=(OnlyTouchUnderLock&& other) { + if (ptr) { + ptr->getAlreadyLockedExclusive()++; + } + ptr = other.ptr; + other.ptr = nullptr; + if (ptr) { + // Just verify it's locked. Don't increment because different compilers may or may not + // elide moves. + ptr->getAlreadyLockedExclusive(); + } + return *this; + } + ~OnlyTouchUnderLock() noexcept(false) { + if (ptr != nullptr) { + ptr->getAlreadyLockedExclusive()++; + } + } + + void frob() { + ptr->getAlreadyLockedExclusive()++; + } + +private: + MutexGuarded* ptr; +}; + +KJ_TEST("ExternalMutexGuarded destroy after release") { + MutexGuarded guarded(0); + + { + ExternalMutexGuarded ext; + + { + auto lock = guarded.lockExclusive(); + ext.set(lock, guarded); + KJ_EXPECT(*lock == 1, *lock); + ext.get(lock).frob(); + KJ_EXPECT(*lock == 2, *lock); + } + + { + auto lock = guarded.lockExclusive(); + auto released = ext.release(lock); + KJ_EXPECT(*lock == 2, *lock); + released.frob(); + KJ_EXPECT(*lock == 3, *lock); + } + } + + { + auto lock = guarded.lockExclusive(); + KJ_EXPECT(*lock == 4, *lock); + } +} + +KJ_TEST("ExternalMutexGuarded destroy without release") { + MutexGuarded guarded(0); + + { + ExternalMutexGuarded ext; + + { + auto lock = guarded.lockExclusive(); + ext.set(lock, guarded); + KJ_EXPECT(*lock == 1); + ext.get(lock).frob(); + KJ_EXPECT(*lock == 2); + } + } + + { + auto lock = guarded.lockExclusive(); + KJ_EXPECT(*lock == 3); + } +} + +KJ_TEST("condvar wait with flapping predicate") { + // This used to deadlock under some implementations due to a wait() checking its own predicate + // as part of unlock()ing the mutex. Adding `waiterToSkip` fixed this (and also eliminated a + // redundant call to the predicate). + + MutexGuarded guarded(0); + + Thread thread([&]() { + delay(); + *guarded.lockExclusive() = 1; + }); + + { + auto lock = guarded.lockExclusive(); + bool flap = true; + lock.wait([&](uint i) { + flap = !flap; + return i == 1 || flap; + }); + } +} + +#if KJ_TRACK_LOCK_BLOCKING +#if !__GLIBC_PREREQ(2, 30) +#ifndef SYS_gettid +#error SYS_gettid is unavailable on this system +#endif + +#define gettid() ((pid_t)syscall(SYS_gettid)) +#endif + +KJ_TEST("tracking blocking on mutex acquisition") { + // SIGEV_THREAD is supposed to be "private" to the pthreads implementation, but, as + // usual, the higher-level POSIX API that we're supposed to use sucks: the "handler" runs on + // some other thread, which means the stack trace it prints won't be useful. + // + // So, we cheat and work around libc. + MutexGuarded foo(5); + auto lock = foo.lockExclusive(); + + struct BlockDetected { + volatile bool blockedOnMutexAcquisition; + SourceLocation blockLocation; + } blockingInfo = {}; + + struct sigaction handler; + memset(&handler, 0, sizeof(handler)); + handler.sa_sigaction = [](int, siginfo_t* info, void*) { + auto& blockage = *reinterpret_cast(info->si_value.sival_ptr); + KJ_IF_MAYBE(r, blockedReason()) { + KJ_SWITCH_ONEOF(*r) { + KJ_CASE_ONEOF(b, BlockedOnMutexAcquisition) { + blockage.blockedOnMutexAcquisition = true; + blockage.blockLocation = b.origin; + } + KJ_CASE_ONEOF_DEFAULT {} + } + } + }; + handler.sa_flags = SA_SIGINFO | SA_RESTART; + + sigaction(SIGINT, &handler, nullptr); + + timer_t timer; + struct sigevent event; + memset(&event, 0, sizeof(event)); + event.sigev_notify = SIGEV_THREAD_ID; + event.sigev_signo = SIGINT; + event.sigev_value.sival_ptr = &blockingInfo; + KJ_SYSCALL(event._sigev_un._tid = gettid()); + KJ_SYSCALL(timer_create(CLOCK_MONOTONIC, &event, &timer)); + KJ_DEFER(timer_delete(timer)); + + kj::Duration timeout = 50 * MILLISECONDS; + struct itimerspec spec; + memset(&spec, 0, sizeof(spec)); + spec.it_value.tv_sec = timeout / kj::SECONDS; + spec.it_value.tv_nsec = timeout % kj::SECONDS / kj::NANOSECONDS; + // We can't use KJ_SYSCALL() because it is not async-signal-safe. + KJ_REQUIRE(-1 != timer_settime(timer, 0, &spec, nullptr)); + + kj::SourceLocation expectedBlockLocation; + KJ_REQUIRE(foo.lockSharedWithTimeout(100 * MILLISECONDS, expectedBlockLocation) == nullptr); + + KJ_EXPECT(blockingInfo.blockedOnMutexAcquisition); + KJ_EXPECT(blockingInfo.blockLocation == expectedBlockLocation); +} + +KJ_TEST("tracking blocked on CondVar::wait") { + // SIGEV_THREAD is supposed to be "private" to the pthreads implementation, but, as + // usual, the higher-level POSIX API that we're supposed to use sucks: the "handler" runs on + // some other thread, which means the stack trace it prints won't be useful. + // + // So, we cheat and work around libc. + MutexGuarded foo(5); + auto lock = foo.lockExclusive(); + + struct BlockDetected { + volatile bool blockedOnCondVar; + SourceLocation blockLocation; + } blockingInfo = {}; + + struct sigaction handler; + memset(&handler, 0, sizeof(handler)); + handler.sa_sigaction = [](int, siginfo_t* info, void*) { + auto& blockage = *reinterpret_cast(info->si_value.sival_ptr); + KJ_IF_MAYBE(r, blockedReason()) { + KJ_SWITCH_ONEOF(*r) { + KJ_CASE_ONEOF(b, BlockedOnCondVarWait) { + blockage.blockedOnCondVar = true; + blockage.blockLocation = b.origin; + } + KJ_CASE_ONEOF_DEFAULT {} + } + } + }; + handler.sa_flags = SA_SIGINFO | SA_RESTART; + + sigaction(SIGINT, &handler, nullptr); + + timer_t timer; + struct sigevent event; + memset(&event, 0, sizeof(event)); + event.sigev_notify = SIGEV_THREAD_ID; + event.sigev_signo = SIGINT; + event.sigev_value.sival_ptr = &blockingInfo; + KJ_SYSCALL(event._sigev_un._tid = gettid()); + KJ_SYSCALL(timer_create(CLOCK_MONOTONIC, &event, &timer)); + KJ_DEFER(timer_delete(timer)); + + kj::Duration timeout = 50 * MILLISECONDS; + struct itimerspec spec; + memset(&spec, 0, sizeof(spec)); + spec.it_value.tv_sec = timeout / kj::SECONDS; + spec.it_value.tv_nsec = timeout % kj::SECONDS / kj::NANOSECONDS; + // We can't use KJ_SYSCALL() because it is not async-signal-safe. + KJ_REQUIRE(-1 != timer_settime(timer, 0, &spec, nullptr)); + + SourceLocation waitLocation; + + lock.wait([](const int& value) { + return false; + }, 100 * MILLISECONDS, waitLocation); + + KJ_EXPECT(blockingInfo.blockedOnCondVar); + KJ_EXPECT(blockingInfo.blockLocation == waitLocation); +} + +KJ_TEST("tracking blocked on Once::init") { + // SIGEV_THREAD is supposed to be "private" to the pthreads implementation, but, as + // usual, the higher-level POSIX API that we're supposed to use sucks: the "handler" runs on + // some other thread, which means the stack trace it prints won't be useful. + // + // So, we cheat and work around libc. + struct BlockDetected { + volatile bool blockedOnOnceInit; + SourceLocation blockLocation; + } blockingInfo = {}; + + struct sigaction handler; + memset(&handler, 0, sizeof(handler)); + handler.sa_sigaction = [](int, siginfo_t* info, void*) { + auto& blockage = *reinterpret_cast(info->si_value.sival_ptr); + KJ_IF_MAYBE(r, blockedReason()) { + KJ_SWITCH_ONEOF(*r) { + KJ_CASE_ONEOF(b, BlockedOnOnceInit) { + blockage.blockedOnOnceInit = true; + blockage.blockLocation = b.origin; + } + KJ_CASE_ONEOF_DEFAULT {} + } + } + }; + handler.sa_flags = SA_SIGINFO | SA_RESTART; + + sigaction(SIGINT, &handler, nullptr); + + timer_t timer; + struct sigevent event; + memset(&event, 0, sizeof(event)); + event.sigev_notify = SIGEV_THREAD_ID; + event.sigev_signo = SIGINT; + event.sigev_value.sival_ptr = &blockingInfo; + KJ_SYSCALL(event._sigev_un._tid = gettid()); + KJ_SYSCALL(timer_create(CLOCK_MONOTONIC, &event, &timer)); + KJ_DEFER(timer_delete(timer)); + + Lazy once; + MutexGuarded onceInitializing(false); + + Thread backgroundInit([&] { + once.get([&](SpaceFor& x) { + *onceInitializing.lockExclusive() = true; + usleep(100 * 1000); // 100 ms + return x.construct(5); + }); + }); + + kj::Duration timeout = 50 * MILLISECONDS; + struct itimerspec spec; + memset(&spec, 0, sizeof(spec)); + spec.it_value.tv_sec = timeout / kj::SECONDS; + spec.it_value.tv_nsec = timeout % kj::SECONDS / kj::NANOSECONDS; + // We can't use KJ_SYSCALL() because it is not async-signal-safe. + KJ_REQUIRE(-1 != timer_settime(timer, 0, &spec, nullptr)); + + kj::SourceLocation onceInitializingBlocked; + + onceInitializing.lockExclusive().wait([](const bool& initializing) { + return initializing; + }); + + once.get([](SpaceFor& x) { + return x.construct(5); + }, onceInitializingBlocked); + + KJ_EXPECT(blockingInfo.blockedOnOnceInit); + KJ_EXPECT(blockingInfo.blockLocation == onceInitializingBlocked); +} + +#if KJ_SAVE_ACQUIRED_LOCK_INFO +KJ_TEST("get location of exclusive mutex") { + _::Mutex mutex; + kj::SourceLocation lockAcquisition; + mutex.lock(_::Mutex::EXCLUSIVE, nullptr, lockAcquisition); + KJ_DEFER(mutex.unlock(_::Mutex::EXCLUSIVE)); + + const auto& lockedInfo = mutex.lockedInfo(); + const auto& lockInfo = lockedInfo.get<_::HoldingExclusively>(); + EXPECT_EQ(gettid(), lockInfo.threadHoldingLock()); + KJ_EXPECT(lockInfo.lockAcquiredAt() == lockAcquisition); +} + +KJ_TEST("get location of shared mutex") { + _::Mutex mutex; + kj::SourceLocation lockLocation; + mutex.lock(_::Mutex::SHARED, nullptr, lockLocation); + KJ_DEFER(mutex.unlock(_::Mutex::SHARED)); + + const auto& lockedInfo = mutex.lockedInfo(); + const auto& lockInfo = lockedInfo.get<_::HoldingShared>(); + KJ_EXPECT(lockInfo.lockAcquiredAt() == lockLocation); +} +#endif + +#endif + +#ifdef KJ_CONTENTION_WARNING_THRESHOLD +KJ_TEST("make sure contended mutex warns") { + class Expectation final: public ExceptionCallback { + public: + Expectation(LogSeverity severity, StringPtr substring) : + severity(severity), substring(substring), seen(false) {} + + void logMessage(LogSeverity severity, const char* file, int line, int contextDepth, + String&& text) override { + if (!seen && severity == this->severity) { + if (_::hasSubstring(text, substring)) { + // Match. Ignore it. + seen = true; + return; + } + } + + // Pass up the chain. + ExceptionCallback::logMessage(severity, file, line, contextDepth, kj::mv(text)); + } + + bool hasSeen() const { + return seen; + } + + private: + LogSeverity severity; + StringPtr substring; + bool seen; + UnwindDetector unwindDetector; + }; + + _::Mutex mutex; + LockSourceLocation exclusiveLockLocation; + mutex.lock(_::Mutex::EXCLUSIVE, nullptr, exclusiveLockLocation); + + bool seenContendedLockLog = false; + + auto threads = kj::heapArrayBuilder>(KJ_CONTENTION_WARNING_THRESHOLD); + for (auto i: kj::zeroTo(KJ_CONTENTION_WARNING_THRESHOLD)) { + (void)i; + threads.add(kj::heap([&mutex, &seenContendedLockLog]() { + Expectation expectation(LogSeverity::WARNING, "Acquired contended lock"); + LockSourceLocation sharedLockLocation; + mutex.lock(_::Mutex::SHARED, nullptr, sharedLockLocation); + seenContendedLockLog = seenContendedLockLog || expectation.hasSeen(); + mutex.unlock(_::Mutex::SHARED); + })); + } + + while (mutex.numReadersWaitingForTest() < KJ_CONTENTION_WARNING_THRESHOLD) { + usleep(5 * kj::MILLISECONDS / kj::MICROSECONDS); + } + + { + KJ_EXPECT_LOG(WARNING, "excessively many readers were waiting on this lock"); + mutex.unlock(_::Mutex::EXCLUSIVE); + } + + threads.clear(); + + KJ_ASSERT(seenContendedLockLog); +} +#endif } // namespace } // namespace kj diff --git a/c++/src/kj/mutex.c++ b/c++/src/kj/mutex.c++ index c232a66ecb..edccdf4060 100644 --- a/c++/src/kj/mutex.c++ +++ b/c++/src/kj/mutex.c++ @@ -19,15 +19,18 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#if _WIN32 -#define WIN32_LEAN_AND_MEAN 1 // lolz -#define WINVER 0x0600 -#define _WIN32_WINNT 0x0600 +#if _WIN32 || __CYGWIN__ +#include "win32-api-version.h" #endif #include "mutex.h" #include "debug.h" +#if !_WIN32 && !__CYGWIN__ +#include +#include +#endif + #if KJ_USE_FUTEX #include #include @@ -36,7 +39,13 @@ #ifndef SYS_futex // Missing on Android/Bionic. +#ifdef __NR_futex #define SYS_futex __NR_futex +#elif defined(SYS_futex_time64) +#define SYS_futex SYS_futex_time64 +#else +#error "Need working SYS_futex" +#endif #endif #ifndef FUTEX_WAIT_PRIVATE @@ -45,30 +54,182 @@ #define FUTEX_WAKE_PRIVATE FUTEX_WAKE #endif -#elif _WIN32 +#elif _WIN32 || __CYGWIN__ #include #endif namespace kj { +#if KJ_TRACK_LOCK_BLOCKING +static thread_local const BlockedOnReason* tlsBlockReason __attribute((tls_model("initial-exec"))); +// The initial-exec model ensures that even if this code is part of a shared library built PIC, then +// we still place this variable in the appropriate ELF section so that __tls_get_addr is avoided. +// It's unclear if __tls_get_addr is still not async signal safe in glibc. The only negative +// downside of this approach is that a shared library built with kj & lock tracking will fail if +// dlopen'ed which isn't an intended use-case for the initial implementation. + +Maybe blockedReason() noexcept { + if (tlsBlockReason == nullptr) { + return nullptr; + } + return *tlsBlockReason; +} + +static void setCurrentThreadIsWaitingFor(const BlockedOnReason* meta) { + tlsBlockReason = meta; +} + +static void setCurrentThreadIsNoLongerWaiting() { + tlsBlockReason = nullptr; +} +#elif KJ_USE_FUTEX +struct BlockedOnMutexAcquisition { + constexpr BlockedOnMutexAcquisition(const _::Mutex& mutex, LockSourceLocationArg) {} +}; + +struct BlockedOnCondVarWait { + constexpr BlockedOnCondVarWait(const _::Mutex& mutex, const void *waiter, + LockSourceLocationArg) {} +}; + +struct BlockedOnOnceInit { + constexpr BlockedOnOnceInit(const _::Once& once, LockSourceLocationArg) {} +}; + +struct BlockedOnReason { + constexpr BlockedOnReason(const BlockedOnMutexAcquisition&) {} + constexpr BlockedOnReason(const BlockedOnCondVarWait&) {} + constexpr BlockedOnReason(const BlockedOnOnceInit&) {} +}; + +static void setCurrentThreadIsWaitingFor(const BlockedOnReason* meta) {} +static void setCurrentThreadIsNoLongerWaiting() {} +#endif + namespace _ { // private +#if KJ_USE_FUTEX +constexpr uint Mutex::EXCLUSIVE_HELD; +constexpr uint Mutex::EXCLUSIVE_REQUESTED; +constexpr uint Mutex::SHARED_COUNT_MASK; +#endif + +inline void Mutex::addWaiter(Waiter& waiter) { +#ifdef KJ_DEBUG + assertLockedByCaller(EXCLUSIVE); +#endif + *waitersTail = waiter; + waitersTail = &waiter.next; +} +inline void Mutex::removeWaiter(Waiter& waiter) { +#ifdef KJ_DEBUG + assertLockedByCaller(EXCLUSIVE); +#endif + *waiter.prev = waiter.next; + KJ_IF_MAYBE(next, waiter.next) { + next->prev = waiter.prev; + } else { + KJ_DASSERT(waitersTail == &waiter.next); + waitersTail = waiter.prev; + } +} + +bool Mutex::checkPredicate(Waiter& waiter) { + // Run the predicate from a thread other than the waiting thread, returning true if it's time to + // signal the waiting thread. This is not only when the predicate passes, but also when it + // throws, in which case we want to propagate the exception to the waiting thread. + + if (waiter.exception != nullptr) return true; // don't run again after an exception + + bool result = false; + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + result = waiter.predicate.check(); + })) { + // Exception thrown. + result = true; + waiter.exception = kj::heap(kj::mv(*exception)); + }; + return result; +} + +#if !_WIN32 && !__CYGWIN__ +namespace { + +TimePoint toTimePoint(struct timespec ts) { + return kj::origin() + ts.tv_sec * kj::SECONDS + ts.tv_nsec * kj::NANOSECONDS; +} +TimePoint now() { + struct timespec now; + KJ_SYSCALL(clock_gettime(CLOCK_MONOTONIC, &now)); + return toTimePoint(now); +} +struct timespec toRelativeTimespec(Duration timeout) { + struct timespec ts; + ts.tv_sec = timeout / kj::SECONDS; + ts.tv_nsec = timeout % kj::SECONDS / kj::NANOSECONDS; + return ts; +} +struct timespec toAbsoluteTimespec(TimePoint time) { + return toRelativeTimespec(time - kj::origin()); +} + +} // namespace +#endif + #if KJ_USE_FUTEX // ======================================================================================= // Futex-based implementation (Linux-only) +#if KJ_SAVE_ACQUIRED_LOCK_INFO +#if !__GLIBC_PREREQ(2, 30) +#ifndef SYS_gettid +#error SYS_gettid is unavailable on this system +#endif + +#define gettid() ((pid_t)syscall(SYS_gettid)) +#endif + +static thread_local pid_t tlsTid = gettid(); +#define TRACK_ACQUIRED_TID() tlsTid + +Mutex::AcquiredMetadata Mutex::lockedInfo() const { + auto state = __atomic_load_n(&futex, __ATOMIC_RELAXED); + auto tid = lockedExclusivelyByThread; + auto location = lockAcquiredLocation; + + if (state & EXCLUSIVE_HELD) { + return HoldingExclusively{tid, location}; + } else { + return HoldingShared{location}; + } +} + +#else +#define TRACK_ACQUIRED_TID() 0 +#endif + Mutex::Mutex(): futex(0) {} Mutex::~Mutex() { // This will crash anyway, might as well crash with a nice error message. KJ_ASSERT(futex == 0, "Mutex destroyed while locked.") { break; } } -void Mutex::lock(Exclusivity exclusivity) { +bool Mutex::lock(Exclusivity exclusivity, Maybe timeout, LockSourceLocationArg location) { + BlockedOnReason blockReason = BlockedOnMutexAcquisition{*this, location}; + KJ_DEFER(setCurrentThreadIsNoLongerWaiting()); + + auto spec = timeout.map([](Duration d) { return toRelativeTimespec(d); }); + struct timespec* specp = nullptr; + KJ_IF_MAYBE(s, spec) { + specp = s; + } + switch (exclusivity) { case EXCLUSIVE: for (;;) { uint state = 0; if (KJ_LIKELY(__atomic_compare_exchange_n(&futex, &state, EXCLUSIVE_HELD, false, __ATOMIC_ACQUIRE, __ATOMIC_RELAXED))) { + // Acquired. break; } @@ -84,31 +245,167 @@ void Mutex::lock(Exclusivity exclusivity) { state |= EXCLUSIVE_REQUESTED; } - syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, state, NULL, NULL, 0); + setCurrentThreadIsWaitingFor(&blockReason); + + auto result = syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, state, specp, nullptr, 0); + if (result < 0) { + if (errno == ETIMEDOUT) { + setCurrentThreadIsNoLongerWaiting(); + // We timed out, we can't remove the exclusive request flag (since others might be waiting) + // so we just return false. + return false; + } + } } + acquiredExclusive(TRACK_ACQUIRED_TID(), location); +#if KJ_CONTENTION_WARNING_THRESHOLD + printContendedReader = false; +#endif break; case SHARED: { +#if KJ_CONTENTION_WARNING_THRESHOLD + kj::Maybe contentionWaitStart; +#endif + uint state = __atomic_add_fetch(&futex, 1, __ATOMIC_ACQUIRE); + for (;;) { if (KJ_LIKELY((state & EXCLUSIVE_HELD) == 0)) { // Acquired. break; } +#if KJ_CONTENTION_WARNING_THRESHOLD + if (contentionWaitStart == nullptr) { + // We could have the exclusive mutex tell us how long it was holding the lock. That would + // be the nicest. However, I'm hesitant to bloat the structure. I suspect having a reader + // tell us how long it was waiting for is probably a good proxy. + contentionWaitStart = kj::systemPreciseMonotonicClock().now(); + } +#endif + + setCurrentThreadIsWaitingFor(&blockReason); + // The mutex is exclusively locked by another thread. Since we incremented the counter // already, we just have to wait for it to be unlocked. - syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, state, NULL, NULL, 0); + auto result = syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, state, specp, nullptr, 0); + if (result < 0) { + // If we timeout though, we need to signal that we're not waiting anymore. + if (errno == ETIMEDOUT) { + setCurrentThreadIsNoLongerWaiting(); + state = __atomic_sub_fetch(&futex, 1, __ATOMIC_RELAXED); + + // We may have unlocked since we timed out. So act like we just unlocked the mutex + // and maybe send a wait signal if needed. See Mutex::unlock SHARED case. + if (KJ_UNLIKELY(state == EXCLUSIVE_REQUESTED)) { + if (__atomic_compare_exchange_n( + &futex, &state, 0, false, __ATOMIC_RELAXED, __ATOMIC_RELAXED)) { + // Wake all exclusive waiters. We have to wake all of them because one of them will + // grab the lock while the others will re-establish the exclusive-requested bit. + syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); + } + } + return false; + } + } state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE); } + +#ifdef KJ_CONTENTION_WARNING_THRESHOLD + KJ_IF_MAYBE(start, contentionWaitStart) { + if (__atomic_load_n(&printContendedReader, __ATOMIC_RELAXED)) { + // Double-checked lock avoids the CPU needing to acquire the lock in most cases. + if (__atomic_exchange_n(&printContendedReader, false, __ATOMIC_RELAXED)) { + auto contentionDuration = kj::systemPreciseMonotonicClock().now() - *start; + KJ_LOG(WARNING, "Acquired contended lock", location, contentionDuration, + kj::getStackTrace()); + } + } + } +#endif + + // We just want to record the lock being acquired somewhere but the specific location doesn't + // matter. This does mean that race conditions could occur where a thread might read this + // inconsistently (e.g. filename from 1 lock & function from another). This currently is just + // meant to be a debugging aid for manual analysis so it's OK for that purpose. If it's ever + // required for this to be used for anything else, then this should probably be changed to + // use an additional atomic variable that can ensure only one writer updates this. Or use the + // futex variable to ensure that this is only done for the first one to acquire the lock, + // although there may be thundering herd problems with that whereby there's a long wallclock + // time between when the lock is acquired and when the location is updated (since the first + // locker isn't really guaranteed to be the first one unlocked). + acquiredShared(location); + break; } } + return true; } -void Mutex::unlock(Exclusivity exclusivity) { +void Mutex::unlock(Exclusivity exclusivity, Waiter* waiterToSkip) { switch (exclusivity) { case EXCLUSIVE: { KJ_DASSERT(futex & EXCLUSIVE_HELD, "Unlocked a mutex that wasn't locked."); + +#ifdef KJ_CONTENTION_WARNING_THRESHOLD + auto acquiredLocation = releasingExclusive(); +#endif + + // First check if there are any conditional waiters. Note we only do this when unlocking an + // exclusive lock since under a shared lock the state couldn't have changed. + auto nextWaiter = waitersHead; + for (;;) { + KJ_IF_MAYBE(waiter, nextWaiter) { + nextWaiter = waiter->next; + + if (waiter != waiterToSkip && checkPredicate(*waiter)) { + // This waiter's predicate now evaluates true, so wake it up. + if (waiter->hasTimeout) { + // In this case we need to be careful to make sure the target thread isn't already + // processing a timeout, so we need to do an atomic CAS rather than just a store. + uint expected = 0; + if (__atomic_compare_exchange_n(&waiter->futex, &expected, 1, false, + __ATOMIC_RELEASE, __ATOMIC_RELAXED)) { + // Good, we set it to 1, transferring ownership of the mutex. Continue on below. + } else { + // Looks like the thread already timed out and set its own futex to 1. In that + // case it is going to try to lock the mutex itself, so we should NOT attempt an + // ownership transfer as this will deadlock. + // + // We have two options here: We can continue along the waiter list looking for + // another waiter that's ready to be signaled, or we could drop out of the list + // immediately since we know that another thread is already waiting for the lock + // and will re-evaluate the waiter queue itself when it is done. It feels cleaner + // to me to continue. + continue; + } + } else { + __atomic_store_n(&waiter->futex, 1, __ATOMIC_RELEASE); + } + syscall(SYS_futex, &waiter->futex, FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); + + // We transferred ownership of the lock to this waiter, so we're done now. + return; + } + } else { + // No more waiters. + break; + } + } + +#ifdef KJ_CONTENTION_WARNING_THRESHOLD + uint readerCount; + { + uint oldState = __atomic_load_n(&futex, __ATOMIC_RELAXED); + readerCount = oldState & SHARED_COUNT_MASK; + if (readerCount >= KJ_CONTENTION_WARNING_THRESHOLD) { + // Atomic not needed because we're still holding the exclusive lock. + printContendedReader = true; + } + } +#endif + + // Didn't wake any waiters, so wake normally. uint oldState = __atomic_fetch_and( &futex, ~(EXCLUSIVE_HELD | EXCLUSIVE_REQUESTED), __ATOMIC_RELEASE); @@ -117,7 +414,14 @@ void Mutex::unlock(Exclusivity exclusivity) { // the lock, and we must wake them up. If there are any exclusive waiters, we must wake // them up even if readers are waiting so that at the very least they may re-establish the // EXCLUSIVE_REQUESTED bit that we just removed. - syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, NULL, NULL, 0); + syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); + +#ifdef KJ_CONTENTION_WARNING_THRESHOLD + if (readerCount >= KJ_CONTENTION_WARNING_THRESHOLD) { + KJ_LOG(WARNING, "excessively many readers were waiting on this lock", readerCount, + acquiredLocation, kj::getStackTrace()); + } +#endif } break; } @@ -133,7 +437,7 @@ void Mutex::unlock(Exclusivity exclusivity) { &futex, &state, 0, false, __ATOMIC_RELAXED, __ATOMIC_RELAXED)) { // Wake all exclusive waiters. We have to wake all of them because one of them will // grab the lock while the others will re-establish the exclusive-requested bit. - syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, NULL, NULL, 0); + syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); } } break; @@ -141,7 +445,7 @@ void Mutex::unlock(Exclusivity exclusivity) { } } -void Mutex::assertLockedByCaller(Exclusivity exclusivity) { +void Mutex::assertLockedByCaller(Exclusivity exclusivity) const { switch (exclusivity) { case EXCLUSIVE: KJ_ASSERT(futex & EXCLUSIVE_HELD, @@ -154,7 +458,120 @@ void Mutex::assertLockedByCaller(Exclusivity exclusivity) { } } -void Once::runOnce(Initializer& init) { +void Mutex::wait(Predicate& predicate, Maybe timeout, LockSourceLocationArg location) { + // Add waiter to list. + Waiter waiter { nullptr, waitersTail, predicate, nullptr, 0, timeout != nullptr }; + addWaiter(waiter); + + BlockedOnReason blockReason = BlockedOnCondVarWait{*this, &waiter, location}; + KJ_DEFER(setCurrentThreadIsNoLongerWaiting()); + + // To guarantee that we've re-locked the mutex before scope exit, keep track of whether it is + // currently. + bool currentlyLocked = true; + KJ_DEFER({ + // Infinite timeout for re-obtaining the lock is on purpose because the post-condition for this + // function has to be that the lock state hasn't changed (& we have to be locked when we enter + // since that's how condvars work). + if (!currentlyLocked) lock(EXCLUSIVE, nullptr, location); + removeWaiter(waiter); + }); + + if (!predicate.check()) { + unlock(EXCLUSIVE, &waiter); + currentlyLocked = false; + + struct timespec ts; + struct timespec* tsp = nullptr; + KJ_IF_MAYBE(t, timeout) { + ts = toAbsoluteTimespec(now() + *t); + tsp = &ts; + } + + setCurrentThreadIsWaitingFor(&blockReason); + + // Wait for someone to set our futex to 1. + for (;;) { + // Note we use FUTEX_WAIT_BITSET_PRIVATE + FUTEX_BITSET_MATCH_ANY to get the same effect as + // FUTEX_WAIT_PRIVATE except that the timeout is specified as an absolute time based on + // CLOCK_MONOTONIC. Otherwise, FUTEX_WAIT_PRIVATE interprets it as a relative time, forcing + // us to recompute the time after every iteration. + KJ_SYSCALL_HANDLE_ERRORS(syscall(SYS_futex, + &waiter.futex, FUTEX_WAIT_BITSET_PRIVATE, 0, tsp, nullptr, FUTEX_BITSET_MATCH_ANY)) { + case EAGAIN: + // Indicates that the futex was already non-zero by the time the kernel looked at it. + // Not an error. + break; + case ETIMEDOUT: { + // Wait timed out. This leaves us in a bit of a pickle: Ownership of the mutex was not + // transferred to us from another thread. So, we need to lock it ourselves. But, another + // thread might be in the process of signaling us and transferring ownership. So, we + // first must atomically take control of our destiny. + KJ_ASSERT(timeout != nullptr); + uint expected = 0; + if (__atomic_compare_exchange_n(&waiter.futex, &expected, 1, false, + __ATOMIC_ACQUIRE, __ATOMIC_ACQUIRE)) { + // OK, we set our own futex to 1. That means no other thread will, and so we won't be + // receiving a mutex ownership transfer. We have to lock the mutex ourselves. + setCurrentThreadIsNoLongerWaiting(); + lock(EXCLUSIVE, nullptr, location); + currentlyLocked = true; + return; + } else { + // Oh, someone else actually did signal us, apparently. Let's move on as if the futex + // call told us so. + break; + } + } + default: + KJ_FAIL_SYSCALL("futex(FUTEX_WAIT_PRIVATE)", error); + } + + setCurrentThreadIsNoLongerWaiting(); + + if (__atomic_load_n(&waiter.futex, __ATOMIC_ACQUIRE)) { + // We received a lock ownership transfer from another thread. + currentlyLocked = true; + + // The other thread checked the predicate before the transfer. +#ifdef KJ_DEBUG + assertLockedByCaller(EXCLUSIVE); +#endif + + KJ_IF_MAYBE(exception, waiter.exception) { + // The predicate threw an exception, apparently. Propagate it. + // TODO(someday): Could we somehow have this be a recoverable exception? Presumably we'd + // then want MutexGuarded::when() to skip calling the callback, but then what should it + // return, since it normally returns the callback's result? Or maybe people who disable + // exceptions just really should not write predicates that can throw. + kj::throwFatalException(kj::mv(**exception)); + } + + return; + } + } + } +} + +void Mutex::induceSpuriousWakeupForTest() { + auto nextWaiter = waitersHead; + for (;;) { + KJ_IF_MAYBE(waiter, nextWaiter) { + nextWaiter = waiter->next; + syscall(SYS_futex, &waiter->futex, FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); + } else { + // No more waiters. + break; + } + } +} + +uint Mutex::numReadersWaitingForTest() const { + assertLockedByCaller(EXCLUSIVE); + return futex & SHARED_COUNT_MASK; +} + +void Once::runOnce(Initializer& init, LockSourceLocationArg location) { startOver: uint state = UNINITIALIZED; if (__atomic_compare_exchange_n(&futex, &state, INITIALIZING, false, @@ -166,7 +583,7 @@ startOver: if (__atomic_exchange_n(&futex, UNINITIALIZED, __ATOMIC_RELEASE) == INITIALIZING_WITH_WAITERS) { // Someone was waiting for us to finish. - syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, NULL, NULL, 0); + syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); } }); @@ -175,9 +592,12 @@ startOver: if (__atomic_exchange_n(&futex, INITIALIZED, __ATOMIC_RELEASE) == INITIALIZING_WITH_WAITERS) { // Someone was waiting for us to finish. - syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, NULL, NULL, 0); + syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); } } else { + BlockedOnReason blockReason = BlockedOnOnceInit{*this, location}; + KJ_DEFER(setCurrentThreadIsNoLongerWaiting()); + for (;;) { if (state == INITIALIZED) { break; @@ -193,7 +613,9 @@ startOver: } // Wait for initialization. - syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0); + setCurrentThreadIsWaitingFor(&blockReason); + syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, + nullptr, nullptr, 0); state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE); if (state == UNINITIALIZED) { @@ -213,12 +635,13 @@ void Once::reset() { } } -#elif _WIN32 +#elif _WIN32 || __CYGWIN__ // ======================================================================================= // Win32 implementation #define coercedSrwLock (*reinterpret_cast(&srwLock)) #define coercedInitOnce (*reinterpret_cast(&initOnce)) +#define coercedCondvar(var) (*reinterpret_cast(&var)) Mutex::Mutex() { static_assert(sizeof(SRWLOCK) == sizeof(srwLock), "SRWLOCK is not a pointer?"); @@ -226,7 +649,10 @@ Mutex::Mutex() { } Mutex::~Mutex() {} -void Mutex::lock(Exclusivity exclusivity) { +bool Mutex::lock(Exclusivity exclusivity, Maybe timeout, NoopSourceLocation) { + if (timeout != nullptr) { + KJ_UNIMPLEMENTED("Locking a mutex with a timeout is only supported on Linux."); + } switch (exclusivity) { case EXCLUSIVE: AcquireSRWLockExclusive(&coercedSrwLock); @@ -235,26 +661,157 @@ void Mutex::lock(Exclusivity exclusivity) { AcquireSRWLockShared(&coercedSrwLock); break; } + return true; +} + +void Mutex::wakeReadyWaiter(Waiter* waiterToSkip) { + // Look for a waiter whose predicate is now evaluating true, and wake it. We wake no more than + // one waiter because only one waiter could get the lock anyway, and once it releases that lock + // it will awake the next waiter if necessary. + + auto nextWaiter = waitersHead; + for (;;) { + KJ_IF_MAYBE(waiter, nextWaiter) { + nextWaiter = waiter->next; + + if (waiter != waiterToSkip && checkPredicate(*waiter)) { + // This waiter's predicate now evaluates true, so wake it up. It doesn't matter if we + // use Wake vs. WakeAll here since there's always only one thread waiting. + WakeConditionVariable(&coercedCondvar(waiter->condvar)); + + // We only need to wake one waiter. Note that unlike the futex-based implementation, we + // cannot "transfer ownership" of the lock to the waiter, therefore we cannot guarantee + // that the condition is still true when that waiter finally awakes. However, if the + // condition is no longer true at that point, the waiter will re-check all other + // waiters' conditions and possibly wake up any other waiter who is now ready, hence we + // still only need to wake one waiter here. + return; + } + } else { + // No more waiters. + break; + } + } } -void Mutex::unlock(Exclusivity exclusivity) { +void Mutex::unlock(Exclusivity exclusivity, Waiter* waiterToSkip) { switch (exclusivity) { - case EXCLUSIVE: - ReleaseSRWLockExclusive(&coercedSrwLock); + case EXCLUSIVE: { + KJ_DEFER(ReleaseSRWLockExclusive(&coercedSrwLock)); + + // Check if there are any conditional waiters. Note we only do this when unlocking an + // exclusive lock since under a shared lock the state couldn't have changed. + wakeReadyWaiter(waiterToSkip); break; + } + case SHARED: ReleaseSRWLockShared(&coercedSrwLock); break; } } -void Mutex::assertLockedByCaller(Exclusivity exclusivity) { +void Mutex::assertLockedByCaller(Exclusivity exclusivity) const { // We could use TryAcquireSRWLock*() here like we do with the pthread version. However, as of // this writing, my version of Wine (1.6.2) doesn't implement these functions and will abort if // they are called. Since we were only going to use them as a hacky way to check if the lock is // held for debug purposes anyway, we just don't bother. } +void Mutex::wait(Predicate& predicate, Maybe timeout, NoopSourceLocation) { + // Add waiter to list. + Waiter waiter { nullptr, waitersTail, predicate, nullptr, 0 }; + static_assert(sizeof(waiter.condvar) == sizeof(CONDITION_VARIABLE), + "CONDITION_VARIABLE is not a pointer?"); + InitializeConditionVariable(&coercedCondvar(waiter.condvar)); + + addWaiter(waiter); + KJ_DEFER(removeWaiter(waiter)); + + DWORD sleepMs; + + // Only initialized if `timeout` is non-null. + const MonotonicClock* clock = nullptr; + kj::Maybe endTime; + + KJ_IF_MAYBE(t, timeout) { + // Windows sleeps are inaccurate -- they can be longer *or shorter* than the requested amount. + // For many use cases of our API, a too-short sleep would be unacceptable. Experimentally, it + // seems like sleeps can be up to half a millisecond short, so we'll add half a millisecond + // (and then we round up, below). + *t += 500 * kj::MICROSECONDS; + + // Compute initial sleep time. + sleepMs = *t / kj::MILLISECONDS; + if (*t % kj::MILLISECONDS > 0 * kj::SECONDS) { + // We guarantee we won't wake up too early. + ++sleepMs; + } + + clock = &systemPreciseMonotonicClock(); + endTime = clock->now() + *t; + } else { + sleepMs = INFINITE; + } + + while (!predicate.check()) { + // SleepConditionVariableSRW() will temporarily release the lock, so we need to signal other + // waiters that are now ready. + wakeReadyWaiter(&waiter); + + if (SleepConditionVariableSRW(&coercedCondvar(waiter.condvar), &coercedSrwLock, sleepMs, 0)) { + // Normal result. Continue loop to check predicate. + } else { + DWORD error = GetLastError(); + if (error == ERROR_TIMEOUT) { + // Windows may have woken us up too early, so don't return yet. Instead, proceed through the + // loop and rely on our sleep time recalculation to detect if we timed out. + } else { + KJ_FAIL_WIN32("SleepConditionVariableSRW()", error); + } + } + + KJ_IF_MAYBE(exception, waiter.exception) { + // The predicate threw an exception, apparently. Propagate it. + // TODO(someday): Could we somehow have this be a recoverable exception? Presumably we'd + // then want MutexGuarded::when() to skip calling the callback, but then what should it + // return, since it normally returns the callback's result? Or maybe people who disable + // exceptions just really should not write predicates that can throw. + kj::throwFatalException(kj::mv(**exception)); + } + + // Recompute sleep time. + KJ_IF_MAYBE(e, endTime) { + auto now = clock->now(); + + if (*e > now) { + auto sleepTime = *e - now; + sleepMs = sleepTime / kj::MILLISECONDS; + if (sleepTime % kj::MILLISECONDS > 0 * kj::SECONDS) { + // We guarantee we won't wake up too early. + ++sleepMs; + } + } else { + // Oops, already timed out. + return; + } + } + } +} + +void Mutex::induceSpuriousWakeupForTest() { + auto nextWaiter = waitersHead; + for (;;) { + KJ_IF_MAYBE(waiter, nextWaiter) { + nextWaiter = waiter->next; + WakeConditionVariable(&coercedCondvar(waiter->condvar)); + } else { + // No more waiters. + break; + } + } +} + static BOOL WINAPI nullInitializer(PINIT_ONCE initOnce, PVOID parameter, PVOID* context) { return true; } @@ -268,7 +825,7 @@ Once::Once(bool startInitialized) { } Once::~Once() {} -void Once::runOnce(Initializer& init) { +void Once::runOnce(Initializer& init, NoopSourceLocation) { BOOL needInit; while (!InitOnceBeginInitialize(&coercedInitOnce, 0, &needInit, nullptr)) { // Init was occurring in another thread, but then failed with an exception. Retry. @@ -313,14 +870,21 @@ void Once::reset() { } \ } -Mutex::Mutex() { - KJ_PTHREAD_CALL(pthread_rwlock_init(&mutex, nullptr)); +Mutex::Mutex(): mutex(PTHREAD_RWLOCK_INITIALIZER) { +#if defined(__ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__) && __ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__ < 1070 + // In older versions of MacOS, mutexes initialized statically cannot be destroyed, + // so we must call the init function. + KJ_PTHREAD_CALL(pthread_rwlock_init(&mutex, NULL)); +#endif } Mutex::~Mutex() { KJ_PTHREAD_CLEANUP(pthread_rwlock_destroy(&mutex)); } -void Mutex::lock(Exclusivity exclusivity) { +bool Mutex::lock(Exclusivity exclusivity, Maybe timeout, NoopSourceLocation) { + if (timeout != nullptr) { + KJ_UNIMPLEMENTED("Locking a mutex with a timeout is only supported on Linux."); + } switch (exclusivity) { case EXCLUSIVE: KJ_PTHREAD_CALL(pthread_rwlock_wrlock(&mutex)); @@ -329,13 +893,44 @@ void Mutex::lock(Exclusivity exclusivity) { KJ_PTHREAD_CALL(pthread_rwlock_rdlock(&mutex)); break; } + return true; } -void Mutex::unlock(Exclusivity exclusivity) { - KJ_PTHREAD_CALL(pthread_rwlock_unlock(&mutex)); +void Mutex::unlock(Exclusivity exclusivity, Waiter* waiterToSkip) { + KJ_DEFER(KJ_PTHREAD_CALL(pthread_rwlock_unlock(&mutex))); + + if (exclusivity == EXCLUSIVE) { + // Check if there are any conditional waiters. Note we only do this when unlocking an + // exclusive lock since under a shared lock the state couldn't have changed. + auto nextWaiter = waitersHead; + for (;;) { + KJ_IF_MAYBE(waiter, nextWaiter) { + nextWaiter = waiter->next; + + if (waiter != waiterToSkip && checkPredicate(*waiter)) { + // This waiter's predicate now evaluates true, so wake it up. It doesn't matter if we + // use _signal() vs. _broadcast() here since there's always only one thread waiting. + KJ_PTHREAD_CALL(pthread_mutex_lock(&waiter->stupidMutex)); + KJ_PTHREAD_CALL(pthread_cond_signal(&waiter->condvar)); + KJ_PTHREAD_CALL(pthread_mutex_unlock(&waiter->stupidMutex)); + + // We only need to wake one waiter. Note that unlike the futex-based implementation, we + // cannot "transfer ownership" of the lock to the waiter, therefore we cannot guarantee + // that the condition is still true when that waiter finally awakes. However, if the + // condition is no longer true at that point, the waiter will re-check all other waiters' + // conditions and possibly wake up any other waiter who is now ready, hence we still only + // need to wake one waiter here. + break; + } + } else { + // No more waiters. + break; + } + } + } } -void Mutex::assertLockedByCaller(Exclusivity exclusivity) { +void Mutex::assertLockedByCaller(Exclusivity exclusivity) const { switch (exclusivity) { case EXCLUSIVE: // A read lock should fail if the mutex is already held for writing. @@ -355,14 +950,141 @@ void Mutex::assertLockedByCaller(Exclusivity exclusivity) { } } -Once::Once(bool startInitialized): state(startInitialized ? INITIALIZED : UNINITIALIZED) { - KJ_PTHREAD_CALL(pthread_mutex_init(&mutex, nullptr)); +void Mutex::wait(Predicate& predicate, Maybe timeout, NoopSourceLocation) { + // Add waiter to list. + Waiter waiter { + nullptr, waitersTail, predicate, nullptr, + PTHREAD_COND_INITIALIZER, PTHREAD_MUTEX_INITIALIZER + }; + +#if defined(__ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__) && __ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__ < 1070 + // In older versions of MacOS, mutexes initialized statically cannot be destroyed, + // so we must call the init function. + KJ_PTHREAD_CALL(pthread_cond_init(&waiter.condvar, NULL)); + KJ_PTHREAD_CALL(pthread_mutex_init(&waiter.stupidMutex, NULL)); +#endif + + addWaiter(waiter); + + // To guarantee that we've re-locked the mutex before scope exit, keep track of whether it is + // currently. + bool currentlyLocked = true; + KJ_DEFER({ + if (!currentlyLocked) lock(EXCLUSIVE, nullptr, NoopSourceLocation{}); + removeWaiter(waiter); + + // Destroy pthread objects. + KJ_PTHREAD_CLEANUP(pthread_mutex_destroy(&waiter.stupidMutex)); + KJ_PTHREAD_CLEANUP(pthread_cond_destroy(&waiter.condvar)); + }); + +#if !__APPLE__ + if (timeout != nullptr) { + // Oops, the default condvar uses the wall clock, which is dumb... fix it to use the monotonic + // clock. (Except not on macOS, where pthread_condattr_setclock() is unimplemented, but there's + // a bizarre pthread_cond_timedwait_relative_np() method we can use instead...) + pthread_condattr_t attr; + KJ_PTHREAD_CALL(pthread_condattr_init(&attr)); + KJ_PTHREAD_CALL(pthread_condattr_setclock(&attr, CLOCK_MONOTONIC)); + pthread_cond_init(&waiter.condvar, &attr); + KJ_PTHREAD_CALL(pthread_condattr_destroy(&attr)); + } +#endif + + Maybe endTime = timeout.map([](Duration d) { + return toAbsoluteTimespec(now() + d); + }); + + while (!predicate.check()) { + // pthread condvars only work with basic mutexes, not rwlocks. So, we need to lock a basic + // mutex before we unlock the real mutex, and the signaling thread also needs to lock this + // mutex, in order to ensure that this thread is actually waiting on the condvar before it is + // signaled. + KJ_PTHREAD_CALL(pthread_mutex_lock(&waiter.stupidMutex)); + + // OK, now we can unlock the main mutex. + unlock(EXCLUSIVE, &waiter); + currentlyLocked = false; + + bool timedOut = false; + + // Wait for someone to signal the condvar. + KJ_IF_MAYBE(t, endTime) { +#if __APPLE__ + // On macOS, the absolute timeout can only be specified in wall time, not monotonic time, + // which means modifying the system clock will break the wait. However, macOS happens to + // provide an alternative relative-time wait function, so I guess we'll use that. It does + // require recomputing the time every iteration... + struct timespec ts = toRelativeTimespec(kj::max(toTimePoint(*t) - now(), 0 * kj::SECONDS)); + int error = pthread_cond_timedwait_relative_np(&waiter.condvar, &waiter.stupidMutex, &ts); +#else + int error = pthread_cond_timedwait(&waiter.condvar, &waiter.stupidMutex, t); +#endif + if (error != 0) { + if (error == ETIMEDOUT) { + timedOut = true; + } else { + KJ_FAIL_SYSCALL("pthread_cond_timedwait", error); + } + } + } else { + KJ_PTHREAD_CALL(pthread_cond_wait(&waiter.condvar, &waiter.stupidMutex)); + } + + // We have to be very careful about lock ordering here. We need to unlock stupidMutex before + // re-locking the main mutex, because another thread may have a lock on the main mutex already + // and be waiting for a lock on stupidMutex. Note that other thread may signal the condvar + // right after we unlock stupidMutex but before we re-lock the main mutex. That is fine, + // because we've already been signaled. + KJ_PTHREAD_CALL(pthread_mutex_unlock(&waiter.stupidMutex)); + + lock(EXCLUSIVE, nullptr, NoopSourceLocation{}); + currentlyLocked = true; + + KJ_IF_MAYBE(exception, waiter.exception) { + // The predicate threw an exception, apparently. Propagate it. + // TODO(someday): Could we somehow have this be a recoverable exception? Presumably we'd + // then want MutexGuarded::when() to skip calling the callback, but then what should it + // return, since it normally returns the callback's result? Or maybe people who disable + // exceptions just really should not write predicates that can throw. + kj::throwFatalException(kj::mv(**exception)); + } + + if (timedOut) { + return; + } + } +} + +void Mutex::induceSpuriousWakeupForTest() { + auto nextWaiter = waitersHead; + for (;;) { + KJ_IF_MAYBE(waiter, nextWaiter) { + nextWaiter = waiter->next; + KJ_PTHREAD_CALL(pthread_mutex_lock(&waiter->stupidMutex)); + KJ_PTHREAD_CALL(pthread_cond_signal(&waiter->condvar)); + KJ_PTHREAD_CALL(pthread_mutex_unlock(&waiter->stupidMutex)); + } else { + // No more waiters. + break; + } + } +} + +Once::Once(bool startInitialized) + : state(startInitialized ? INITIALIZED : UNINITIALIZED), + mutex(PTHREAD_MUTEX_INITIALIZER) { +#if defined(__ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__) && __ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__ < 1070 + // In older versions of MacOS, mutexes initialized statically cannot be destroyed, + // so we must call the init function. + KJ_PTHREAD_CALL(pthread_mutex_init(&mutex, NULL)); +#endif } Once::~Once() { KJ_PTHREAD_CLEANUP(pthread_mutex_destroy(&mutex)); } -void Once::runOnce(Initializer& init) { +void Once::runOnce(Initializer& init, NoopSourceLocation) { KJ_PTHREAD_CALL(pthread_mutex_lock(&mutex)); KJ_DEFER(KJ_PTHREAD_CALL(pthread_mutex_unlock(&mutex))); diff --git a/c++/src/kj/mutex.h b/c++/src/kj/mutex.h index d211ebfeb1..619e7f95d4 100644 --- a/c++/src/kj/mutex.h +++ b/c++/src/kj/mutex.h @@ -19,54 +19,164 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_MUTEX_H_ -#define KJ_MUTEX_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once +#include "debug.h" #include "memory.h" #include +#include "time.h" +#include "source-location.h" +#include "one-of.h" + +KJ_BEGIN_HEADER #if __linux__ && !defined(KJ_USE_FUTEX) #define KJ_USE_FUTEX 1 #endif -#if !KJ_USE_FUTEX && !_WIN32 -// On Linux we use futex. On other platforms we wrap pthreads. +#if !KJ_USE_FUTEX && !_WIN32 && !__CYGWIN__ +// We fall back to pthreads when we don't have a better platform-specific primitive. pthreads +// mutexes are bloated, though, so we like to avoid them. Hence on Linux we use futex(), and on +// Windows we use SRW locks and friends. On Cygwin we prefer the Win32 primitives both because they +// are more efficient and because I ran into problems with Cygwin's implementation of RW locks +// seeming to allow multiple threads to lock the same mutex (but I didn't investigate very +// closely). +// // TODO(someday): Write efficient low-level locking primitives for other platforms. #include #endif +// There are 3 macros controlling lock tracking: +// KJ_TRACK_LOCK_BLOCKING will set up async signal safe TLS variables that can be used to identify +// the KJ primitive blocking the current thread. +// KJ_SAVE_ACQUIRED_LOCK_INFO will allow introspection of a Mutex to get information about what is +// currently holding the lock. +// KJ_TRACK_LOCK_ACQUISITION is automatically enabled by either one of them. + +#if KJ_TRACK_LOCK_BLOCKING +// Lock tracking is required to keep track of what blocked. +#define KJ_TRACK_LOCK_ACQUISITION 1 +#endif + +#if KJ_SAVE_ACQUIRED_LOCK_INFO +#define KJ_TRACK_LOCK_ACQUISITION 1 +#include +#endif + namespace kj { +#if KJ_TRACK_LOCK_ACQUISITION +#if !KJ_USE_FUTEX +#error Lock tracking is only currently supported for futex-based mutexes. +#endif + +#if !KJ_COMPILER_SUPPORTS_SOURCE_LOCATION +#error C++20 or newer is required (or the use of clang/gcc). +#endif + +using LockSourceLocation = SourceLocation; +using LockSourceLocationArg = const SourceLocation&; +// On x86-64 the codegen is optimal if the argument has type const& for the location. However, +// since this conflicts with the optimal call signature for NoopSourceLocation, +// LockSourceLocationArg is used to conditionally select the right type without polluting the usage +// themselves. Interestingly this makes no difference on ARM. +// https://godbolt.org/z/q6G8ee5a3 +#else +using LockSourceLocation = NoopSourceLocation; +using LockSourceLocationArg = NoopSourceLocation; +#endif + + +class Exception; // ======================================================================================= // Private details -- public interfaces follow below. namespace _ { // private +#if KJ_SAVE_ACQUIRED_LOCK_INFO +class HoldingExclusively { + // The lock is being held in exclusive mode. +public: + constexpr HoldingExclusively(pid_t tid, const SourceLocation& location) + : heldBy(tid), acquiredAt(location) {} + + pid_t threadHoldingLock() const { return heldBy; } + const SourceLocation& lockAcquiredAt() const { return acquiredAt; } + +private: + pid_t heldBy; + SourceLocation acquiredAt; +}; + +class HoldingShared { + // The lock is being held in shared mode currently. Which threads are holding this lock open + // is unknown. +public: + constexpr HoldingShared(const SourceLocation& location) : acquiredAt(location) {} + + const SourceLocation& lockAcquiredAt() const { return acquiredAt; } + +private: + SourceLocation acquiredAt; +}; +#endif + class Mutex { // Internal implementation details. See `MutexGuarded`. + struct Waiter; + public: Mutex(); ~Mutex(); - KJ_DISALLOW_COPY(Mutex); + KJ_DISALLOW_COPY_AND_MOVE(Mutex); enum Exclusivity { EXCLUSIVE, SHARED }; - void lock(Exclusivity exclusivity); - void unlock(Exclusivity exclusivity); + bool lock(Exclusivity exclusivity, Maybe timeout, LockSourceLocationArg location); + void unlock(Exclusivity exclusivity, Waiter* waiterToSkip = nullptr); - void assertLockedByCaller(Exclusivity exclusivity); + void assertLockedByCaller(Exclusivity exclusivity) const; // In debug mode, assert that the mutex is locked by the calling thread, or if that is // non-trivial, assert that the mutex is locked (which should be good enough to catch problems // in unit tests). In non-debug builds, do nothing. + class Predicate { + public: + virtual bool check() = 0; + }; + + void wait(Predicate& predicate, Maybe timeout, LockSourceLocationArg location); + // If predicate.check() returns false, unlock the mutex until predicate.check() returns true, or + // when the timeout (if any) expires. The mutex is always re-locked when this returns regardless + // of whether the timeout expired, and including if it throws. + // + // Requires that the mutex is already exclusively locked before calling. + + void induceSpuriousWakeupForTest(); + // Utility method for mutex-test.c++ which causes a spurious thread wakeup on all threads that + // are waiting for a wait() condition. Assuming correct implementation, all those threads + // should immediately go back to sleep. + +#if KJ_USE_FUTEX + uint numReadersWaitingForTest() const; + // The number of reader locks that are currently blocked on this lock (must be called while + // holding the writer lock). This is really only a utility method for mutex-test.c++ so it can + // validate certain invariants. +#endif + +#if KJ_SAVE_ACQUIRED_LOCK_INFO + using AcquiredMetadata = kj::OneOf; + KJ_DISABLE_TSAN AcquiredMetadata lockedInfo() const; + // Returns metadata about this lock when its held. This method is async signal safe. It must also + // be called in a state where it's guaranteed that the lock state won't be released by another + // thread. In other words this has to be called from the signal handler within the thread that's + // holding the lock. +#endif + private: #if KJ_USE_FUTEX uint futex; @@ -76,16 +186,75 @@ class Mutex { // waiting for a read lock, otherwise it is the count of threads that currently hold a read // lock. +#ifdef KJ_CONTENTION_WARNING_THRESHOLD + bool printContendedReader = false; +#endif + static constexpr uint EXCLUSIVE_HELD = 1u << 31; static constexpr uint EXCLUSIVE_REQUESTED = 1u << 30; static constexpr uint SHARED_COUNT_MASK = EXCLUSIVE_REQUESTED - 1; -#elif _WIN32 +#elif _WIN32 || __CYGWIN__ uintptr_t srwLock; // Actually an SRWLOCK, but don't want to #include in header. #else mutable pthread_rwlock_t mutex; #endif + +#if KJ_SAVE_ACQUIRED_LOCK_INFO + pid_t lockedExclusivelyByThread = 0; + SourceLocation lockAcquiredLocation; + + KJ_DISABLE_TSAN void acquiredExclusive(pid_t tid, const SourceLocation& location) noexcept { + lockAcquiredLocation = location; + __atomic_store_n(&lockedExclusivelyByThread, tid, __ATOMIC_RELAXED); + } + + KJ_DISABLE_TSAN void acquiredShared(const SourceLocation& location) noexcept { + lockAcquiredLocation = location; + } + + KJ_DISABLE_TSAN SourceLocation releasingExclusive() noexcept { + auto tmp = lockAcquiredLocation; + lockAcquiredLocation = SourceLocation{}; + lockedExclusivelyByThread = 0; + return tmp; + } +#else + static constexpr void acquiredExclusive(uint, LockSourceLocationArg) {} + static constexpr void acquiredShared(LockSourceLocationArg) {} + static constexpr NoopSourceLocation releasingExclusive() { return NoopSourceLocation{}; } +#endif + struct Waiter { + kj::Maybe next; + kj::Maybe* prev; + Predicate& predicate; + Maybe> exception; +#if KJ_USE_FUTEX + uint futex; + bool hasTimeout; +#elif _WIN32 || __CYGWIN__ + uintptr_t condvar; + // Actually CONDITION_VARIABLE, but don't want to #include in header. +#else + pthread_cond_t condvar; + + pthread_mutex_t stupidMutex; + // pthread condvars are only compatible with basic pthread mutexes, not rwlocks, for no + // particularly good reason. To work around this, we need an extra mutex per condvar. +#endif + }; + + kj::Maybe waitersHead = nullptr; + kj::Maybe* waitersTail = &waitersHead; + // linked list of waiters; can only modify under lock + + inline void addWaiter(Waiter& waiter); + inline void removeWaiter(Waiter& waiter); + bool checkPredicate(Waiter& waiter); +#if _WIN32 || __CYGWIN__ + void wakeReadyWaiter(Waiter* waiterToSkip); +#endif }; class Once { @@ -99,16 +268,16 @@ class Once { Once(bool startInitialized = false); ~Once(); #endif - KJ_DISALLOW_COPY(Once); + KJ_DISALLOW_COPY_AND_MOVE(Once); class Initializer { public: virtual void run() = 0; }; - void runOnce(Initializer& init); + void runOnce(Initializer& init, LockSourceLocationArg location); -#if _WIN32 // TODO(perf): Can we make this inline on win32 somehow? +#if _WIN32 || __CYGWIN__ // TODO(perf): Can we make this inline on win32 somehow? bool isInitialized() noexcept; #else @@ -138,7 +307,7 @@ class Once { INITIALIZED }; -#elif _WIN32 +#elif _WIN32 || __CYGWIN__ uintptr_t initOnce; // Actually an INIT_ONCE, but don't want to #include in header. #else @@ -196,6 +365,32 @@ class Locked { inline operator T*() { return ptr; } inline operator const T*() const { return ptr; } + template + void wait(Cond&& condition, Maybe timeout = nullptr, + LockSourceLocationArg location = {}) { + // Unlocks the lock until `condition(state)` evaluates true (where `state` is type `const T&` + // referencing the object protected by the lock). + + // We can't wait on a shared lock because the internal bookkeeping needed for a wait requires + // the protection of an exclusive lock. + static_assert(!isConst(), "cannot wait() on shared lock"); + + struct PredicateImpl final: public _::Mutex::Predicate { + bool check() override { + return condition(value); + } + + Cond&& condition; + const T& value; + + PredicateImpl(Cond&& condition, const T& value) + : condition(kj::fwd(condition)), value(value) {} + }; + + PredicateImpl impl(kj::fwd(condition), *ptr); + mutex->wait(impl, timeout, location); + } + private: _::Mutex* mutex; T* ptr; @@ -204,6 +399,16 @@ class Locked { template friend class MutexGuarded; + template + friend class ExternalMutexGuarded; + +#if KJ_MUTEX_TEST +public: +#endif + void induceSpuriousWakeupForTest() { mutex->induceSpuriousWakeupForTest(); } + // Utility method for mutex-test.c++ which causes a spurious thread wakeup on all threads that + // are waiting for a when() condition. Assuming correct implementation, all those threads should + // immediately go back to sleep. }; template @@ -225,7 +430,7 @@ class MutexGuarded { explicit MutexGuarded(Params&&... params); // Initialize the mutex-bounded object by passing the given parameters to its constructor. - Locked lockExclusive() const; + Locked lockExclusive(LockSourceLocationArg location = {}) const; // Exclusively locks the object and returns it. The returned `Locked` can be passed by // move, similar to `Own`. // @@ -235,10 +440,20 @@ class MutexGuarded { // be shared between threads, its methods should be const, even though locking it produces a // non-const pointer to the contained object. - Locked lockShared() const; + Locked lockShared(LockSourceLocationArg location = {}) const; // Lock the value for shared access. Multiple shared locks can be taken concurrently, but cannot // be held at the same time as a non-shared lock. + Maybe> lockExclusiveWithTimeout(Duration timeout, + LockSourceLocationArg location = {}) const; + // Attempts to exclusively lock the object. If the timeout elapses before the lock is acquired, + // this returns null. + + Maybe> lockSharedWithTimeout(Duration timeout, + LockSourceLocationArg location = {}) const; + // Attempts to lock the value for shared access. If the timeout elapses before the lock is acquired, + // this returns null. + inline const T& getWithoutLock() const { return value; } inline T& getWithoutLock() { return value; } // Escape hatch for cases where some external factor guarantees that it's safe to get the @@ -249,6 +464,32 @@ class MutexGuarded { inline T& getAlreadyLockedExclusive() const; // Like `getWithoutLock()`, but asserts that the lock is already held by the calling thread. + template + auto when(Cond&& condition, Func&& callback, Maybe timeout = nullptr, + LockSourceLocationArg location = {}) const + -> decltype(callback(instance())) { + // Waits until condition(state) returns true, then calls callback(state) under lock. + // + // `condition`, when called, receives as its parameter a const reference to the state, which is + // locked (either shared or exclusive). `callback` receives a mutable reference, which is + // exclusively locked. + // + // `condition()` may be called multiple times, from multiple threads, while waiting for the + // condition to become true. It may even return true once, but then be called more times. + // It is guaranteed, though, that at the time `callback()` is finally called, `condition()` + // would currently return true (assuming it is a pure function of the guarded data). + // + // If `timeout` is specified, then after the given amount of time, the callback will be called + // regardless of whether the condition is true. In this case, when `callback()` is called, + // `condition()` may in fact evaluate false, but *only* if the timeout was reached. + // + // TODO(cleanup): lock->wait() is a better interface. Can we deprecate this one? + + auto lock = lockExclusive(); + lock.wait(kj::fwd(condition), timeout, location); + return callback(value); + } + private: mutable _::Mutex mutex; mutable T value; @@ -261,15 +502,115 @@ class MutexGuarded { static_assert(sizeof(T) < 0, "MutexGuarded's type cannot be const."); }; +template +class ExternalMutexGuarded { + // Holds a value that can only be manipulated while some other mutex is locked. + // + // The ExternalMutexGuarded lives *outside* the scope of any lock on the mutex, but ensures + // that the value it holds can only be accessed under lock by forcing the caller to present a + // lock before accessing the value. + // + // Additionally, ExternalMutexGuarded's destructor will take an exclusive lock on the mutex + // while destroying the held value, unless the value has been release()ed before hand. + // + // The type T must have the following properties (which probably all movable types satisfy): + // - T is movable. + // - Immediately after any of the following has happened, T's destructor is effectively a no-op + // (hence certainly not requiring locks): + // - The value has been default-constructed. + // - The value has been initialized by-move from a default-constructed T. + // - The value has been moved away. + // - If ExternalMutexGuarded is ever moved, then T must have a move constructor and move + // assignment operator that do not follow any pointers, therefore do not need to take a lock. + // + // Inherits from LockSourceLocation to perform an empty base class optimization when lock tracking + // is compiled out. Once the minimum C++ standard for the KJ library is C++20, this optimization + // could be replaced by a member variable with a [[no_unique_address]] annotation. +public: + ExternalMutexGuarded(LockSourceLocationArg location = {}) + : location(location) {} + + template + ExternalMutexGuarded(Locked lock, Params&&... params, LockSourceLocationArg location = {}) + : mutex(lock.mutex), + value(kj::fwd(params)...), + location(location) {} + // Construct the value in-place. This constructor requires passing ownership of the lock into + // the constructor. Normally this should be a lock that you take on the line calling the + // constructor, like: + // + // ExternalMutexGuarded foo(someMutexGuarded.lockExclusive()); + // + // The reason this constructor does not accept an lvalue reference to an existing lock is because + // this would be deadlock-prone: If an exception were thrown immediately after the constructor + // completed, then the destructor would deadlock, because the lock would still be held. An + // ExternalMutexGuarded must live outside the scope of any locks to avoid such a deadlock. + + ~ExternalMutexGuarded() noexcept(false) { + if (mutex != nullptr) { + mutex->lock(_::Mutex::EXCLUSIVE, nullptr, location); + KJ_DEFER(mutex->unlock(_::Mutex::EXCLUSIVE)); + value = T(); + } + } + + ExternalMutexGuarded(ExternalMutexGuarded&& other) + : mutex(other.mutex), value(kj::mv(other.value)), location(other.location) { + other.mutex = nullptr; + } + ExternalMutexGuarded& operator=(ExternalMutexGuarded&& other) { + mutex = other.mutex; + value = kj::mv(other.value); + location = other.location; + other.mutex = nullptr; + return *this; + } + + template + void set(Locked& lock, T&& newValue) { + KJ_IREQUIRE(mutex == nullptr); + mutex = lock.mutex; + value = kj::mv(newValue); + } + + template + T& get(Locked& lock) { + KJ_IREQUIRE(lock.mutex == mutex); + return value; + } + + template + const T& get(Locked& lock) const { + KJ_IREQUIRE(lock.mutex == mutex); + return value; + } + + template + T release(Locked& lock) { + // Release (move away) the value. This allows the destructor to skip locking the mutex. + KJ_IREQUIRE(lock.mutex == mutex); + T result = kj::mv(value); + mutex = nullptr; + return result; + } + +private: + _::Mutex* mutex = nullptr; + T value; + KJ_NO_UNIQUE_ADDRESS LockSourceLocation location; + // When built against C++20 (or clang >= 9.0), the overhead of this is elided. Otherwise this + // struct will be 1 byte larger than it would otherwise be. +}; + template class Lazy { // A lazily-initialized value. public: template - T& get(Func&& init); + T& get(Func&& init, LockSourceLocationArg location = {}); template - const T& get(Func&& init) const; + const T& get(Func&& init, LockSourceLocationArg location = {}) const; // The first thread to call get() will invoke the given init function to construct the value. // Other threads will block until construction completes, then return the same value. // @@ -296,17 +637,38 @@ inline MutexGuarded::MutexGuarded(Params&&... params) : value(kj::fwd(params)...) {} template -inline Locked MutexGuarded::lockExclusive() const { - mutex.lock(_::Mutex::EXCLUSIVE); +inline Locked MutexGuarded::lockExclusive(LockSourceLocationArg location) + const { + mutex.lock(_::Mutex::EXCLUSIVE, nullptr, location); return Locked(mutex, value); } template -inline Locked MutexGuarded::lockShared() const { - mutex.lock(_::Mutex::SHARED); +inline Locked MutexGuarded::lockShared(LockSourceLocationArg location) const { + mutex.lock(_::Mutex::SHARED, nullptr, location); return Locked(mutex, value); } +template +inline Maybe> MutexGuarded::lockExclusiveWithTimeout(Duration timeout, + LockSourceLocationArg location) const { + if (mutex.lock(_::Mutex::EXCLUSIVE, timeout, location)) { + return Locked(mutex, value); + } else { + return nullptr; + } +} + +template +inline Maybe> MutexGuarded::lockSharedWithTimeout(Duration timeout, + LockSourceLocationArg location) const { + if (mutex.lock(_::Mutex::SHARED, timeout, location)) { + return Locked(mutex, value); + } else { + return nullptr; + } +} + template inline const T& MutexGuarded::getAlreadyLockedShared() const { #ifdef KJ_DEBUG @@ -346,24 +708,71 @@ class Lazy::InitImpl: public _::Once::Initializer { template template -inline T& Lazy::get(Func&& init) { +inline T& Lazy::get(Func&& init, LockSourceLocationArg location) { if (!once.isInitialized()) { InitImpl initImpl(*this, kj::fwd(init)); - once.runOnce(initImpl); + once.runOnce(initImpl, location); } return *value; } template template -inline const T& Lazy::get(Func&& init) const { +inline const T& Lazy::get(Func&& init, LockSourceLocationArg location) const { if (!once.isInitialized()) { InitImpl initImpl(*this, kj::fwd(init)); - once.runOnce(initImpl); + once.runOnce(initImpl, location); } return *value; } +#if KJ_TRACK_LOCK_BLOCKING +struct BlockedOnMutexAcquisition { + const _::Mutex& mutex; + // The mutex we are blocked on. + + const SourceLocation& origin; + // Where did the blocking operation originate from. +}; + +struct BlockedOnCondVarWait { + const _::Mutex& mutex; + // The mutex the condition variable is using (may or may not be locked). + + const void* waiter; + // Pointer to the waiter that's being waited on. + + const SourceLocation& origin; + // Where did the blocking operation originate from. +}; + +struct BlockedOnOnceInit { + const _::Once& once; + + const SourceLocation& origin; + // Where did the blocking operation originate from. +}; + +using BlockedOnReason = OneOf; + +Maybe blockedReason() noexcept; +// Returns the information about the reason the current thread is blocked synchronously on KJ +// lock primitives. Returns nullptr if the current thread is not currently blocked on such +// primitives. This is intended to be called from a signal handler to check whether the current +// thread is blocked. Outside of a signal handler there is little value to this function. In those +// cases by definition the thread is not blocked. This includes the callable used as part of a +// condition variable since that happens after the lock is acquired & the current thread is no +// longer blocked). The utility could be made useful for non-signal handler use-cases by being able +// to fetch the pointer to the TLS variable directly (i.e. const BlockedOnReason&*). However, there +// would have to be additional changes/complexity to try make that work since you'd need +// synchronization to ensure that the memory you'd try to reference is still valid. The likely +// solution would be to make these mutually exclusive options where you can use either the fast +// async-safe option, or a mutex-guarded TLS variable you can get a reference to that isn't +// async-safe. That being said, maybe someone can come up with a way to make something that works +// in both use-cases which would of course be more preferable. +#endif + + } // namespace kj -#endif // KJ_MUTEX_H_ +KJ_END_HEADER diff --git a/c++/src/kj/one-of-test.c++ b/c++/src/kj/one-of-test.c++ index d73e05ba5e..d7bec4d108 100644 --- a/c++/src/kj/one-of-test.c++ +++ b/c++/src/kj/one-of-test.c++ @@ -31,6 +31,9 @@ TEST(OneOf, Basic) { EXPECT_FALSE(var.is()); EXPECT_FALSE(var.is()); EXPECT_FALSE(var.is()); + EXPECT_TRUE(var.tryGet() == nullptr); + EXPECT_TRUE(var.tryGet() == nullptr); + EXPECT_TRUE(var.tryGet() == nullptr); var.init(123); @@ -44,6 +47,10 @@ TEST(OneOf, Basic) { EXPECT_ANY_THROW(var.get()); #endif + EXPECT_EQ(123, KJ_ASSERT_NONNULL(var.tryGet())); + EXPECT_TRUE(var.tryGet() == nullptr); + EXPECT_TRUE(var.tryGet() == nullptr); + var.init(kj::str("foo")); EXPECT_FALSE(var.is()); @@ -52,6 +59,10 @@ TEST(OneOf, Basic) { EXPECT_EQ("foo", var.get()); + EXPECT_TRUE(var.tryGet() == nullptr); + EXPECT_TRUE(var.tryGet() == nullptr); + EXPECT_EQ("foo", KJ_ASSERT_NONNULL(var.tryGet())); + OneOf var2 = kj::mv(var); EXPECT_EQ("", var.get()); EXPECT_EQ("foo", var2.get()); @@ -59,6 +70,11 @@ TEST(OneOf, Basic) { var = kj::mv(var2); EXPECT_EQ("foo", var.get()); EXPECT_EQ("", var2.get()); + + auto canCompile KJ_UNUSED = [&]() { + var.allHandled<3>(); + // var.allHandled<2>(); // doesn't compile + }; } TEST(OneOf, Copy) { @@ -82,4 +98,137 @@ TEST(OneOf, Copy) { EXPECT_STREQ("foo", var2.get()); } +TEST(OneOf, Switch) { + OneOf var; + var = "foo"; + uint count = 0; + + { + KJ_SWITCH_ONEOF(var) { + KJ_CASE_ONEOF(i, int) { + KJ_FAIL_ASSERT("expected char*, got int", i); + } + KJ_CASE_ONEOF(s, const char*) { + KJ_EXPECT(kj::StringPtr(s) == "foo"); + ++count; + } + KJ_CASE_ONEOF(n, float) { + KJ_FAIL_ASSERT("expected char*, got float", n); + } + } + } + + KJ_EXPECT(count == 1); + + { + KJ_SWITCH_ONEOF(kj::cp(var)) { + KJ_CASE_ONEOF(i, int) { + KJ_FAIL_ASSERT("expected char*, got int", i); + } + KJ_CASE_ONEOF(s, const char*) { + KJ_EXPECT(kj::StringPtr(s) == "foo"); + } + KJ_CASE_ONEOF(n, float) { + KJ_FAIL_ASSERT("expected char*, got float", n); + } + } + } + + { + // At one time this failed to compile. + const auto& constVar = var; + KJ_SWITCH_ONEOF(constVar) { + KJ_CASE_ONEOF(i, int) { + KJ_FAIL_ASSERT("expected char*, got int", i); + } + KJ_CASE_ONEOF(s, const char*) { + KJ_EXPECT(kj::StringPtr(s) == "foo"); + } + KJ_CASE_ONEOF(n, float) { + KJ_FAIL_ASSERT("expected char*, got float", n); + } + } + } +} + +TEST(OneOf, Maybe) { + Maybe> var; + var = OneOf(123); + + KJ_IF_MAYBE(v, var) { + // At one time this failed to compile. Note that a Maybe> isn't necessarily great + // style -- you might be better off with an explicit OneOf. Nevertheless, it should + // compile. + KJ_SWITCH_ONEOF(*v) { + KJ_CASE_ONEOF(i, int) { + KJ_EXPECT(i == 123); + } + KJ_CASE_ONEOF(n, float) { + KJ_FAIL_ASSERT("expected int, got float", n); + } + } + } +} + +KJ_TEST("OneOf copy/move from alternative variants") { + { + // Test const copy. + const OneOf src = 23.5f; + OneOf dst = src; + KJ_ASSERT(dst.is()); + KJ_EXPECT(dst.get() == 23.5); + } + + { + // Test case that requires non-const copy. + int arr[3] = {1, 2, 3}; + OneOf> src = ArrayPtr(arr); + OneOf> dst = src; + KJ_ASSERT(dst.is>()); + KJ_EXPECT(dst.get>().begin() == arr); + KJ_EXPECT(dst.get>().size() == kj::size(arr)); + } + + { + // Test move. + OneOf src = kj::str("foo"); + OneOf dst = kj::mv(src); + KJ_ASSERT(dst.is()); + KJ_EXPECT(dst.get() == "foo"); + + String s = kj::mv(dst).get(); + KJ_EXPECT(s == "foo"); + } + + { + // We can still have nested OneOfs. + OneOf src = 23.5f; + OneOf> dst = src; + KJ_ASSERT((dst.is>())); + KJ_ASSERT((dst.get>().is())); + KJ_EXPECT((dst.get>().get() == 23.5)); + } +} + +template +struct T { + unsigned int n = N; +}; + +TEST(OneOf, MaxVariants) { + kj::OneOf< + T<1>, T<2>, T<3>, T<4>, T<5>, T<6>, T<7>, T<8>, T<9>, T<10>, + T<11>, T<12>, T<13>, T<14>, T<15>, T<16>, T<17>, T<18>, T<19>, T<20>, + T<21>, T<22>, T<23>, T<24>, T<25>, T<26>, T<27>, T<28>, T<29>, T<30>, + T<31>, T<32>, T<33>, T<34>, T<35>, T<36>, T<37>, T<38>, T<39>, T<40>, + T<41>, T<42>, T<43>, T<44>, T<45>, T<46>, T<47>, T<48>, T<49>, T<50> + > v; + + v = T<1>(); + EXPECT_TRUE(v.is>()); + + v = T<50>(); + EXPECT_TRUE(v.is>()); +} + } // namespace kj diff --git a/c++/src/kj/one-of.h b/c++/src/kj/one-of.h index 6e143c44cf..6cc24a0f03 100644 --- a/c++/src/kj/one-of.h +++ b/c++/src/kj/one-of.h @@ -19,36 +19,369 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_ONE_OF_H_ -#define KJ_ONE_OF_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "common.h" +KJ_BEGIN_HEADER + namespace kj { namespace _ { // private -template -struct TypeIndex_ { static constexpr uint value = TypeIndex_::value; }; -template -struct TypeIndex_ { static constexpr uint value = i; }; +template class Fail, typename Key, typename... Variants> +struct TypeIndex_; +template class Fail, typename Key, typename First, typename... Rest> +struct TypeIndex_ { + static constexpr uint value = TypeIndex_::value; +}; +template class Fail, typename Key, typename... Rest> +struct TypeIndex_ { static constexpr uint value = i; }; +template class Fail, typename Key> +struct TypeIndex_: public Fail {}; + +template +struct OneOfFailError_ { + static_assert(i == -1, "type does not match any in OneOf"); +}; +template +struct OneOfFailZero_ { + static constexpr int value = 0; +}; + +template +struct SuccessIfNotZero { + typedef int Success; +}; +template <> +struct SuccessIfNotZero<0> {}; + +enum class Variants0 {}; +enum class Variants1 { _variant0 }; +enum class Variants2 { _variant0, _variant1 }; +enum class Variants3 { _variant0, _variant1, _variant2 }; +enum class Variants4 { _variant0, _variant1, _variant2, _variant3 }; +enum class Variants5 { _variant0, _variant1, _variant2, _variant3, _variant4 }; +enum class Variants6 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5 }; +enum class Variants7 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6 }; +enum class Variants8 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7 }; +enum class Variants9 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8 }; +enum class Variants10 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9 }; +enum class Variants11 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10 }; +enum class Variants12 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11 }; +enum class Variants13 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12 }; +enum class Variants14 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13 }; +enum class Variants15 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14 }; +enum class Variants16 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15 }; +enum class Variants17 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16 }; +enum class Variants18 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17 }; +enum class Variants19 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18 }; +enum class Variants20 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19 }; +enum class Variants21 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20 }; +enum class Variants22 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21 }; +enum class Variants23 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22 }; +enum class Variants24 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23 }; +enum class Variants25 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24 }; +enum class Variants26 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25 }; +enum class Variants27 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26 }; +enum class Variants28 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27 }; +enum class Variants29 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28 }; +enum class Variants30 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29 }; +enum class Variants31 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30 }; +enum class Variants32 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31 }; +enum class Variants33 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32 }; +enum class Variants34 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33 }; +enum class Variants35 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34 }; +enum class Variants36 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35 }; +enum class Variants37 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36 }; +enum class Variants38 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37 }; +enum class Variants39 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38 }; +enum class Variants40 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39 }; +enum class Variants41 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40 }; +enum class Variants42 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41 }; +enum class Variants43 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42 }; +enum class Variants44 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43 }; +enum class Variants45 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44 }; +enum class Variants46 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45 }; +enum class Variants47 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45, _variant46 }; +enum class Variants48 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45, _variant46, _variant47 }; +enum class Variants49 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45, _variant46, _variant47, _variant48 }; +enum class Variants50 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6, + _variant7, _variant8, _variant9, _variant10, _variant11, _variant12, + _variant13, _variant14, _variant15, _variant16, _variant17, _variant18, + _variant19, _variant20, _variant21, _variant22, _variant23, _variant24, + _variant25, _variant26, _variant27, _variant28, _variant29, _variant30, + _variant31, _variant32, _variant33, _variant34, _variant35, _variant36, + _variant37, _variant38, _variant39, _variant40, _variant41, _variant42, + _variant43, _variant44, _variant45, _variant46, _variant47, _variant48, + _variant49 }; + +template struct Variants_; +template <> struct Variants_<0> { typedef Variants0 Type; }; +template <> struct Variants_<1> { typedef Variants1 Type; }; +template <> struct Variants_<2> { typedef Variants2 Type; }; +template <> struct Variants_<3> { typedef Variants3 Type; }; +template <> struct Variants_<4> { typedef Variants4 Type; }; +template <> struct Variants_<5> { typedef Variants5 Type; }; +template <> struct Variants_<6> { typedef Variants6 Type; }; +template <> struct Variants_<7> { typedef Variants7 Type; }; +template <> struct Variants_<8> { typedef Variants8 Type; }; +template <> struct Variants_<9> { typedef Variants9 Type; }; +template <> struct Variants_<10> { typedef Variants10 Type; }; +template <> struct Variants_<11> { typedef Variants11 Type; }; +template <> struct Variants_<12> { typedef Variants12 Type; }; +template <> struct Variants_<13> { typedef Variants13 Type; }; +template <> struct Variants_<14> { typedef Variants14 Type; }; +template <> struct Variants_<15> { typedef Variants15 Type; }; +template <> struct Variants_<16> { typedef Variants16 Type; }; +template <> struct Variants_<17> { typedef Variants17 Type; }; +template <> struct Variants_<18> { typedef Variants18 Type; }; +template <> struct Variants_<19> { typedef Variants19 Type; }; +template <> struct Variants_<20> { typedef Variants20 Type; }; +template <> struct Variants_<21> { typedef Variants21 Type; }; +template <> struct Variants_<22> { typedef Variants22 Type; }; +template <> struct Variants_<23> { typedef Variants23 Type; }; +template <> struct Variants_<24> { typedef Variants24 Type; }; +template <> struct Variants_<25> { typedef Variants25 Type; }; +template <> struct Variants_<26> { typedef Variants26 Type; }; +template <> struct Variants_<27> { typedef Variants27 Type; }; +template <> struct Variants_<28> { typedef Variants28 Type; }; +template <> struct Variants_<29> { typedef Variants29 Type; }; +template <> struct Variants_<30> { typedef Variants30 Type; }; +template <> struct Variants_<31> { typedef Variants31 Type; }; +template <> struct Variants_<32> { typedef Variants32 Type; }; +template <> struct Variants_<33> { typedef Variants33 Type; }; +template <> struct Variants_<34> { typedef Variants34 Type; }; +template <> struct Variants_<35> { typedef Variants35 Type; }; +template <> struct Variants_<36> { typedef Variants36 Type; }; +template <> struct Variants_<37> { typedef Variants37 Type; }; +template <> struct Variants_<38> { typedef Variants38 Type; }; +template <> struct Variants_<39> { typedef Variants39 Type; }; +template <> struct Variants_<40> { typedef Variants40 Type; }; +template <> struct Variants_<41> { typedef Variants41 Type; }; +template <> struct Variants_<42> { typedef Variants42 Type; }; +template <> struct Variants_<43> { typedef Variants43 Type; }; +template <> struct Variants_<44> { typedef Variants44 Type; }; +template <> struct Variants_<45> { typedef Variants45 Type; }; +template <> struct Variants_<46> { typedef Variants46 Type; }; +template <> struct Variants_<47> { typedef Variants47 Type; }; +template <> struct Variants_<48> { typedef Variants48 Type; }; +template <> struct Variants_<49> { typedef Variants49 Type; }; +template <> struct Variants_<50> { typedef Variants50 Type; }; + +template +using Variants = typename Variants_::Type; } // namespace _ (private) template class OneOf { template - static inline constexpr uint typeIndex() { return _::TypeIndex_<1, Key, Variants...>::value; } - // Get the 1-based index of Key within the type list Types. + static inline constexpr uint typeIndex() { + return _::TypeIndex_<1, _::OneOfFailError_, Key, Variants...>::value; + } + // Get the 1-based index of Key within the type list Types, or static_assert with a nice error. + + template + static inline constexpr uint typeIndexOrZero() { + return _::TypeIndex_<1, _::OneOfFailZero_, Key, Variants...>::value; + } + + template + struct HasAll; + // Has a member type called "Success" if and only if all of `OtherVariants` are types that + // appear in `Variants`. Used with SFINAE to enable subset constructors. public: inline OneOf(): tag(0) {} + OneOf(const OneOf& other) { copyFrom(other); } + OneOf(OneOf& other) { copyFrom(other); } OneOf(OneOf&& other) { moveFrom(other); } + // Copy/move from same OneOf type. + + template ::Success> + OneOf(const OneOf& other) { copyFromSubset(other); } + template ::Success> + OneOf(OneOf& other) { copyFromSubset(other); } + template ::Success> + OneOf(OneOf&& other) { moveFromSubset(other); } + // Copy/move from OneOf that contains a subset of the types we do. + + template >::Success> + OneOf(T&& other): tag(typeIndex>()) { + ctor(*reinterpret_cast*>(space), kj::fwd(other)); + } + // Copy/move from a value that matches one of the individual types in the OneOf. + ~OneOf() { destroy(); } OneOf& operator=(const OneOf& other) { if (tag != 0) destroy(); copyFrom(other); return *this; } @@ -63,15 +396,25 @@ class OneOf { } template - T& get() { + T& get() & { KJ_IREQUIRE(is(), "Must check OneOf::is() before calling get()."); return *reinterpret_cast(space); } template - const T& get() const { + T&& get() && { + KJ_IREQUIRE(is(), "Must check OneOf::is() before calling get()."); + return kj::mv(*reinterpret_cast(space)); + } + template + const T& get() const& { KJ_IREQUIRE(is(), "Must check OneOf::is() before calling get()."); return *reinterpret_cast(space); } + template + const T&& get() const&& { + KJ_IREQUIRE(is(), "Must check OneOf::is() before calling get()."); + return kj::mv(*reinterpret_cast(space)); + } template T& init(Params&&... params) { @@ -81,6 +424,45 @@ class OneOf { return *reinterpret_cast(space); } + template + Maybe tryGet() { + if (is()) { + return *reinterpret_cast(space); + } else { + return nullptr; + } + } + template + Maybe tryGet() const { + if (is()) { + return *reinterpret_cast(space); + } else { + return nullptr; + } + } + + template + KJ_NORETURN(void allHandled()); + // After a series of if/else blocks handling each variant of the OneOf, have the final else + // block call allHandled() where n is the number of variants. This will fail to compile + // if new variants are added in the future. + + typedef _::Variants Tag; + + Tag which() const { + KJ_IREQUIRE(tag != 0, "Can't KJ_SWITCH_ONEOF() on uninitialized value."); + return static_cast(tag - 1); + } + + template + static constexpr Tag tagFor() { + return static_cast(typeIndex() - 1); + } + + OneOf* _switchSubject() & { return this; } + const OneOf* _switchSubject() const& { return this; } + _::NullableValue _switchSubject() && { return kj::mv(*this); } + private: uint tag; @@ -135,6 +517,20 @@ class OneOf { doAll(copyVariantFrom(other)...); } + template + inline bool copyVariantFrom(OneOf& other) { + if (other.is()) { + ctor(*reinterpret_cast(space), other.get()); + } + return false; + } + void copyFrom(OneOf& other) { + // Initialize as a copy of `other`. Expects that `this` starts out uninitialized, so the tag + // is invalid. + tag = other.tag; + doAll(copyVariantFrom(other)...); + } + template inline bool moveVariantFrom(OneOf& other) { if (other.is()) { @@ -148,8 +544,125 @@ class OneOf { tag = other.tag; doAll(moveVariantFrom(other)...); } + + template + inline bool copySubsetVariantFrom(const OneOf& other) { + if (other.template is()) { + tag = typeIndex>(); + ctor(*reinterpret_cast(space), other.template get()); + } + return false; + } + template + void copyFromSubset(const OneOf& other) { + doAll(copySubsetVariantFrom(other)...); + } + + template + inline bool copySubsetVariantFrom(OneOf& other) { + if (other.template is()) { + tag = typeIndex>(); + ctor(*reinterpret_cast(space), other.template get()); + } + return false; + } + template + void copyFromSubset(OneOf& other) { + doAll(copySubsetVariantFrom(other)...); + } + + template + inline bool moveSubsetVariantFrom(OneOf& other) { + if (other.template is()) { + tag = typeIndex>(); + ctor(*reinterpret_cast(space), kj::mv(other.template get())); + } + return false; + } + template + void moveFromSubset(OneOf& other) { + doAll(moveSubsetVariantFrom(other)...); + } }; +template +template +struct OneOf::HasAll + : public HasAll(), Rest...> {}; +template +template +struct OneOf::HasAll: public _::SuccessIfNotZero {}; + +template +template +void OneOf::allHandled() { + // After a series of if/else blocks handling each variant of the OneOf, have the final else + // block call allHandled() where n is the number of variants. This will fail to compile + // if new variants are added in the future. + + static_assert(i == sizeof...(Variants), "new OneOf variants need to be handled here"); + KJ_UNREACHABLE; +} + +#if __cplusplus > 201402L +#define KJ_SWITCH_ONEOF(value) \ + switch (auto _kj_switch_subject = (value)._switchSubject(); _kj_switch_subject->which()) +#else +#define KJ_SWITCH_ONEOF(value) \ + /* Without C++17, we can only support one switch per containing block. Deal with it. */ \ + auto _kj_switch_subject = (value)._switchSubject(); \ + switch (_kj_switch_subject->which()) +#endif +#if !_MSC_VER || defined(__clang__) +#define KJ_CASE_ONEOF(name, ...) \ + break; \ + case ::kj::Decay::template tagFor<__VA_ARGS__>(): \ + for (auto& name = _kj_switch_subject->template get<__VA_ARGS__>(), *_kj_switch_done = &name; \ + _kj_switch_done; _kj_switch_done = nullptr) +#else +// TODO(msvc): The latest MSVC which ships with VS2019 now ICEs on the implementation above. It +// appears we can hack around the problem by moving the `->template get<>()` syntax to an outer +// `if`. (This unfortunately allows wonky syntax like `KJ_CASE_ONEOF(a, B) { } else { }`.) +// https://developercommunity.visualstudio.com/content/problem/1143733/internal-compiler-error-on-v1670.html +#define KJ_CASE_ONEOF(name, ...) \ + break; \ + case ::kj::Decay::template tagFor<__VA_ARGS__>(): \ + if (auto* _kj_switch_done = &_kj_switch_subject->template get<__VA_ARGS__>()) \ + for (auto& name = *_kj_switch_done; _kj_switch_done; _kj_switch_done = nullptr) +#endif +#define KJ_CASE_ONEOF_DEFAULT break; default: +// Allows switching over a OneOf. +// +// Example: +// +// kj::OneOf variant; +// KJ_SWITCH_ONEOF(variant) { +// KJ_CASE_ONEOF(i, int) { +// doSomethingWithInt(i); +// } +// KJ_CASE_ONEOF(s, const char*) { +// doSomethingWithString(s); +// } +// KJ_CASE_ONEOF_DEFAULT { +// doSomethingElse(); +// } +// } +// +// Notes: +// - If you don't handle all possible types and don't include a default branch, you'll get a +// compiler warning, just like a regular switch() over an enum where one of the enum values is +// missing. +// - There's no need for a `break` statement in a KJ_CASE_ONEOF; it is implied. +// - Under C++11 and C++14, only one KJ_SWITCH_ONEOF() can appear in a block. Wrap the switch in +// a pair of braces if you need a second switch in the same block. If C++17 is enabled, this is +// not an issue. +// +// Implementation notes: +// - The use of __VA_ARGS__ is to account for template types that have commas separating type +// parameters, since macros don't recognize <> as grouping. +// - _kj_switch_done is really used as a boolean flag to prevent the for() loop from actually +// looping, but it's defined as a pointer since that's all we can define in this context. + } // namespace kj -#endif // KJ_ONE_OF_H_ +KJ_END_HEADER diff --git a/c++/src/kj/parse/char.c++ b/c++/src/kj/parse/char.c++ index 1c6c77d5e3..4db99142fa 100644 --- a/c++/src/kj/parse/char.c++ +++ b/c++/src/kj/parse/char.c++ @@ -61,7 +61,8 @@ double ParseFloat::operator()(const Array& digits, *pos++ = '\0'; KJ_DASSERT(pos == buf.end()); - return strtod(buf.begin(), nullptr); + // The above construction should always produce a valid double, so this should never throw... + return StringPtr(buf.begin(), bufSize).parseAs(); } } // namespace _ (private) diff --git a/c++/src/kj/parse/char.h b/c++/src/kj/parse/char.h index 2e6d51921d..74e4e6c82b 100644 --- a/c++/src/kj/parse/char.h +++ b/c++/src/kj/parse/char.h @@ -22,17 +22,14 @@ // This file contains parsers useful for character stream inputs, including parsers to parse // common kinds of tokens like identifiers, numbers, and quoted strings. -#ifndef KJ_PARSE_CHAR_H_ -#define KJ_PARSE_CHAR_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "common.h" #include "../string.h" #include +KJ_BEGIN_HEADER + namespace kj { namespace parse { @@ -111,6 +108,13 @@ class CharGroup_ { return (bits[c / 64] & (1ll << (c % 64))) != 0; } + inline bool containsAll(ArrayPtr text) const { + for (char c: text) { + if (!contains(c)) return false; + } + return true; + } + template Maybe operator()(Input& input) const { if (input.atEnd()) return nullptr; @@ -152,7 +156,7 @@ constexpr inline CharGroup_ charRange(char first, char last) { return CharGroup_().orRange(first, last); } -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) #define anyOfChars(chars) CharGroup_().orAny(chars) // TODO(msvc): MSVC ICEs on the proper definition of `anyOfChars()`, which in turn prevents us from // building the compiler or schema parser. We don't know why this happens, but Harris found that @@ -212,6 +216,7 @@ namespace _ { // private struct IdentifierToString { inline String operator()(char first, const Array& rest) const { + if (rest.size() == 0) return heapString(&first, 1); String result = heapString(rest.size() + 1); result[0] = first; memcpy(result.begin() + 1, rest.begin(), rest.size()); @@ -358,4 +363,4 @@ constexpr auto doubleQuotedHexBinary = sequence( } // namespace parse } // namespace kj -#endif // KJ_PARSE_CHAR_H_ +KJ_END_HEADER diff --git a/c++/src/kj/parse/common.h b/c++/src/kj/parse/common.h index fa30d93954..cfb97299ca 100644 --- a/c++/src/kj/parse/common.h +++ b/c++/src/kj/parse/common.h @@ -33,22 +33,24 @@ // will have updated the input cursor to point to the position just past the end of what was parsed. // On failure, the cursor position is unspecified. -#ifndef KJ_PARSE_COMMON_H_ -#define KJ_PARSE_COMMON_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "../common.h" #include "../memory.h" #include "../array.h" #include "../tuple.h" #include "../vector.h" -#if _MSC_VER && !__clang__ + +#if _MSC_VER && _MSC_VER < 1920 && !__clang__ +#define KJ_MSVC_BROKEN_DECLTYPE 1 +#endif + +#if KJ_MSVC_BROKEN_DECLTYPE #include // result_of_t #endif +KJ_BEGIN_HEADER + namespace kj { namespace parse { @@ -66,7 +68,7 @@ class IteratorInput { parent->best = kj::max(kj::max(pos, best), parent->best); } } - KJ_DISALLOW_COPY(IteratorInput); + KJ_DISALLOW_COPY_AND_MOVE(IteratorInput); void advanceParent() { parent->pos = pos; @@ -104,10 +106,9 @@ template struct OutputType_; template struct OutputType_> { typedef T Type; }; template using OutputType = typename OutputType_< -#if _MSC_VER && !__clang__ +#if KJ_MSVC_BROKEN_DECLTYPE std::result_of_t - // The instance() based version below results in: - // C2064: term does not evaluate to a function taking 1 arguments + // The instance() based version below results in many compiler errors on MSVC2017. #else decltype(instance()(instance())) #endif @@ -821,4 +822,4 @@ constexpr EndOfInput_ endOfInput = EndOfInput_(); } // namespace parse } // namespace kj -#endif // KJ_PARSE_COMMON_H_ +KJ_END_HEADER diff --git a/c++/src/kj/refcount-test.c++ b/c++/src/kj/refcount-test.c++ index 81f5e5fb37..c3b39d4577 100644 --- a/c++/src/kj/refcount-test.c++ +++ b/c++/src/kj/refcount-test.c++ @@ -57,4 +57,81 @@ TEST(Refcount, Basic) { #endif } +struct SetTrueInDestructor2 { + // Like above but doesn't inherit Refcounted. + + SetTrueInDestructor2(bool* ptr): ptr(ptr) {} + ~SetTrueInDestructor2() { *ptr = true; } + + bool* ptr; +}; + +KJ_TEST("RefcountedWrapper") { + { + bool b = false; + Own> w = refcountedWrapper(&b); + KJ_EXPECT(!b); + + Own ref1 = w->addWrappedRef(); + Own ref2 = w->addWrappedRef(); + + KJ_EXPECT(ref1.get() == &w->getWrapped()); + KJ_EXPECT(ref1.get() == ref2.get()); + + KJ_EXPECT(!b); + + w = nullptr; + ref1 = nullptr; + + KJ_EXPECT(!b); + + ref2 = nullptr; + + KJ_EXPECT(b); + } + + // Wrap Own. + { + bool b = false; + Own>> w = + refcountedWrapper(kj::heap(&b)); + KJ_EXPECT(!b); + + Own ref1 = w->addWrappedRef(); + Own ref2 = w->addWrappedRef(); + + KJ_EXPECT(ref1.get() == &w->getWrapped()); + KJ_EXPECT(ref1.get() == ref2.get()); + + KJ_EXPECT(!b); + + w = nullptr; + ref1 = nullptr; + + KJ_EXPECT(!b); + + ref2 = nullptr; + + KJ_EXPECT(b); + } + + // Try wrapping an `int` to really demonstrate the wrapped type can be anything. + { + Own> w = refcountedWrapper(123); + int* ptr = &w->getWrapped(); + KJ_EXPECT(*ptr == 123); + + Own ref1 = w->addWrappedRef(); + Own ref2 = w->addWrappedRef(); + + KJ_EXPECT(ref1.get() == ptr); + KJ_EXPECT(ref2.get() == ptr); + + w = nullptr; + ref1 = nullptr; + + KJ_EXPECT(*ref2 == 123); + } +} + } // namespace kj diff --git a/c++/src/kj/refcount.c++ b/c++/src/kj/refcount.c++ index dc1e3808b2..33de86fb72 100644 --- a/c++/src/kj/refcount.c++ +++ b/c++/src/kj/refcount.c++ @@ -21,10 +21,19 @@ #include "refcount.h" #include "debug.h" -#include + +#if _MSC_VER && !defined(__clang__) +// Annoyingly, MSVC only implements the C++ atomic libs, not the C libs, so the only useful +// thing we can get from seems to be atomic_thread_fence... but that one function is +// indeed not implemented by the intrinsics, so... +#include +#endif namespace kj { +// ======================================================================================= +// Non-atomic (thread-unsafe) refcounting + Refcounted::~Refcounted() noexcept(false) { KJ_ASSERT(refcount == 0, "Refcounted object deleted with non-zero refcount."); } @@ -35,4 +44,59 @@ void Refcounted::disposeImpl(void* pointer) const { } } +// ======================================================================================= +// Atomic (thread-safe) refcounting + +AtomicRefcounted::~AtomicRefcounted() noexcept(false) { + KJ_ASSERT(refcount == 0, "Refcounted object deleted with non-zero refcount."); +} + +void AtomicRefcounted::disposeImpl(void* pointer) const { +#if _MSC_VER && !defined(__clang__) + if (KJ_MSVC_INTERLOCKED(Decrement, rel)(&refcount) == 0) { + std::atomic_thread_fence(std::memory_order_acquire); + delete this; + } +#else + if (__atomic_sub_fetch(&refcount, 1, __ATOMIC_RELEASE) == 0) { + __atomic_thread_fence(__ATOMIC_ACQUIRE); + delete this; + } +#endif +} + +bool AtomicRefcounted::addRefWeakInternal() const { +#if _MSC_VER && !defined(__clang__) + long orig = refcount; + + for (;;) { + if (orig == 0) { + // Refcount already hit zero. Destructor is already running so we can't revive the object. + return false; + } + + unsigned long old = KJ_MSVC_INTERLOCKED(CompareExchange, nf)(&refcount, orig + 1, orig); + if (old == orig) { + return true; + } + orig = old; + } +#else + uint orig = __atomic_load_n(&refcount, __ATOMIC_RELAXED); + + for (;;) { + if (orig == 0) { + // Refcount already hit zero. Destructor is already running so we can't revive the object. + return false; + } + + if (__atomic_compare_exchange_n(&refcount, &orig, orig + 1, true, + __ATOMIC_RELAXED, __ATOMIC_RELAXED)) { + // Successfully incremented refcount without letting it hit zero. + return true; + } + } +#endif +} + } // namespace kj diff --git a/c++/src/kj/refcount.h b/c++/src/kj/refcount.h index a24e4bf5b9..03b5234d8d 100644 --- a/c++/src/kj/refcount.h +++ b/c++/src/kj/refcount.h @@ -19,17 +19,25 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#include "memory.h" +#pragma once -#ifndef KJ_REFCOUNT_H_ -#define KJ_REFCOUNT_H_ +#include "memory.h" -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header +#if _MSC_VER +#if _MSC_VER < 1910 +#include +#else +#include +#endif #endif +KJ_BEGIN_HEADER + namespace kj { +// ======================================================================================= +// Non-atomic (thread-unsafe) refcounting + class Refcounted: private Disposer { // Subclass this to create a class that contains a reference count. Then, use // `kj::refcounted()` to allocate a new refcounted pointer. @@ -57,7 +65,9 @@ class Refcounted: private Disposer { // Own could also be nice. public: + Refcounted() = default; virtual ~Refcounted() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(Refcounted); inline bool isShared() const { return refcount > 1; } // Check if there are multiple references to this object. This is sometimes useful for deciding @@ -75,6 +85,9 @@ class Refcounted: private Disposer { friend Own addRef(T& object); template friend Own refcounted(Params&&... params); + + template + friend class RefcountedWrapper; }; template @@ -102,6 +115,169 @@ Own Refcounted::addRefInternal(T* object) { return Own(object, *refcounted); } +template +class RefcountedWrapper: public Refcounted { + // Adds refcounting as a wrapper around an existing type, allowing you to construct references + // with type Own that appears to point directly to the underlying object. + +public: + template + RefcountedWrapper(Params&&... params): wrapped(kj::fwd(params)...) {} + + T& getWrapped() { return wrapped; } + const T& getWrapped() const { return wrapped; } + + Own addWrappedRef() { + // Return an owned reference to the wrapped object that is backed by a refcount. + ++refcount; + return Own(&wrapped, *this); + } + +private: + T wrapped; +}; + +template +class RefcountedWrapper>: public Refcounted { + // Specialization for when the wrapped type is itself Own. We don't want this to result in + // Own>. + +public: + RefcountedWrapper(Own wrapped): wrapped(kj::mv(wrapped)) {} + + T& getWrapped() { return *wrapped; } + const T& getWrapped() const { return *wrapped; } + + Own addWrappedRef() { + // Return an owned reference to the wrapped object that is backed by a refcount. + ++refcount; + return Own(wrapped.get(), *this); + } + +private: + Own wrapped; +}; + +template +Own> refcountedWrapper(Params&&... params) { + return refcounted>(kj::fwd(params)...); +} + +template +Own>> refcountedWrapper(Own&& wrapped) { + return refcounted>>(kj::mv(wrapped)); +} + +// ======================================================================================= +// Atomic (thread-safe) refcounting +// +// Warning: Atomic ops are SLOW. + +#if _MSC_VER && !defined(__clang__) +#if _M_ARM +#define KJ_MSVC_INTERLOCKED(OP, MEM) _Interlocked##OP##_##MEM +#else +#define KJ_MSVC_INTERLOCKED(OP, MEM) _Interlocked##OP +#endif +#endif + +class AtomicRefcounted: private kj::Disposer { +public: + AtomicRefcounted() = default; + virtual ~AtomicRefcounted() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(AtomicRefcounted); + + inline bool isShared() const { +#if _MSC_VER && !defined(__clang__) + return KJ_MSVC_INTERLOCKED(Or, acq)(&refcount, 0) > 1; +#else + return __atomic_load_n(&refcount, __ATOMIC_ACQUIRE) > 1; +#endif + } + +private: +#if _MSC_VER && !defined(__clang__) + mutable volatile long refcount = 0; +#else + mutable volatile uint refcount = 0; +#endif + + bool addRefWeakInternal() const; + + void disposeImpl(void* pointer) const override; + template + static kj::Own addRefInternal(T* object); + template + static kj::Own addRefInternal(const T* object); + + template + friend kj::Own atomicAddRef(T& object); + template + friend kj::Own atomicAddRef(const T& object); + template + friend kj::Maybe> atomicAddRefWeak(const T& object); + template + friend kj::Own atomicRefcounted(Params&&... params); +}; + +template +inline kj::Own atomicRefcounted(Params&&... params) { + return AtomicRefcounted::addRefInternal(new T(kj::fwd(params)...)); +} + +template +kj::Own atomicAddRef(T& object) { + KJ_IREQUIRE(object.AtomicRefcounted::refcount > 0, + "Object not allocated with kj::atomicRefcounted()."); + return AtomicRefcounted::addRefInternal(&object); +} + +template +kj::Own atomicAddRef(const T& object) { + KJ_IREQUIRE(object.AtomicRefcounted::refcount > 0, + "Object not allocated with kj::atomicRefcounted()."); + return AtomicRefcounted::addRefInternal(&object); +} + +template +kj::Maybe> atomicAddRefWeak(const T& object) { + // Try to addref an object whose refcount could have already reached zero in another thread, and + // whose destructor could therefore already have started executing. The destructor must contain + // some synchronization that guarantees that said destructor has not yet completed when + // attomicAddRefWeak() is called (so that the object is still valid). Since the destructor cannot + // be canceled once it has started, in the case that it has already started, this function + // returns nullptr. + + const AtomicRefcounted* refcounted = &object; + if (refcounted->addRefWeakInternal()) { + return kj::Own(&object, *refcounted); + } else { + return nullptr; + } +} + +template +kj::Own AtomicRefcounted::addRefInternal(T* object) { + AtomicRefcounted* refcounted = object; +#if _MSC_VER && !defined(__clang__) + KJ_MSVC_INTERLOCKED(Increment, nf)(&refcounted->refcount); +#else + __atomic_add_fetch(&refcounted->refcount, 1, __ATOMIC_RELAXED); +#endif + return kj::Own(object, *refcounted); +} + +template +kj::Own AtomicRefcounted::addRefInternal(const T* object) { + const AtomicRefcounted* refcounted = object; +#if _MSC_VER && !defined(__clang__) + KJ_MSVC_INTERLOCKED(Increment, nf)(&refcounted->refcount); +#else + __atomic_add_fetch(&refcounted->refcount, 1, __ATOMIC_RELAXED); +#endif + return kj::Own(object, *refcounted); +} + } // namespace kj -#endif // KJ_REFCOUNT_H_ +KJ_END_HEADER diff --git a/c++/src/kj/source-location.c++ b/c++/src/kj/source-location.c++ new file mode 100644 index 0000000000..fda19323e3 --- /dev/null +++ b/c++/src/kj/source-location.c++ @@ -0,0 +1,28 @@ +// Copyright (c) 2021 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "source-location.h" + +namespace kj { +kj::String KJ_STRINGIFY(const SourceLocation& l) { + return kj::str(l.fileName, ":", l.lineNumber, ":", l.columnNumber, " in ", l.function); +} +} // namespace kj diff --git a/c++/src/kj/source-location.h b/c++/src/kj/source-location.h new file mode 100644 index 0000000000..0c587a38ac --- /dev/null +++ b/c++/src/kj/source-location.h @@ -0,0 +1,111 @@ +// Copyright (c) 2021 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "string.h" + +KJ_BEGIN_HEADER + +// GCC does not implement __builtin_COLUMN() as that's non-standard but MSVC & clang do. +// MSVC does as of version https://github.com/microsoft/STL/issues/54) but there's currently not any +// pressing need for this for MSVC & writing the write compiler version check is annoying. +// Checking for clang version is problematic due to the way that XCode lies about __clang_major__. +// Instead we use __has_builtin as the feature check to check clang. +// Context: https://github.com/capnproto/capnproto/issues/1305 +#ifdef __has_builtin +#if __has_builtin(__builtin_COLUMN) +#define KJ_CALLER_COLUMN() __builtin_COLUMN() +#else +#define KJ_CALLER_COLUMN() 0 +#endif +#else +#define KJ_CALLER_COLUMN() 0 +#endif + +#if __cplusplus > 201703L +#define KJ_COMPILER_SUPPORTS_SOURCE_LOCATION 1 +#elif defined(__has_builtin) +// Clang 9 added these builtins: https://releases.llvm.org/9.0.0/tools/clang/docs/LanguageExtensions.html +// Use __has_builtin as the way to detect this because __clang_major__ is unreliable (see above +// about issue with Xcode-provided clang). +#define KJ_COMPILER_SUPPORTS_SOURCE_LOCATION ( \ + __has_builtin(__builtin_FILE) && \ + __has_builtin(__builtin_LINE) && \ + __has_builtin(__builtin_FUNCTION) \ + ) +#elif __GNUC__ >= 5 +// GCC 5 supports the required builtins: https://gcc.gnu.org/onlinedocs/gcc-5.1.0/gcc/Other-Builtins.html +#define KJ_COMPILER_SUPPORTS_SOURCE_LOCATION 1 +#endif + +namespace kj { +class SourceLocation { + // libc++ doesn't seem to implement (or even ), so + // this is a non-STL wrapper over the compiler primitives (these are the same across MSVC/clang/ + // gcc). Additionally this uses kj::StringPtr for holding the strings instead of const char* which + // makes it integrate a little more nicely into KJ. + + struct Badge { explicit constexpr Badge() = default; }; + // Neat little trick to make sure we can never call SourceLocation with explicit arguments. +public: +#if !KJ_COMPILER_SUPPORTS_SOURCE_LOCATION + constexpr SourceLocation() : fileName("??"), function("??"), lineNumber(0), columnNumber(0) {} + // Constructs a dummy source location that's not pointing at anything. +#else + constexpr SourceLocation(Badge = Badge{}, const char* file = __builtin_FILE(), + const char* func = __builtin_FUNCTION(), uint line = __builtin_LINE(), + uint column = KJ_CALLER_COLUMN()) + : fileName(file), function(func), lineNumber(line), columnNumber(column) + {} +#endif + +#if KJ_COMPILER_SUPPORTS_SOURCE_LOCATION + // This can only be exposed if we actually generate valid SourceLocation objects as otherwise all + // SourceLocation objects would confusingly (and likely problematically) be equated equal. + constexpr bool operator==(const SourceLocation& o) const { + // Pointer equality is fine here based on how SourceLocation operates & how compilers will + // intern all duplicate string constants. + return fileName == o.fileName && function == o.function && lineNumber == o.lineNumber && + columnNumber == o.columnNumber; + } +#endif + + const char* fileName; + const char* function; + uint lineNumber; + uint columnNumber; +}; + +kj::String KJ_STRINGIFY(const SourceLocation& l); + +class NoopSourceLocation { + // This is used in places where we want to conditionally compile out tracking the source location. + // As such it intentionally lacks all the features but the default constructor so that the API + // isn't accidentally used in the wrong compilation context. +}; + +KJ_UNUSED static kj::String KJ_STRINGIFY(const NoopSourceLocation& l) { + return kj::String(); +} +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/std/iostream.h b/c++/src/kj/std/iostream.h index 627e0fcf86..f909caff88 100644 --- a/c++/src/kj/std/iostream.h +++ b/c++/src/kj/std/iostream.h @@ -23,16 +23,13 @@ * Compatibility layer for stdlib iostream */ -#ifndef KJ_STD_IOSTREAM_H_ -#define KJ_STD_IOSTREAM_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "../io.h" #include +KJ_BEGIN_HEADER + namespace kj { namespace std { @@ -85,4 +82,4 @@ class StdInputStream: public kj::InputStream { } // namespace std } // namespace kj -#endif // KJ_STD_IOSTREAM_H_ +KJ_END_HEADER diff --git a/c++/src/kj/string-test.c++ b/c++/src/kj/string-test.c++ index 700bb5b770..620c0585d1 100644 --- a/c++/src/kj/string-test.c++ +++ b/c++/src/kj/string-test.c++ @@ -23,6 +23,8 @@ #include #include #include "vector.h" +#include +#include namespace kj { namespace _ { // private @@ -35,6 +37,12 @@ TEST(String, Str) { EXPECT_EQ("foo", str('f', 'o', 'o')); EXPECT_EQ("123 234 -123 e7", str((int8_t)123, " ", (uint8_t)234, " ", (int8_t)-123, " ", hex((uint8_t)0xe7))); + EXPECT_EQ("-128 -32768 -2147483648 -9223372036854775808", + str((signed char)-128, ' ', (signed short)-32768, ' ', + ((int)-2147483647) - 1, ' ', ((long long)-9223372036854775807ll) - 1)) + EXPECT_EQ("ff ffff ffffffff ffffffffffffffff", + str(hex((uint8_t)0xff), ' ', hex((uint16_t)0xffff), ' ', hex((uint32_t)0xffffffffu), ' ', + hex((uint64_t)0xffffffffffffffffull))); char buf[3] = {'f', 'o', 'o'}; ArrayPtr a = buf; @@ -54,6 +62,12 @@ TEST(String, Str) { EXPECT_EQ("foo", str(mv(f))); } +TEST(String, Nullptr) { + EXPECT_EQ(String(nullptr), ""); + EXPECT_EQ(StringPtr(String(nullptr)).size(), 0u); + EXPECT_EQ(StringPtr(String(nullptr))[0], '\0'); +} + TEST(String, StartsEndsWith) { EXPECT_TRUE(StringPtr("foobar").startsWith("foo")); EXPECT_FALSE(StringPtr("foobar").startsWith("bar")); @@ -151,6 +165,91 @@ TEST(String, parseAs) { EXPECT_EQ(heapString("1").parseAs(), 1); } +TEST(String, tryParseAs) { + KJ_EXPECT(StringPtr("0").tryParseAs() == 0.0); + KJ_EXPECT(StringPtr("0").tryParseAs() == 0.0); + KJ_EXPECT(StringPtr("0.0").tryParseAs() == 0.0); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1.0); + KJ_EXPECT(StringPtr("1.0").tryParseAs() == 1.0); + KJ_EXPECT(StringPtr("1e100").tryParseAs() == 1e100); + KJ_EXPECT(StringPtr("inf").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("infinity").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("INF").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("INFINITY").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("1e100000").tryParseAs() == inf()); + KJ_EXPECT(StringPtr("-inf").tryParseAs() == -inf()); + KJ_EXPECT(StringPtr("-infinity").tryParseAs() == -inf()); + KJ_EXPECT(StringPtr("-INF").tryParseAs() == -inf()); + KJ_EXPECT(StringPtr("-INFINITY").tryParseAs() == -inf()); + KJ_EXPECT(StringPtr("-1e100000").tryParseAs() == -inf()); + KJ_EXPECT(isNaN(StringPtr("nan").tryParseAs().orDefault(0.0)) == true); + KJ_EXPECT(isNaN(StringPtr("NAN").tryParseAs().orDefault(0.0)) == true); + KJ_EXPECT(isNaN(StringPtr("NaN").tryParseAs().orDefault(0.0)) == true); + KJ_EXPECT(StringPtr("").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("1a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("+-1").tryParseAs() == nullptr); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1.0); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("9223372036854775807").tryParseAs() == 9223372036854775807LL); + KJ_EXPECT(StringPtr("-9223372036854775808").tryParseAs() == -9223372036854775808ULL); + KJ_EXPECT(StringPtr("9223372036854775808").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("-9223372036854775809").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("1a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("+-1").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("010").tryParseAs() == 10); + KJ_EXPECT(StringPtr("0010").tryParseAs() == 10); + KJ_EXPECT(StringPtr("0x10").tryParseAs() == 16); + KJ_EXPECT(StringPtr("0X10").tryParseAs() == 16); + KJ_EXPECT(StringPtr("-010").tryParseAs() == -10); + KJ_EXPECT(StringPtr("-0x10").tryParseAs() == -16); + KJ_EXPECT(StringPtr("-0X10").tryParseAs() == -16); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("0").tryParseAs() == 0); + KJ_EXPECT(StringPtr("18446744073709551615").tryParseAs() == 18446744073709551615ULL); + KJ_EXPECT(StringPtr("-1").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("18446744073709551616").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("1a").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("+-1").tryParseAs() == nullptr); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("2147483647").tryParseAs() == 2147483647); + KJ_EXPECT(StringPtr("-2147483648").tryParseAs() == -2147483648); + KJ_EXPECT(StringPtr("2147483648").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("-2147483649").tryParseAs() == nullptr); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("0").tryParseAs() == 0U); + KJ_EXPECT(StringPtr("4294967295").tryParseAs() == 4294967295U); + KJ_EXPECT(StringPtr("-1").tryParseAs() == nullptr); + KJ_EXPECT(StringPtr("4294967296").tryParseAs() == nullptr); + + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + KJ_EXPECT(StringPtr("1").tryParseAs() == 1); + + KJ_EXPECT(heapString("1").tryParseAs() == 1); +} + #if KJ_COMPILER_SUPPORTS_STL_STRING_INTEROP TEST(String, StlInterop) { std::string foo = "foo"; @@ -253,6 +352,39 @@ KJ_TEST("String == String") { KJ_EXPECT_NOMAGIC(a != c); } +KJ_TEST("float stringification and parsing is not locale-dependent") { + // Remember the old locale, set it back when we're done. + char* oldLocaleCstr = setlocale(LC_NUMERIC, nullptr); + KJ_ASSERT(oldLocaleCstr != nullptr); + auto oldLocale = kj::str(oldLocaleCstr); + KJ_DEFER(setlocale(LC_NUMERIC, oldLocale.cStr())); + + // Set the locale to "C". + KJ_ASSERT(setlocale(LC_NUMERIC, "C") != nullptr); + + KJ_ASSERT(kj::str(1.5) == "1.5"); + KJ_ASSERT(kj::str(1.5f) == "1.5"); + KJ_EXPECT("1.5"_kj.parseAs() == 1.5); + KJ_EXPECT("1.5"_kj.parseAs() == 1.5); + + if (setlocale(LC_NUMERIC, "es_ES") == nullptr && + setlocale(LC_NUMERIC, "es_ES.utf8") == nullptr && + setlocale(LC_NUMERIC, "es_ES.UTF-8") == nullptr) { + // Some systems may not have the desired locale available. + KJ_LOG(WARNING, "Couldn't set locale to es_ES. Skipping this test."); + } else { + KJ_EXPECT(kj::str(1.5) == "1.5"); + KJ_EXPECT(kj::str(1.5f) == "1.5"); + KJ_EXPECT("1.5"_kj.parseAs() == 1.5); + KJ_EXPECT("1.5"_kj.parseAs() == 1.5); + } +} + +KJ_TEST("ConstString") { + kj::ConstString theString = "it's a const string!"_kjc; + KJ_EXPECT(theString == "it's a const string!"); +} + } // namespace } // namespace _ (private) } // namespace kj diff --git a/c++/src/kj/string-tree.c++ b/c++/src/kj/string-tree.c++ index c4f23977cb..94d2055f60 100644 --- a/c++/src/kj/string-tree.c++ +++ b/c++/src/kj/string-tree.c++ @@ -53,11 +53,21 @@ String StringTree::flatten() const { return result; } -void StringTree::flattenTo(char* __restrict__ target) const { +char* StringTree::flattenTo(char* __restrict__ target) const { visit([&target](ArrayPtr text) { memcpy(target, text.begin(), text.size()); target += text.size(); }); + return target; +} + +char* StringTree::flattenTo(char* __restrict__ target, char* limit) const { + visit([&target,limit](ArrayPtr text) { + size_t size = kj::min(text.size(), limit - target); + memcpy(target, text.begin(), size); + target += size; + }); + return target; } } // namespace kj diff --git a/c++/src/kj/string-tree.h b/c++/src/kj/string-tree.h index 70a46319ef..19281bdc18 100644 --- a/c++/src/kj/string-tree.h +++ b/c++/src/kj/string-tree.h @@ -19,15 +19,12 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_STRING_TREE_H_ -#define KJ_STRING_TREE_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "string.h" +KJ_BEGIN_HEADER + namespace kj { class StringTree { @@ -63,8 +60,10 @@ class StringTree { // TODO(someday): flatten() when *this is an rvalue and when branches.size() == 0 could simply // return `kj::mv(text)`. Requires reference qualifiers (Clang 3.3 / GCC 4.8). - void flattenTo(char* __restrict__ target) const; - // Copy the contents to the given character array. Does not add a NUL terminator. + char* flattenTo(char* __restrict__ target) const; + char* flattenTo(char* __restrict__ target, char* limit) const; + // Copy the contents to the given character array. Does not add a NUL terminator. Returns a + // pointer just past the end of what was filled. private: size_t size_; @@ -125,6 +124,14 @@ char* fill(char* __restrict__ target, const StringTree& first, Rest&&... rest) { return fill(target + first.size(), kj::fwd(rest)...); } +template +char* fillLimited(char* __restrict__ target, char* limit, const StringTree& first, Rest&&... rest) { + // Make str() work with stringifiers that return StringTree by patching fill(). + + target = first.flattenTo(target, limit); + return fillLimited(target + first.size(), limit, kj::fwd(rest)...); +} + template constexpr bool isStringTree() { return false; } template <> constexpr bool isStringTree() { return true; } @@ -209,4 +216,4 @@ StringTree strTree(Params&&... params) { } // namespace kj -#endif // KJ_STRING_TREE_H_ +KJ_END_HEADER diff --git a/c++/src/kj/string.c++ b/c++/src/kj/string.c++ index dd8be9e937..cf2c5fcaa9 100644 --- a/c++/src/kj/string.c++ +++ b/c++/src/kj/string.c++ @@ -25,14 +25,10 @@ #include #include #include +#include namespace kj { -#if _MSC_VER -#pragma warning(disable: 4996) -// Warns that sprintf() is buffer-overrunny. We know that, it's cool. -#endif - namespace { bool isHex(const char *s) { if (*s == '-') s++; @@ -50,6 +46,17 @@ long long parseSigned(const StringPtr& s, long long min, long long max) { return value; } +Maybe tryParseSigned(const StringPtr& s, long long min, long long max) { + if (s == nullptr) { return nullptr; } // String does not contain valid number. + char *endPtr; + errno = 0; + auto value = strtoll(s.begin(), &endPtr, isHex(s.cStr()) ? 16 : 10); + if (endPtr != s.end() || errno == ERANGE || value < min || max < value) { + return nullptr; + } + return value; +} + unsigned long long parseUnsigned(const StringPtr& s, unsigned long long max) { KJ_REQUIRE(s != nullptr, "String does not contain valid number", s) { return 0; } char *endPtr; @@ -63,6 +70,15 @@ unsigned long long parseUnsigned(const StringPtr& s, unsigned long long max) { return value; } +Maybe tryParseUnsigned(const StringPtr& s, unsigned long long max) { + if (s == nullptr) { return nullptr; } // String does not contain valid number. + char *endPtr; + errno = 0; + auto value = strtoull(s.begin(), &endPtr, isHex(s.cStr()) ? 16 : 10); + if (endPtr != s.end() || errno == ERANGE || max < value || s[0] == '-') { return nullptr; } + return value; +} + template T parseInteger(const StringPtr& s) { if (static_cast(minValue) < 0) { @@ -75,13 +91,16 @@ T parseInteger(const StringPtr& s) { } } -double parseDouble(const StringPtr& s) { - KJ_REQUIRE(s != nullptr, "String does not contain valid number", s) { return 0; } - char *endPtr; - errno = 0; - auto value = strtod(s.begin(), &endPtr); - KJ_REQUIRE(endPtr == s.end(), "String does not contain valid floating number", s) { return 0; } - return value; +template +Maybe tryParseInteger(const StringPtr& s) { + if (static_cast(minValue) < 0) { + long long min = static_cast(minValue); + long long max = static_cast(maxValue); + return static_cast>(tryParseSigned(s, min, max)); + } else { + unsigned long long max = static_cast(maxValue); + return static_cast>(tryParseUnsigned(s, max)); + } } } // namespace @@ -100,8 +119,21 @@ PARSE_AS_INTEGER(unsigned long); PARSE_AS_INTEGER(long long); PARSE_AS_INTEGER(unsigned long long); #undef PARSE_AS_INTEGER -template <> double StringPtr::parseAs() const { return parseDouble(*this); } -template <> float StringPtr::parseAs() const { return parseDouble(*this); } + +#define TRY_PARSE_AS_INTEGER(T) \ + template <> Maybe StringPtr::tryParseAs() const { return tryParseInteger(*this); } +TRY_PARSE_AS_INTEGER(char); +TRY_PARSE_AS_INTEGER(signed char); +TRY_PARSE_AS_INTEGER(unsigned char); +TRY_PARSE_AS_INTEGER(short); +TRY_PARSE_AS_INTEGER(unsigned short); +TRY_PARSE_AS_INTEGER(int); +TRY_PARSE_AS_INTEGER(unsigned int); +TRY_PARSE_AS_INTEGER(long); +TRY_PARSE_AS_INTEGER(unsigned long); +TRY_PARSE_AS_INTEGER(long long); +TRY_PARSE_AS_INTEGER(unsigned long long); +#undef TRY_PARSE_AS_INTEGER String heapString(size_t size) { char* buffer = _::HeapArrayDisposer::allocate(size + 1); @@ -111,23 +143,46 @@ String heapString(size_t size) { String heapString(const char* value, size_t size) { char* buffer = _::HeapArrayDisposer::allocate(size + 1); - memcpy(buffer, value, size); + if (size != 0u) { + memcpy(buffer, value, size); + } buffer[size] = '\0'; return String(buffer, size, _::HeapArrayDisposer::instance); } -#define HEXIFY_INT(type, format) \ +template +static CappedArray hexImpl(T i) { + // We don't use sprintf() because it's not async-signal-safe (for strPreallocated()). + CappedArray result; + uint8_t reverse[sizeof(T) * 2]; + uint8_t* p = reverse; + if (i == 0) { + *p++ = 0; + } else { + while (i > 0) { + *p++ = i % 16; + i /= 16; + } + } + + char* p2 = result.begin(); + while (p > reverse) { + *p2++ = "0123456789abcdef"[*--p]; + } + result.setSize(p2 - result.begin()); + return result; +} + +#define HEXIFY_INT(type) \ CappedArray hex(type i) { \ - CappedArray result; \ - result.setSize(sprintf(result.begin(), format, i)); \ - return result; \ + return hexImpl(i); \ } -HEXIFY_INT(unsigned char, "%x"); -HEXIFY_INT(unsigned short, "%x"); -HEXIFY_INT(unsigned int, "%x"); -HEXIFY_INT(unsigned long, "%lx"); -HEXIFY_INT(unsigned long long, "%llx"); +HEXIFY_INT(unsigned char); +HEXIFY_INT(unsigned short); +HEXIFY_INT(unsigned int); +HEXIFY_INT(unsigned long); +HEXIFY_INT(unsigned long long); #undef HEXIFY_INT @@ -141,27 +196,58 @@ StringPtr Stringifier::operator*(bool b) const { return b ? StringPtr("true") : StringPtr("false"); } -#define STRINGIFY_INT(type, format) \ +template +static CappedArray stringifyImpl(T i) { + // We don't use sprintf() because it's not async-signal-safe (for strPreallocated()). + CappedArray result; + bool negative = i < 0; + // Note that if `i` is the most-negative value, negating it produces the same bit value. But + // since it's a signed integer, this is considered an overflow. We therefore must make it + // unsigned first, then negate it, to avoid ubsan complaining. + Unsigned u = i; + if (negative) u = -u; + uint8_t reverse[sizeof(T) * 3 + 1]; + uint8_t* p = reverse; + if (u == 0) { + *p++ = 0; + } else { + while (u > 0) { + *p++ = u % 10; + u /= 10; + } + } + + char* p2 = result.begin(); + if (negative) *p2++ = '-'; + while (p > reverse) { + *p2++ = '0' + *--p; + } + result.setSize(p2 - result.begin()); + return result; +} + +#define STRINGIFY_INT(type, unsigned) \ CappedArray Stringifier::operator*(type i) const { \ - CappedArray result; \ - result.setSize(sprintf(result.begin(), format, i)); \ - return result; \ + return stringifyImpl(i); \ } -STRINGIFY_INT(signed char, "%d"); -STRINGIFY_INT(unsigned char, "%u"); -STRINGIFY_INT(short, "%d"); -STRINGIFY_INT(unsigned short, "%u"); -STRINGIFY_INT(int, "%d"); -STRINGIFY_INT(unsigned int, "%u"); -STRINGIFY_INT(long, "%ld"); -STRINGIFY_INT(unsigned long, "%lu"); -STRINGIFY_INT(long long, "%lld"); -STRINGIFY_INT(unsigned long long, "%llu"); -STRINGIFY_INT(const void*, "%p"); +STRINGIFY_INT(signed char, uint); +STRINGIFY_INT(unsigned char, uint); +STRINGIFY_INT(short, uint); +STRINGIFY_INT(unsigned short, uint); +STRINGIFY_INT(int, uint); +STRINGIFY_INT(unsigned int, uint); +STRINGIFY_INT(long, unsigned long); +STRINGIFY_INT(unsigned long, unsigned long); +STRINGIFY_INT(long long, unsigned long long); +STRINGIFY_INT(unsigned long long, unsigned long long); #undef STRINGIFY_INT +CappedArray Stringifier::operator*(const void* i) const { \ + return hexImpl(reinterpret_cast(i)); +} + namespace { // ---------------------------------------------------------------------- @@ -417,6 +503,76 @@ char* FloatToBuffer(float value, char* buffer) { return buffer; } +// ---------------------------------------------------------------------- +// NoLocaleStrtod() +// This code will make you cry. +// ---------------------------------------------------------------------- + +namespace { + +// Returns a string identical to *input except that the character pointed to +// by radix_pos (which should be '.') is replaced with the locale-specific +// radix character. +kj::String LocalizeRadix(const char* input, const char* radix_pos) { + // Determine the locale-specific radix character by calling sprintf() to + // print the number 1.5, then stripping off the digits. As far as I can + // tell, this is the only portable, thread-safe way to get the C library + // to divuldge the locale's radix character. No, localeconv() is NOT + // thread-safe. + char temp[16]; + int size = snprintf(temp, sizeof(temp), "%.1f", 1.5); + KJ_ASSERT(temp[0] == '1'); + KJ_ASSERT(temp[size-1] == '5'); + KJ_ASSERT(size <= 6); + + // Now replace the '.' in the input with it. + return kj::str( + kj::arrayPtr(input, radix_pos), + kj::arrayPtr(temp + 1, size - 2), + kj::StringPtr(radix_pos + 1)); +} + +} // namespace + +double NoLocaleStrtod(const char* text, char** original_endptr) { + // We cannot simply set the locale to "C" temporarily with setlocale() + // as this is not thread-safe. Instead, we try to parse in the current + // locale first. If parsing stops at a '.' character, then this is a + // pretty good hint that we're actually in some other locale in which + // '.' is not the radix character. + + char* temp_endptr; + double result = strtod(text, &temp_endptr); + if (original_endptr != NULL) *original_endptr = temp_endptr; + if (*temp_endptr != '.') return result; + + // Parsing halted on a '.'. Perhaps we're in a different locale? Let's + // try to replace the '.' with a locale-specific radix character and + // try again. + kj::String localized = LocalizeRadix(text, temp_endptr); + const char* localized_cstr = localized.cStr(); + char* localized_endptr; + result = strtod(localized_cstr, &localized_endptr); + if ((localized_endptr - localized_cstr) > + (temp_endptr - text)) { + // This attempt got further, so replacing the decimal must have helped. + // Update original_endptr to point at the right location. + if (original_endptr != NULL) { + // size_diff is non-zero if the localized radix has multiple bytes. + int size_diff = localized.size() - strlen(text); + // const_cast is necessary to match the strtod() interface. + *original_endptr = const_cast( + text + (localized_endptr - localized_cstr - size_diff)); + } + } + + return result; +} + +// ---------------------------------------------------------------------- +// End of code copied from Protobuf +// ---------------------------------------------------------------------- + } // namespace CappedArray Stringifier::operator*(float f) const { @@ -431,5 +587,52 @@ CappedArray Stringifier::operator*(double f) const { return result; } +double parseDouble(const StringPtr& s) { + KJ_REQUIRE(s != nullptr, "String does not contain valid number", s) { return 0; } + char *endPtr; + errno = 0; + auto value = _::NoLocaleStrtod(s.begin(), &endPtr); + KJ_REQUIRE(endPtr == s.end(), "String does not contain valid floating number", s) { return 0; } +#if _WIN32 || __CYGWIN__ || __BIONIC__ + // When Windows' strtod() parses "nan", it returns a value with the sign bit set. But, our + // preferred canonical value for NaN does not have the sign bit set, and all other platforms + // return one without the sign bit set. So, on Windows, detect NaN and return our preferred + // version. + // + // Cygwin seemingly does not try to emulate Linux behavior here, but rather allows Windows' + // behavior to leak through. (Conversely, WINE actually produces the Linux behavior despite + // trying to behave like Win32...) + // + // Bionic (Android) failed the unit test and so I added it to the list without investigating + // further. + if (isNaN(value)) { + // NaN + return kj::nan(); + } +#endif + return value; +} + +Maybe tryParseDouble(const StringPtr& s) { + if(s == nullptr) { return nullptr; } + char *endPtr; + errno = 0; + auto value = _::NoLocaleStrtod(s.begin(), &endPtr); + if (endPtr != s.end()) { return nullptr; } +#if _WIN32 || __CYGWIN__ || __BIONIC__ + if (isNaN(value)) { + return kj::nan(); + } +#endif + return value; +} + } // namespace _ (private) + +template <> double StringPtr::parseAs() const { return _::parseDouble(*this); } +template <> float StringPtr::parseAs() const { return _::parseDouble(*this); } + +template <> Maybe StringPtr::tryParseAs() const { return _::tryParseDouble(*this); } +template <> Maybe StringPtr::tryParseAs() const { return _::tryParseDouble(*this); } + } // namespace kj diff --git a/c++/src/kj/string.h b/c++/src/kj/string.h index 038604badf..100c7deda5 100644 --- a/c++/src/kj/string.h +++ b/c++/src/kj/string.h @@ -19,23 +19,43 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_STRING_H_ -#define KJ_STRING_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include #include "array.h" +#include "kj/common.h" #include +KJ_BEGIN_HEADER + namespace kj { + class StringPtr; + class LiteralStringConst; + class String; + class ConstString; + + class StringTree; // string-tree.h +} + +constexpr kj::StringPtr operator "" _kj(const char* str, size_t n); +// You can append _kj to a string literal to make its type be StringPtr. There are a few cases +// where you must do this for correctness: +// - When you want to declare a constexpr StringPtr. Without _kj, this is a compile error. +// - When you want to initialize a static/global StringPtr from a string literal without forcing +// global constructor code to run at dynamic initialization time. +// - When you have a string literal that contains NUL characters. Without _kj, the string will +// be considered to end at the first NUL. +// - When you want to initialize an ArrayPtr from a string literal, without including +// the NUL terminator in the data. (Initializing an ArrayPtr from a regular string literal is +// a compile error specifically due to this ambiguity.) +// +// In other cases, there should be no difference between initializing a StringPtr from a regular +// string literal vs. one with _kj (assuming the compiler is able to optimize away strlen() on a +// string literal). -class StringPtr; -class String; +constexpr kj::LiteralStringConst operator "" _kjc(const char* str, size_t n); -class StringTree; // string-tree.h +namespace kj { // Our STL string SFINAE trick does not work with GCC 4.7, but it works with Clang and GCC 4.8, so // we'll just preprocess it out if not supported. @@ -54,12 +74,28 @@ class StringPtr { public: inline StringPtr(): content("", 1) {} inline StringPtr(decltype(nullptr)): content("", 1) {} - inline StringPtr(const char* value): content(value, strlen(value) + 1) {} - inline StringPtr(const char* value, size_t size): content(value, size + 1) { + inline StringPtr(const char* value KJ_LIFETIMEBOUND): content(value, strlen(value) + 1) {} + inline StringPtr(const char* value KJ_LIFETIMEBOUND, size_t size): content(value, size + 1) { KJ_IREQUIRE(value[size] == '\0', "StringPtr must be NUL-terminated."); } - inline StringPtr(const char* begin, const char* end): StringPtr(begin, end - begin) {} - inline StringPtr(const String& value); + inline StringPtr(const char* begin KJ_LIFETIMEBOUND, const char* end KJ_LIFETIMEBOUND): StringPtr(begin, end - begin) {} + inline StringPtr(String&& value KJ_LIFETIMEBOUND) : StringPtr(value) {} + inline StringPtr(const String& value KJ_LIFETIMEBOUND); + inline StringPtr(const ConstString& value KJ_LIFETIMEBOUND); + StringPtr& operator=(String&& value) = delete; + inline StringPtr& operator=(decltype(nullptr)) { + content = ArrayPtr("", 1); + return *this; + } + +#if __cpp_char8_t + inline StringPtr(const char8_t* value KJ_LIFETIMEBOUND): StringPtr(reinterpret_cast(value)) {} + inline StringPtr(const char8_t* value KJ_LIFETIMEBOUND, size_t size) + : StringPtr(reinterpret_cast(value), size) {} + inline StringPtr(const char8_t* begin KJ_LIFETIMEBOUND, const char8_t* end KJ_LIFETIMEBOUND) + : StringPtr(reinterpret_cast(begin), reinterpret_cast(end)) {} + // KJ strings are and always have been UTF-8, so screw this C++20 char8_t stuff. +#endif #if __cplusplus >= 202000L inline StringPtr(const char8_t* value): StringPtr(reinterpret_cast(value)) {} @@ -71,21 +107,26 @@ class StringPtr { #endif #if KJ_COMPILER_SUPPORTS_STL_STRING_INTEROP - template ().c_str())> - inline StringPtr(const T& t): StringPtr(t.c_str()) {} - // Allow implicit conversion from any class that has a c_str() method (namely, std::string). + template < + typename T, + typename = decltype(instance().c_str()), + typename = decltype(instance().size())> + inline StringPtr(const T& t KJ_LIFETIMEBOUND): StringPtr(t.c_str(), t.size()) {} + // Allow implicit conversion from any class that has a c_str() and a size() method (namely, std::string). // We use a template trick to detect std::string in order to avoid including the header for // those who don't want it. - - template ().c_str())> - inline operator T() const { return cStr(); } - // Allow implicit conversion to any class that has a c_str() method (namely, std::string). + template < + typename T, + typename = decltype(instance().c_str()), + typename = decltype(instance().size())> + inline operator T() const { return {cStr(), size()}; } + // Allow implicit conversion to any class that has a c_str() method and a size() method (namely, std::string). // We use a template trick to detect std::string in order to avoid including the header for // those who don't want it. #endif - inline operator ArrayPtr() const; - inline ArrayPtr asArray() const; + inline constexpr operator ArrayPtr() const; + inline constexpr ArrayPtr asArray() const; inline ArrayPtr asBytes() const { return asArray().asBytes(); } // Result does not include NUL terminator. @@ -97,14 +138,18 @@ class StringPtr { inline char operator[](size_t index) const { return content[index]; } - inline const char* begin() const { return content.begin(); } - inline const char* end() const { return content.end() - 1; } + inline constexpr const char* begin() const { return content.begin(); } + inline constexpr const char* end() const { return content.end() - 1; } - inline bool operator==(decltype(nullptr)) const { return content.size() <= 1; } - inline bool operator!=(decltype(nullptr)) const { return content.size() > 1; } + inline constexpr bool operator==(decltype(nullptr)) const { return content.size() <= 1; } +#if !__cpp_impl_three_way_comparison + inline constexpr bool operator!=(decltype(nullptr)) const { return content.size() > 1; } +#endif inline bool operator==(const StringPtr& other) const; +#if !__cpp_impl_three_way_comparison inline bool operator!=(const StringPtr& other) const { return !(*this == other); } +#endif inline bool operator< (const StringPtr& other) const; inline bool operator> (const StringPtr& other) const { return other < *this; } inline bool operator<=(const StringPtr& other) const { return !(other < *this); } @@ -115,11 +160,11 @@ class StringPtr { // A string slice is only NUL-terminated if it is a suffix, so slice() has a one-parameter // version that assumes end = size(). - inline bool startsWith(const StringPtr& other) const; - inline bool endsWith(const StringPtr& other) const; + inline bool startsWith(const StringPtr& other) const { return asArray().startsWith(other);} + inline bool endsWith(const StringPtr& other) const { return asArray().endsWith(other); } - inline Maybe findFirst(char c) const; - inline Maybe findLast(char c) const; + inline Maybe findFirst(char c) const { return asArray().findFirst(c); } + inline Maybe findLast(char c) const { return asArray().findLast(c); } template T parseAs() const; @@ -128,14 +173,26 @@ class StringPtr { // Integer numbers prefixed by "0" are parsed in base 10 (unlike strtoi with base 0). // Overflowed integer numbers throw exception. // Overflowed floating numbers return inf. + template + Maybe tryParseAs() const; + // Same as parseAs, but rather than throwing an exception we return NULL. + + template + ConstString attach(Attachments&&... attachments) const KJ_WARN_UNUSED_RESULT; + ConstString attach() const KJ_WARN_UNUSED_RESULT; + // Like ArrayPtr::attach(), but instead promotes a StringPtr into a ConstString. Generally the + // attachment should be an object that somehow owns the String that the StringPtr is pointing at. private: inline explicit constexpr StringPtr(ArrayPtr content): content(content) {} + friend constexpr StringPtr (::operator "" _kj)(const char* str, size_t n); + friend class LiteralStringConst; ArrayPtr content; + friend class SourceLocation; }; -#if __cplusplus < 202000L +#if !__cpp_impl_three_way_comparison inline bool operator==(const char* a, const StringPtr& b) { return b == a; } inline bool operator!=(const char* a, const StringPtr& b) { return b != a; } #endif @@ -154,6 +211,29 @@ template <> unsigned long long StringPtr::parseAs() const; template <> float StringPtr::parseAs() const; template <> double StringPtr::parseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; +template <> Maybe StringPtr::tryParseAs() const; + +class LiteralStringConst: public StringPtr { +public: + inline operator ConstString() const; + +private: + inline explicit constexpr LiteralStringConst(ArrayPtr content): StringPtr(content) {} + friend constexpr LiteralStringConst (::operator "" _kjc)(const char* str, size_t n); +}; + // ======================================================================================= // String -- A NUL-terminated Array containing UTF-8 text. // @@ -173,43 +253,52 @@ class String { inline explicit String(Array buffer); // Does not copy. Requires `buffer` ends with `\0`. - inline operator ArrayPtr(); - inline operator ArrayPtr() const; - inline ArrayPtr asArray(); - inline ArrayPtr asArray() const; - inline ArrayPtr asBytes() { return asArray().asBytes(); } - inline ArrayPtr asBytes() const { return asArray().asBytes(); } + inline operator ArrayPtr() KJ_LIFETIMEBOUND; + inline operator ArrayPtr() const KJ_LIFETIMEBOUND; + inline ArrayPtr asArray() KJ_LIFETIMEBOUND; + inline ArrayPtr asArray() const KJ_LIFETIMEBOUND; + inline ArrayPtr asBytes() KJ_LIFETIMEBOUND { return asArray().asBytes(); } + inline ArrayPtr asBytes() const KJ_LIFETIMEBOUND { return asArray().asBytes(); } // Result does not include NUL terminator. + inline StringPtr asPtr() const KJ_LIFETIMEBOUND { + // Convenience operator to return a StringPtr. + return StringPtr{*this}; + } + inline Array releaseArray() { return kj::mv(content); } // Disowns the backing array (which includes the NUL terminator) and returns it. The String value // is clobbered (as if moved away). - inline const char* cStr() const; + inline const char* cStr() const KJ_LIFETIMEBOUND; inline size_t size() const; // Result does not include NUL terminator. inline char operator[](size_t index) const; - inline char& operator[](size_t index); + inline char& operator[](size_t index) KJ_LIFETIMEBOUND; - inline char* begin(); - inline char* end(); - inline const char* begin() const; - inline const char* end() const; + inline char* begin() KJ_LIFETIMEBOUND; + inline char* end() KJ_LIFETIMEBOUND; + inline const char* begin() const KJ_LIFETIMEBOUND; + inline const char* end() const KJ_LIFETIMEBOUND; inline bool operator==(decltype(nullptr)) const { return content.size() <= 1; } inline bool operator!=(decltype(nullptr)) const { return content.size() > 1; } inline bool operator==(const StringPtr& other) const { return StringPtr(*this) == other; } +#if !__cpp_impl_three_way_comparison inline bool operator!=(const StringPtr& other) const { return StringPtr(*this) != other; } +#endif inline bool operator< (const StringPtr& other) const { return StringPtr(*this) < other; } inline bool operator> (const StringPtr& other) const { return StringPtr(*this) > other; } inline bool operator<=(const StringPtr& other) const { return StringPtr(*this) <= other; } inline bool operator>=(const StringPtr& other) const { return StringPtr(*this) >= other; } inline bool operator==(const String& other) const { return StringPtr(*this) == StringPtr(other); } +#if !__cpp_impl_three_way_comparison inline bool operator!=(const String& other) const { return StringPtr(*this) != StringPtr(other); } +#endif inline bool operator< (const String& other) const { return StringPtr(*this) < StringPtr(other); } inline bool operator> (const String& other) const { return StringPtr(*this) > StringPtr(other); } inline bool operator<=(const String& other) const { return StringPtr(*this) <= StringPtr(other); } @@ -218,26 +307,140 @@ class String { // comparisons between two strings are ambiguous. (Clang turns this into a warning, // -Wambiguous-reversed-operator, due to the stupidity...) - inline bool startsWith(const StringPtr& other) const { return StringPtr(*this).startsWith(other);} - inline bool endsWith(const StringPtr& other) const { return StringPtr(*this).endsWith(other); } + inline bool operator==(const ConstString& other) const { return StringPtr(*this) == StringPtr(other); } +#if !__cpp_impl_three_way_comparison + inline bool operator!=(const ConstString& other) const { return StringPtr(*this) != StringPtr(other); } +#endif + inline bool operator< (const ConstString& other) const { return StringPtr(*this) < StringPtr(other); } + inline bool operator> (const ConstString& other) const { return StringPtr(*this) > StringPtr(other); } + inline bool operator<=(const ConstString& other) const { return StringPtr(*this) <= StringPtr(other); } + inline bool operator>=(const ConstString& other) const { return StringPtr(*this) >= StringPtr(other); } - inline StringPtr slice(size_t start) const { return StringPtr(*this).slice(start); } - inline ArrayPtr slice(size_t start, size_t end) const { + inline bool startsWith(const StringPtr& other) const { return asArray().startsWith(other);} + inline bool endsWith(const StringPtr& other) const { return asArray().endsWith(other); } + + inline StringPtr slice(size_t start) const KJ_LIFETIMEBOUND { + return StringPtr(*this).slice(start); + } + inline ArrayPtr slice(size_t start, size_t end) const KJ_LIFETIMEBOUND { return StringPtr(*this).slice(start, end); } - inline Maybe findFirst(char c) const { return StringPtr(*this).findFirst(c); } - inline Maybe findLast(char c) const { return StringPtr(*this).findLast(c); } + inline Maybe findFirst(char c) const { return asArray().findFirst(c); } + inline Maybe findLast(char c) const { return asArray().findLast(c); } template T parseAs() const { return StringPtr(*this).parseAs(); } // Parse as number + template + Maybe tryParseAs() const { return StringPtr(*this).tryParseAs(); } + private: Array content; }; -#if __cplusplus < 202000L +// ======================================================================================= +// ConstString -- Same as String, but the backing buffer is const. +// +// This has the useful property that it can reference a string literal without allocating +// a copy. Any String can also convert (by move) to ConstString, transferring ownership of +// the buffer. + +class ConstString { +public: + ConstString() = default; + inline ConstString(decltype(nullptr)): content(nullptr) {} + inline ConstString(const char* value, size_t size, const ArrayDisposer& disposer); + // Does not copy. `size` does not include NUL terminator, but `value` must be NUL-terminated. + inline explicit ConstString(Array buffer); + // Does not copy. Requires `buffer` ends with `\0`. + inline explicit ConstString(String&& string): content(string.releaseArray()) {} + // Does not copy. Ownership is transfered. + + inline operator ArrayPtr() const KJ_LIFETIMEBOUND; + inline ArrayPtr asArray() const KJ_LIFETIMEBOUND; + inline ArrayPtr asBytes() const KJ_LIFETIMEBOUND { return asArray().asBytes(); } + // Result does not include NUL terminator. + + inline StringPtr asPtr() const KJ_LIFETIMEBOUND { + // Convenience operator to return a StringPtr. + return StringPtr{*this}; + } + + inline Array releaseArray() { return kj::mv(content); } + // Disowns the backing array (which includes the NUL terminator) and returns it. The ConstString value + // is clobbered (as if moved away). + + inline const char* cStr() const KJ_LIFETIMEBOUND; + + inline size_t size() const; + // Result does not include NUL terminator. + + inline char operator[](size_t index) const; + inline char& operator[](size_t index) KJ_LIFETIMEBOUND; + + inline const char* begin() const KJ_LIFETIMEBOUND; + inline const char* end() const KJ_LIFETIMEBOUND; + + inline bool operator==(decltype(nullptr)) const { return content.size() <= 1; } + inline bool operator!=(decltype(nullptr)) const { return content.size() > 1; } + + inline bool operator==(const StringPtr& other) const { return StringPtr(*this) == other; } +#if !__cpp_impl_three_way_comparison + inline bool operator!=(const StringPtr& other) const { return StringPtr(*this) != other; } +#endif + inline bool operator< (const StringPtr& other) const { return StringPtr(*this) < other; } + inline bool operator> (const StringPtr& other) const { return StringPtr(*this) > other; } + inline bool operator<=(const StringPtr& other) const { return StringPtr(*this) <= other; } + inline bool operator>=(const StringPtr& other) const { return StringPtr(*this) >= other; } + + inline bool operator==(const String& other) const { return StringPtr(*this) == StringPtr(other); } +#if !__cpp_impl_three_way_comparison + inline bool operator!=(const String& other) const { return StringPtr(*this) != StringPtr(other); } +#endif + inline bool operator< (const String& other) const { return StringPtr(*this) < StringPtr(other); } + inline bool operator> (const String& other) const { return StringPtr(*this) > StringPtr(other); } + inline bool operator<=(const String& other) const { return StringPtr(*this) <= StringPtr(other); } + inline bool operator>=(const String& other) const { return StringPtr(*this) >= StringPtr(other); } + + inline bool operator==(const ConstString& other) const { return StringPtr(*this) == StringPtr(other); } +#if !__cpp_impl_three_way_comparison + inline bool operator!=(const ConstString& other) const { return StringPtr(*this) != StringPtr(other); } +#endif + inline bool operator< (const ConstString& other) const { return StringPtr(*this) < StringPtr(other); } + inline bool operator> (const ConstString& other) const { return StringPtr(*this) > StringPtr(other); } + inline bool operator<=(const ConstString& other) const { return StringPtr(*this) <= StringPtr(other); } + inline bool operator>=(const ConstString& other) const { return StringPtr(*this) >= StringPtr(other); } + // Note that if we don't overload for `const ConstString&` specifically, then C++20 will decide that + // comparisons between two strings are ambiguous. (Clang turns this into a warning, + // -Wambiguous-reversed-operator, due to the stupidity...) + + inline bool startsWith(const StringPtr& other) const { return asArray().startsWith(other);} + inline bool endsWith(const StringPtr& other) const { return asArray().endsWith(other); } + + inline StringPtr slice(size_t start) const KJ_LIFETIMEBOUND { + return StringPtr(*this).slice(start); + } + inline ArrayPtr slice(size_t start, size_t end) const KJ_LIFETIMEBOUND { + return StringPtr(*this).slice(start, end); + } + + inline Maybe findFirst(char c) const { return asArray().findFirst(c); } + inline Maybe findLast(char c) const { return asArray().findLast(c); } + + template + T parseAs() const { return StringPtr(*this).parseAs(); } + // Parse as number + + template + Maybe tryParseAs() const { return StringPtr(*this).tryParseAs(); } + +private: + Array content; +}; + +#if !__cpp_impl_three_way_comparison inline bool operator==(const char* a, const String& b) { return b == a; } inline bool operator!=(const char* a, const String& b) { return b != a; } #endif @@ -268,9 +471,12 @@ inline size_t sum(std::initializer_list nums) { } inline char* fill(char* ptr) { return ptr; } +inline char* fillLimited(char* ptr, char* limit) { return ptr; } template char* fill(char* __restrict__ target, const StringTree& first, Rest&&... rest); +template +char* fillLimited(char* __restrict__ target, char* limit, const StringTree& first, Rest&&... rest); // Make str() work with stringifiers that return StringTree by patching fill(). // // Defined in string-tree.h. @@ -299,6 +505,31 @@ inline String concat(String&& arr) { return kj::mv(arr); } +template +char* fillLimited(char* __restrict__ target, char* limit, const First& first, Rest&&... rest) { + auto i = first.begin(); + auto end = first.end(); + while (i != end) { + if (target == limit) return target; + *target++ = *i++; + } + return fillLimited(target, limit, kj::fwd(rest)...); +} + +template +class Delimited; +// Delimits a sequence of type T with a string delimiter. Implements kj::delimited(). + +template +char* fill(char* __restrict__ target, Delimited&& first, Rest&&... rest); +template +char* fillLimited(char* __restrict__ target, char* limit, Delimited&& first,Rest&&... rest); +template +char* fill(char* __restrict__ target, Delimited& first, Rest&&... rest); +template +char* fillLimited(char* __restrict__ target, char* limit, Delimited& first,Rest&&... rest); +// As with StringTree, we special-case Delimited. + struct Stringifier { // This is a dummy type with only one instance: STR (below). To make an arbitrary type // stringifiable, define `operator*(Stringifier, T)` to return an iterable container of `char`. @@ -314,20 +545,31 @@ struct Stringifier { inline ArrayPtr operator*(ArrayPtr s) const { return s; } inline ArrayPtr operator*(ArrayPtr s) const { return s; } - inline ArrayPtr operator*(const Array& s) const { return s; } - inline ArrayPtr operator*(const Array& s) const { return s; } + inline ArrayPtr operator*(const Array& s) const KJ_LIFETIMEBOUND { + return s; + } + inline ArrayPtr operator*(const Array& s) const KJ_LIFETIMEBOUND { return s; } template - inline ArrayPtr operator*(const CappedArray& s) const { return s; } + inline ArrayPtr operator*(const CappedArray& s) const KJ_LIFETIMEBOUND { + return s; + } template - inline ArrayPtr operator*(const FixedArray& s) const { return s; } - inline ArrayPtr operator*(const char* s) const { return arrayPtr(s, strlen(s)); } -#if __cplusplus >= 202000L - inline ArrayPtr operator*(const char8_t* s) const { + inline ArrayPtr operator*(const FixedArray& s) const KJ_LIFETIMEBOUND { + return s; + } + inline ArrayPtr operator*(const char* s) const KJ_LIFETIMEBOUND { + return arrayPtr(s, strlen(s)); + } +#if __cpp_char8_t + inline ArrayPtr operator*(const char8_t* s) const KJ_LIFETIMEBOUND { return operator*(reinterpret_cast(s)); } #endif - inline ArrayPtr operator*(const String& s) const { return s.asArray(); } + inline ArrayPtr operator*(const String& s) const KJ_LIFETIMEBOUND { + return s.asArray(); + } inline ArrayPtr operator*(const StringPtr& s) const { return s.asArray(); } + inline ArrayPtr operator*(const ConstString& s) const { return s.asArray(); } inline Range operator*(const Range& r) const { return r; } inline Repeat operator*(const Repeat& r) const { return r; } @@ -353,12 +595,7 @@ struct Stringifier { CappedArray operator*(unsigned long long i) const; CappedArray operator*(float f) const; CappedArray operator*(double f) const; - CappedArray operator*(const void* s) const; - - template - String operator*(ArrayPtr arr) const; - template - String operator*(const Array& arr) const; + CappedArray operator*(const void* s) const; #if KJ_COMPILER_SUPPORTS_STL_STRING_INTEROP // supports expression SFINAE? template ().toString())> @@ -402,6 +639,10 @@ String str(Params&&... params) { inline String str(String&& s) { return mv(s); } // Overload to prevent redundant allocation. +template +_::Delimited delimited(T&& arr, kj::StringPtr delim); +// Use to stringify an array. + template String strArray(T&& arr, const char* delim) { size_t delimLen = strlen(delim); @@ -425,25 +666,47 @@ String strArray(T&& arr, const char* delim) { return result; } -namespace _ { // private - -template -inline String Stringifier::operator*(ArrayPtr arr) const { - return strArray(arr, ", "); +template +StringPtr strPreallocated(ArrayPtr buffer, Params&&... params) { + // Like str() but writes into a preallocated buffer. If the buffer is not long enough, the result + // is truncated (but still NUL-terminated). + // + // This can be used like: + // + // char buffer[256]; + // StringPtr text = strPreallocated(buffer, params...); + // + // This is useful for optimization. It can also potentially be used safely in async signal + // handlers. HOWEVER, to use in an async signal handler, all of the stringifiers for the inputs + // must also be signal-safe. KJ guarantees signal safety when stringifying any built-in integer + // type (but NOT floating-points), basic char/byte sequences (ArrayPtr, String, etc.), as + // well as Array as long as T can also be stringified safely. To safely stringify a delimited + // array, you must use kj::delimited(arr, delim) rather than the deprecated + // kj::strArray(arr, delim). + + char* end = _::fillLimited(buffer.begin(), buffer.end() - 1, + toCharSequence(kj::fwd(params))...); + *end = '\0'; + return StringPtr(buffer.begin(), end); } -template -inline String Stringifier::operator*(const Array& arr) const { - return strArray(arr, ", "); +template ()))> +inline _::Delimited> operator*(const _::Stringifier&, ArrayPtr arr) { + return _::Delimited>(arr, ", "); } -} // namespace _ (private) +template ()))> +inline _::Delimited> operator*(const _::Stringifier&, const Array& arr) { + return _::Delimited>(arr, ", "); +} #define KJ_STRINGIFY(...) operator*(::kj::_::Stringifier, __VA_ARGS__) // Defines a stringifier for a custom type. Example: // // class Foo {...}; // inline StringPtr KJ_STRINGIFY(const Foo& foo) { return foo.name(); } +// // or perhaps +// inline String KJ_STRINGIFY(const Foo& foo) { return kj::str(foo.fld1(), ",", foo.fld2()); } // // This allows Foo to be passed to str(). // @@ -453,14 +716,15 @@ inline String Stringifier::operator*(const Array& arr) const { // ======================================================================================= // Inline implementation details. -inline StringPtr::StringPtr(const String& value): content(value.begin(), value.size() + 1) {} +inline StringPtr::StringPtr(const String& value): content(value.cStr(), value.size() + 1) {} +inline StringPtr::StringPtr(const ConstString& value): content(value.cStr(), value.size() + 1) {} -inline StringPtr::operator ArrayPtr() const { - return content.slice(0, content.size() - 1); +inline constexpr StringPtr::operator ArrayPtr() const { + return ArrayPtr(content.begin(), content.size() - 1); } -inline ArrayPtr StringPtr::asArray() const { - return content.slice(0, content.size() - 1); +inline constexpr ArrayPtr StringPtr::asArray() const { + return ArrayPtr(content.begin(), content.size() - 1); } inline bool StringPtr::operator==(const StringPtr& other) const { @@ -482,31 +746,18 @@ inline ArrayPtr StringPtr::slice(size_t start, size_t end) const { return content.slice(start, end); } -inline bool StringPtr::startsWith(const StringPtr& other) const { - return other.content.size() <= content.size() && - memcmp(content.begin(), other.content.begin(), other.size()) == 0; -} -inline bool StringPtr::endsWith(const StringPtr& other) const { - return other.content.size() <= content.size() && - memcmp(end() - other.size(), other.content.begin(), other.size()) == 0; +inline LiteralStringConst::operator ConstString() const { + return ConstString(begin(), size(), NullArrayDisposer::instance); } -inline Maybe StringPtr::findFirst(char c) const { - const char* pos = reinterpret_cast(memchr(content.begin(), c, size())); - if (pos == nullptr) { - return nullptr; - } else { - return pos - content.begin(); - } +inline ConstString StringPtr::attach() const { + // This is meant as a roundabout way to make a ConstString from a StringPtr + return ConstString(begin(), size(), NullArrayDisposer::instance); } -inline Maybe StringPtr::findLast(char c) const { - for (size_t i = size(); i > 0; --i) { - if (content[i-1] == c) { - return i-1; - } - } - return nullptr; +template +inline ConstString StringPtr::attach(Attachments&&... attachments) const { + return ConstString { .content = content.attach(kj::fwd(attachments)...) }; } inline String::operator ArrayPtr() { @@ -515,6 +766,9 @@ inline String::operator ArrayPtr() { inline String::operator ArrayPtr() const { return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); } +inline ConstString::operator ArrayPtr() const { + return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); +} inline ArrayPtr String::asArray() { return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); @@ -522,27 +776,42 @@ inline ArrayPtr String::asArray() { inline ArrayPtr String::asArray() const { return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); } +inline ArrayPtr ConstString::asArray() const { + return content == nullptr ? ArrayPtr(nullptr) : content.slice(0, content.size() - 1); +} inline const char* String::cStr() const { return content == nullptr ? "" : content.begin(); } +inline const char* ConstString::cStr() const { return content == nullptr ? "" : content.begin(); } inline size_t String::size() const { return content == nullptr ? 0 : content.size() - 1; } +inline size_t ConstString::size() const { return content == nullptr ? 0 : content.size() - 1; } inline char String::operator[](size_t index) const { return content[index]; } inline char& String::operator[](size_t index) { return content[index]; } +inline char ConstString::operator[](size_t index) const { return content[index]; } inline char* String::begin() { return content == nullptr ? nullptr : content.begin(); } inline char* String::end() { return content == nullptr ? nullptr : content.end() - 1; } inline const char* String::begin() const { return content == nullptr ? nullptr : content.begin(); } inline const char* String::end() const { return content == nullptr ? nullptr : content.end() - 1; } +inline const char* ConstString::begin() const { return content == nullptr ? nullptr : content.begin(); } +inline const char* ConstString::end() const { return content == nullptr ? nullptr : content.end() - 1; } inline String::String(char* value, size_t size, const ArrayDisposer& disposer) : content(value, size + 1, disposer) { KJ_IREQUIRE(value[size] == '\0', "String must be NUL-terminated."); } +inline ConstString::ConstString(const char* value, size_t size, const ArrayDisposer& disposer) + : content(value, size + 1, disposer) { + KJ_IREQUIRE(value[size] == '\0', "String must be NUL-terminated."); +} inline String::String(Array buffer): content(kj::mv(buffer)) { KJ_IREQUIRE(content.size() > 0 && content.back() == '\0', "String must be NUL-terminated."); } +inline ConstString::ConstString(Array buffer): content(kj::mv(buffer)) { + KJ_IREQUIRE(content.size() > 0 && content.back() == '\0', "String must be NUL-terminated."); +} inline String heapString(const char* value) { return heapString(value, strlen(value)); @@ -557,6 +826,119 @@ inline String heapString(ArrayPtr value) { return heapString(value.begin(), value.size()); } +namespace _ { // private + +template +class Delimited { +public: + Delimited(T array, kj::StringPtr delimiter) + : array(kj::fwd(array)), delimiter(delimiter) {} + + // TODO(someday): In theory we should support iteration as a character sequence, but the iterator + // will be pretty complicated. + + size_t size() { + ensureStringifiedInitialized(); + + size_t result = 0; + bool first = true; + for (auto& e: stringified) { + if (first) { + first = false; + } else { + result += delimiter.size(); + } + result += e.size(); + } + return result; + } + + char* flattenTo(char* __restrict__ target) { + ensureStringifiedInitialized(); + + bool first = true; + for (auto& elem: stringified) { + if (first) { + first = false; + } else { + target = fill(target, delimiter); + } + target = fill(target, elem); + } + return target; + } + + char* flattenTo(char* __restrict__ target, char* limit) { + // This is called in the strPreallocated(). We want to avoid allocation. size() will not have + // been called in this case, so hopefully `stringified` is still uninitialized. We will + // stringify each item and immediately use it. + bool first = true; + for (auto&& elem: array) { + if (target == limit) return target; + if (first) { + first = false; + } else { + target = fillLimited(target, limit, delimiter); + } + target = fillLimited(target, limit, kj::toCharSequence(elem)); + } + return target; + } + +private: + typedef decltype(toCharSequence(*instance().begin())) StringifiedItem; + T array; + kj::StringPtr delimiter; + Array stringified; + + void ensureStringifiedInitialized() { + if (array.size() > 0 && stringified.size() == 0) { + stringified = KJ_MAP(e, array) { return toCharSequence(e); }; + } + } +}; + +template +char* fill(char* __restrict__ target, Delimited&& first, Rest&&... rest) { + target = first.flattenTo(target); + return fill(target, kj::fwd(rest)...); +} +template +char* fillLimited(char* __restrict__ target, char* limit, Delimited&& first, Rest&&... rest) { + target = first.flattenTo(target, limit); + return fillLimited(target, limit, kj::fwd(rest)...); +} +template +char* fill(char* __restrict__ target, Delimited& first, Rest&&... rest) { + target = first.flattenTo(target); + return fill(target, kj::fwd(rest)...); +} +template +char* fillLimited(char* __restrict__ target, char* limit, Delimited& first, Rest&&... rest) { + target = first.flattenTo(target, limit); + return fillLimited(target, limit, kj::fwd(rest)...); +} + +template +inline Delimited&& KJ_STRINGIFY(Delimited&& delimited) { return kj::mv(delimited); } +template +inline const Delimited& KJ_STRINGIFY(const Delimited& delimited) { return delimited; } + +} // namespace _ (private) + +template +_::Delimited delimited(T&& arr, kj::StringPtr delim) { + return _::Delimited(kj::fwd(arr), delim); +} + } // namespace kj -#endif // KJ_STRING_H_ +constexpr kj::StringPtr operator "" _kj(const char* str, size_t n) { + return kj::StringPtr(kj::ArrayPtr(str, n + 1)); +}; + +constexpr kj::LiteralStringConst operator "" _kjc(const char* str, size_t n) { + return kj::LiteralStringConst(kj::ArrayPtr(str, n + 1)); +}; + +KJ_END_HEADER diff --git a/c++/src/kj/table-test.c++ b/c++/src/kj/table-test.c++ new file mode 100644 index 0000000000..708843a2bd --- /dev/null +++ b/c++/src/kj/table-test.c++ @@ -0,0 +1,1410 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "table.h" +#include +#include +#include +#include "hash.h" +#include "time.h" +#include + +namespace kj { +namespace _ { +namespace { + +#if defined(KJ_DEBUG) && !__OPTIMIZE__ +static constexpr uint MEDIUM_PRIME = 619; +static constexpr uint BIG_PRIME = 6143; +#else +static constexpr uint MEDIUM_PRIME = 6143; +static constexpr uint BIG_PRIME = 101363; +#endif +// Some of the tests build large tables. These numbers are used as the table sizes. We use primes +// to avoid any unintended aliasing affects -- this is probably just paranoia, but why not? +// +// We use smaller values for debug builds to keep runtime down. + +KJ_TEST("_::tryReserveSize() works") { + { + Vector vec; + tryReserveSize(vec, "foo"_kj); + KJ_EXPECT(vec.capacity() == 4); // Vectors always grow by powers of two. + } + { + Vector vec; + tryReserveSize(vec, 123); + KJ_EXPECT(vec.capacity() == 0); + } +} + +class StringHasher { +public: + StringPtr keyForRow(StringPtr s) const { return s; } + + bool matches(StringPtr a, StringPtr b) const { + return a == b; + } + uint hashCode(StringPtr str) const { + return kj::hashCode(str); + } +}; + +KJ_TEST("simple table") { + Table> table; + + KJ_EXPECT(table.find("foo") == nullptr); + + KJ_EXPECT(table.size() == 0); + KJ_EXPECT(table.insert("foo") == "foo"); + KJ_EXPECT(table.size() == 1); + KJ_EXPECT(table.insert("bar") == "bar"); + KJ_EXPECT(table.size() == 2); + + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("foo")) == "foo"); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("bar")) == "bar"); + KJ_EXPECT(table.find("fop") == nullptr); + KJ_EXPECT(table.find("baq") == nullptr); + + { + StringPtr& ref = table.insert("baz"); + KJ_EXPECT(ref == "baz"); + StringPtr& ref2 = KJ_ASSERT_NONNULL(table.find("baz")); + KJ_EXPECT(&ref == &ref2); + } + + KJ_EXPECT(table.size() == 3); + + { + auto iter = table.begin(); + KJ_EXPECT(*iter++ == "foo"); + KJ_EXPECT(*iter++ == "bar"); + KJ_EXPECT(*iter++ == "baz"); + KJ_EXPECT(iter == table.end()); + } + + KJ_EXPECT(table.eraseMatch("foo")); + KJ_EXPECT(table.size() == 2); + KJ_EXPECT(table.find("foo") == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("bar")) == "bar"); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("baz")) == "baz"); + + { + auto iter = table.begin(); + KJ_EXPECT(*iter++ == "baz"); + KJ_EXPECT(*iter++ == "bar"); + KJ_EXPECT(iter == table.end()); + } + + { + auto& row = table.upsert("qux", [&](StringPtr&, StringPtr&&) { + KJ_FAIL_ASSERT("shouldn't get here"); + }); + + auto copy = kj::str("qux"); + table.upsert(StringPtr(copy), [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(param.begin() == copy.begin()); + KJ_EXPECT(&existing == &row); + }); + + auto& found = KJ_ASSERT_NONNULL(table.find("qux")); + KJ_EXPECT(&found == &row); + } + + StringPtr STRS[] = { "corge"_kj, "grault"_kj, "garply"_kj }; + table.insertAll(ArrayPtr(STRS)); + KJ_EXPECT(table.size() == 6); + KJ_EXPECT(table.find("corge") != nullptr); + KJ_EXPECT(table.find("grault") != nullptr); + KJ_EXPECT(table.find("garply") != nullptr); + + KJ_EXPECT_THROW_MESSAGE("inserted row already exists in table", table.insert("bar")); + + KJ_EXPECT(table.size() == 6); + + KJ_EXPECT(table.insert("baa") == "baa"); + + KJ_EXPECT(table.eraseAll([](StringPtr s) { return s.startsWith("ba"); }) == 3); + KJ_EXPECT(table.size() == 4); + + { + auto iter = table.begin(); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(*iter++ == "qux"); + KJ_EXPECT(*iter++ == "corge"); + KJ_EXPECT(iter == table.end()); + } + + auto& graultRow = table.begin()[1]; + kj::StringPtr origGrault = graultRow; + + KJ_EXPECT(&table.findOrCreate("grault", + [&]() -> kj::StringPtr { KJ_FAIL_ASSERT("shouldn't have called this"); }) == &graultRow); + KJ_EXPECT(graultRow.begin() == origGrault.begin()); + KJ_EXPECT(&KJ_ASSERT_NONNULL(table.find("grault")) == &graultRow); + KJ_EXPECT(table.find("waldo") == nullptr); + KJ_EXPECT(table.size() == 4); + + kj::String searchWaldo = kj::str("waldo"); + kj::String insertWaldo = kj::str("waldo"); + + auto& waldo = table.findOrCreate(searchWaldo, + [&]() -> kj::StringPtr { return insertWaldo; }); + KJ_EXPECT(waldo == "waldo"); + KJ_EXPECT(waldo.begin() == insertWaldo.begin()); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("grault")) == "grault"); + KJ_EXPECT(&KJ_ASSERT_NONNULL(table.find("waldo")) == &waldo); + KJ_EXPECT(table.size() == 5); + + { + auto iter = table.begin(); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(*iter++ == "qux"); + KJ_EXPECT(*iter++ == "corge"); + KJ_EXPECT(*iter++ == "waldo"); + KJ_EXPECT(iter == table.end()); + } +} + +class BadHasher { + // String hash that always returns the same hash code. This should not affect correctness, only + // performance. +public: + StringPtr keyForRow(StringPtr s) const { return s; } + + bool matches(StringPtr a, StringPtr b) const { + return a == b; + } + uint hashCode(StringPtr str) const { + return 1234; + } +}; + +KJ_TEST("hash tables when hash is always same") { + Table> table; + + KJ_EXPECT(table.size() == 0); + KJ_EXPECT(table.insert("foo") == "foo"); + KJ_EXPECT(table.size() == 1); + KJ_EXPECT(table.insert("bar") == "bar"); + KJ_EXPECT(table.size() == 2); + + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("foo")) == "foo"); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("bar")) == "bar"); + KJ_EXPECT(table.find("fop") == nullptr); + KJ_EXPECT(table.find("baq") == nullptr); + + { + StringPtr& ref = table.insert("baz"); + KJ_EXPECT(ref == "baz"); + StringPtr& ref2 = KJ_ASSERT_NONNULL(table.find("baz")); + KJ_EXPECT(&ref == &ref2); + } + + KJ_EXPECT(table.size() == 3); + + { + auto iter = table.begin(); + KJ_EXPECT(*iter++ == "foo"); + KJ_EXPECT(*iter++ == "bar"); + KJ_EXPECT(*iter++ == "baz"); + KJ_EXPECT(iter == table.end()); + } + + KJ_EXPECT(table.eraseMatch("foo")); + KJ_EXPECT(table.size() == 2); + KJ_EXPECT(table.find("foo") == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("bar")) == "bar"); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("baz")) == "baz"); + + { + auto iter = table.begin(); + KJ_EXPECT(*iter++ == "baz"); + KJ_EXPECT(*iter++ == "bar"); + KJ_EXPECT(iter == table.end()); + } + + { + auto& row = table.upsert("qux", [&](StringPtr&, StringPtr&&) { + KJ_FAIL_ASSERT("shouldn't get here"); + }); + + auto copy = kj::str("qux"); + table.upsert(StringPtr(copy), [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(param.begin() == copy.begin()); + KJ_EXPECT(&existing == &row); + }); + + auto& found = KJ_ASSERT_NONNULL(table.find("qux")); + KJ_EXPECT(&found == &row); + } + + StringPtr STRS[] = { "corge"_kj, "grault"_kj, "garply"_kj }; + table.insertAll(ArrayPtr(STRS)); + KJ_EXPECT(table.size() == 6); + KJ_EXPECT(table.find("corge") != nullptr); + KJ_EXPECT(table.find("grault") != nullptr); + KJ_EXPECT(table.find("garply") != nullptr); + + KJ_EXPECT_THROW_MESSAGE("inserted row already exists in table", table.insert("bar")); +} + +class IntHasher { + // Dumb integer hasher that just returns the integer itself. +public: + uint keyForRow(uint i) const { return i; } + + bool matches(uint a, uint b) const { + return a == b; + } + uint hashCode(uint i) const { + return i; + } +}; + +KJ_TEST("HashIndex with many erasures doesn't keep growing") { + HashIndex index; + + kj::ArrayPtr rows = nullptr; + + for (uint i: kj::zeroTo(1000000)) { + KJ_ASSERT(index.insert(rows, 0, i) == nullptr); + index.erase(rows, 0, i); + } + + KJ_ASSERT(index.capacity() < 10); +} + +struct SiPair { + kj::StringPtr str; + uint i; + + inline bool operator==(SiPair other) const { + return str == other.str && i == other.i; + } +}; + +class SiPairStringHasher { +public: + StringPtr keyForRow(SiPair s) const { return s.str; } + + bool matches(SiPair a, StringPtr b) const { + return a.str == b; + } + uint hashCode(StringPtr str) const { + return inner.hashCode(str); + } + +private: + StringHasher inner; +}; + +class SiPairIntHasher { +public: + uint keyForRow(SiPair s) const { return s.i; } + + bool matches(SiPair a, uint b) const { + return a.i == b; + } + uint hashCode(uint i) const { + return i; + } +}; + +KJ_TEST("double-index table") { + Table, HashIndex> table; + + KJ_EXPECT(table.size() == 0); + KJ_EXPECT(table.insert({"foo", 123}) == (SiPair {"foo", 123})); + KJ_EXPECT(table.size() == 1); + KJ_EXPECT(table.insert({"bar", 456}) == (SiPair {"bar", 456})); + KJ_EXPECT(table.size() == 2); + + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find>("foo")) == + (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find>(123)) == + (SiPair {"foo", 123})); + + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("foo")) == (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(123)) == (SiPair {"foo", 123})); + + KJ_EXPECT_THROW_MESSAGE("inserted row already exists in table", table.insert({"foo", 111})); + KJ_EXPECT_THROW_MESSAGE("inserted row already exists in table", table.insert({"qux", 123})); + + KJ_EXPECT(table.size() == 2); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("foo")) == (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(123)) == (SiPair {"foo", 123})); + + KJ_EXPECT( + table.findOrCreate<0>("foo", + []() -> SiPair { KJ_FAIL_ASSERT("shouldn't have called this"); }) + == (SiPair {"foo", 123})); + KJ_EXPECT(table.size() == 2); + KJ_EXPECT_THROW_MESSAGE("inserted row already exists in table", + table.findOrCreate<0>("corge", []() -> SiPair { return {"corge", 123}; })); + + KJ_EXPECT(table.size() == 2); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("foo")) == (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(123)) == (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("bar")) == (SiPair {"bar", 456})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(456)) == (SiPair {"bar", 456})); + KJ_EXPECT(table.find<0>("corge") == nullptr); + + KJ_EXPECT( + table.findOrCreate<0>("corge", []() -> SiPair { return {"corge", 789}; }) + == (SiPair {"corge", 789})); + + KJ_EXPECT(table.size() == 3); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("foo")) == (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(123)) == (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("bar")) == (SiPair {"bar", 456})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(456)) == (SiPair {"bar", 456})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("corge")) == (SiPair {"corge", 789})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(789)) == (SiPair {"corge", 789})); + + KJ_EXPECT( + table.findOrCreate<1>(234, []() -> SiPair { return {"grault", 234}; }) + == (SiPair {"grault", 234})); + + KJ_EXPECT(table.size() == 4); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("foo")) == (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(123)) == (SiPair {"foo", 123})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("bar")) == (SiPair {"bar", 456})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(456)) == (SiPair {"bar", 456})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("corge")) == (SiPair {"corge", 789})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(789)) == (SiPair {"corge", 789})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<0>("grault")) == (SiPair {"grault", 234})); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find<1>(234)) == (SiPair {"grault", 234})); +} + +class UintHasher { +public: + uint keyForRow(uint i) const { return i; } + + bool matches(uint a, uint b) const { + return a == b; + } + uint hashCode(uint i) const { + return i; + } +}; + +KJ_TEST("benchmark: kj::Table") { + constexpr uint SOME_PRIME = BIG_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + for (auto step: STEP) { + KJ_CONTEXT(step); + Table> table; + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(j * 5 + 123); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint value = KJ_ASSERT_NONNULL(table.find(i * 5 + 123)); + KJ_ASSERT(value == i * 5 + 123); + KJ_ASSERT(table.find(i * 5 + 122) == nullptr); + KJ_ASSERT(table.find(i * 5 + 124) == nullptr); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + table.erase(KJ_ASSERT_NONNULL(table.find(i * 5 + 123))); + } + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(i * 5 + 123) == nullptr); + } else { + uint value = KJ_ASSERT_NONNULL(table.find(i * 5 + 123)); + KJ_ASSERT(value == i * 5 + 123); + } + } + } +} + +KJ_TEST("benchmark: std::unordered_set") { + constexpr uint SOME_PRIME = BIG_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + for (auto step: STEP) { + KJ_CONTEXT(step); + std::unordered_set table; + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(j * 5 + 123); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + auto iter = table.find(i * 5 + 123); + KJ_ASSERT(iter != table.end()); + uint value = *iter; + KJ_ASSERT(value == i * 5 + 123); + KJ_ASSERT(table.find(i * 5 + 122) == table.end()); + KJ_ASSERT(table.find(i * 5 + 124) == table.end()); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + KJ_ASSERT(table.erase(i * 5 + 123) > 0); + } + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(i * 5 + 123) == table.end()); + } else { + auto iter = table.find(i * 5 + 123); + KJ_ASSERT(iter != table.end()); + uint value = *iter; + KJ_ASSERT(value == i * 5 + 123); + } + } + } +} + +KJ_TEST("benchmark: kj::Table") { + constexpr uint SOME_PRIME = BIG_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + kj::Vector strings(SOME_PRIME); + for (uint i: kj::zeroTo(SOME_PRIME)) { + strings.add(kj::str(i * 5 + 123)); + } + + for (auto step: STEP) { + KJ_CONTEXT(step); + Table> table; + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(strings[j]); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + StringPtr value = KJ_ASSERT_NONNULL(table.find(strings[i])); + KJ_ASSERT(value == strings[i]); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + table.erase(KJ_ASSERT_NONNULL(table.find(strings[i]))); + } + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(strings[i]) == nullptr); + } else { + StringPtr value = KJ_ASSERT_NONNULL(table.find(strings[i])); + KJ_ASSERT(value == strings[i]); + } + } + } +} + +struct StlStringHash { + inline size_t operator()(StringPtr str) const { + return kj::hashCode(str); + } +}; + +KJ_TEST("benchmark: std::unordered_set") { + constexpr uint SOME_PRIME = BIG_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + kj::Vector strings(SOME_PRIME); + for (uint i: kj::zeroTo(SOME_PRIME)) { + strings.add(kj::str(i * 5 + 123)); + } + + for (auto step: STEP) { + KJ_CONTEXT(step); + std::unordered_set table; + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(strings[j]); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + auto iter = table.find(strings[i]); + KJ_ASSERT(iter != table.end()); + StringPtr value = *iter; + KJ_ASSERT(value == strings[i]); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + KJ_ASSERT(table.erase(strings[i]) > 0); + } + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(strings[i]) == table.end()); + } else { + auto iter = table.find(strings[i]); + KJ_ASSERT(iter != table.end()); + StringPtr value = *iter; + KJ_ASSERT(value == strings[i]); + } + } + } +} + +// ======================================================================================= + +KJ_TEST("B-tree internals") { + { + BTreeImpl::Leaf leaf; + memset(&leaf, 0, sizeof(leaf)); + + for (auto i: kj::indices(leaf.rows)) { + KJ_CONTEXT(i); + + KJ_EXPECT(leaf.size() == i); + + if (i < kj::size(leaf.rows) / 2) { +#ifdef KJ_DEBUG + KJ_EXPECT_THROW(FAILED, leaf.isHalfFull()); +#endif + KJ_EXPECT(!leaf.isMostlyFull()); + } + + if (i == kj::size(leaf.rows) / 2) { + KJ_EXPECT(leaf.isHalfFull()); + KJ_EXPECT(!leaf.isMostlyFull()); + } + + if (i > kj::size(leaf.rows) / 2) { + KJ_EXPECT(!leaf.isHalfFull()); + KJ_EXPECT(leaf.isMostlyFull()); + } + + if (i == kj::size(leaf.rows)) { + KJ_EXPECT(leaf.isFull()); + } else { + KJ_EXPECT(!leaf.isFull()); + } + + leaf.rows[i] = 1; + } + KJ_EXPECT(leaf.size() == kj::size(leaf.rows)); + } + + { + BTreeImpl::Parent parent; + memset(&parent, 0, sizeof(parent)); + + for (auto i: kj::indices(parent.keys)) { + KJ_CONTEXT(i); + + KJ_EXPECT(parent.keyCount() == i); + + if (i < kj::size(parent.keys) / 2) { +#ifdef KJ_DEBUG + KJ_EXPECT_THROW(FAILED, parent.isHalfFull()); +#endif + KJ_EXPECT(!parent.isMostlyFull()); + } + + if (i == kj::size(parent.keys) / 2) { + KJ_EXPECT(parent.isHalfFull()); + KJ_EXPECT(!parent.isMostlyFull()); + } + + if (i > kj::size(parent.keys) / 2) { + KJ_EXPECT(!parent.isHalfFull()); + KJ_EXPECT(parent.isMostlyFull()); + } + + if (i == kj::size(parent.keys)) { + KJ_EXPECT(parent.isFull()); + } else { + KJ_EXPECT(!parent.isFull()); + } + + parent.keys[i] = 1; + } + KJ_EXPECT(parent.keyCount() == kj::size(parent.keys)); + } +} + +class StringCompare { +public: + StringPtr keyForRow(StringPtr s) const { return s; } + + bool isBefore(StringPtr a, StringPtr b) const { + return a < b; + } + bool matches(StringPtr a, StringPtr b) const { + return a == b; + } +}; + +KJ_TEST("simple tree table") { + Table> table; + + KJ_EXPECT(table.find("foo") == nullptr); + + KJ_EXPECT(table.size() == 0); + KJ_EXPECT(table.insert("foo") == "foo"); + KJ_EXPECT(table.size() == 1); + KJ_EXPECT(table.insert("bar") == "bar"); + KJ_EXPECT(table.size() == 2); + + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("foo")) == "foo"); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("bar")) == "bar"); + KJ_EXPECT(table.find("fop") == nullptr); + KJ_EXPECT(table.find("baq") == nullptr); + + { + StringPtr& ref = table.insert("baz"); + KJ_EXPECT(ref == "baz"); + StringPtr& ref2 = KJ_ASSERT_NONNULL(table.find("baz")); + KJ_EXPECT(&ref == &ref2); + } + + KJ_EXPECT(table.size() == 3); + + { + auto range = table.ordered(); + auto iter = range.begin(); + KJ_EXPECT(*iter++ == "bar"); + KJ_EXPECT(*iter++ == "baz"); + KJ_EXPECT(*iter++ == "foo"); + KJ_EXPECT(iter == range.end()); + } + + KJ_EXPECT(table.eraseMatch("foo")); + KJ_EXPECT(table.size() == 2); + KJ_EXPECT(table.find("foo") == nullptr); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("bar")) == "bar"); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("baz")) == "baz"); + + { + auto range = table.ordered(); + auto iter = range.begin(); + KJ_EXPECT(*iter++ == "bar"); + KJ_EXPECT(*iter++ == "baz"); + KJ_EXPECT(iter == range.end()); + } + + { + auto& row = table.upsert("qux", [&](StringPtr&, StringPtr&&) { + KJ_FAIL_ASSERT("shouldn't get here"); + }); + + auto copy = kj::str("qux"); + table.upsert(StringPtr(copy), [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(param.begin() == copy.begin()); + KJ_EXPECT(&existing == &row); + }); + + auto& found = KJ_ASSERT_NONNULL(table.find("qux")); + KJ_EXPECT(&found == &row); + } + + StringPtr STRS[] = { "corge"_kj, "grault"_kj, "garply"_kj }; + table.insertAll(ArrayPtr(STRS)); + KJ_EXPECT(table.size() == 6); + KJ_EXPECT(table.find("corge") != nullptr); + KJ_EXPECT(table.find("grault") != nullptr); + KJ_EXPECT(table.find("garply") != nullptr); + + KJ_EXPECT_THROW_MESSAGE("inserted row already exists in table", table.insert("bar")); + + KJ_EXPECT(table.size() == 6); + + KJ_EXPECT(table.insert("baa") == "baa"); + + KJ_EXPECT(table.eraseAll([](StringPtr s) { return s.startsWith("ba"); }) == 3); + KJ_EXPECT(table.size() == 4); + + { + auto range = table.ordered(); + auto iter = range.begin(); + KJ_EXPECT(*iter++ == "corge"); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(*iter++ == "qux"); + KJ_EXPECT(iter == range.end()); + } + + { + auto range = table.range("foo", "har"); + auto iter = range.begin(); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(iter == range.end()); + } + + { + auto range = table.range("garply", "grault"); + auto iter = range.begin(); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(iter == range.end()); + } + + { + auto iter = table.seek("garply"); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(*iter++ == "qux"); + KJ_EXPECT(iter == table.ordered().end()); + } + + { + auto iter = table.seek("gorply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(*iter++ == "qux"); + KJ_EXPECT(iter == table.ordered().end()); + } + + auto& graultRow = table.begin()[1]; + kj::StringPtr origGrault = graultRow; + + KJ_EXPECT(&table.findOrCreate("grault", + [&]() -> kj::StringPtr { KJ_FAIL_ASSERT("shouldn't have called this"); }) == &graultRow); + KJ_EXPECT(graultRow.begin() == origGrault.begin()); + KJ_EXPECT(&KJ_ASSERT_NONNULL(table.find("grault")) == &graultRow); + KJ_EXPECT(table.find("waldo") == nullptr); + KJ_EXPECT(table.size() == 4); + + kj::String searchWaldo = kj::str("waldo"); + kj::String insertWaldo = kj::str("waldo"); + + auto& waldo = table.findOrCreate(searchWaldo, + [&]() -> kj::StringPtr { return insertWaldo; }); + KJ_EXPECT(waldo == "waldo"); + KJ_EXPECT(waldo.begin() == insertWaldo.begin()); + KJ_EXPECT(KJ_ASSERT_NONNULL(table.find("grault")) == "grault"); + KJ_EXPECT(&KJ_ASSERT_NONNULL(table.find("waldo")) == &waldo); + KJ_EXPECT(table.size() == 5); + + { + auto iter = table.begin(); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(*iter++ == "qux"); + KJ_EXPECT(*iter++ == "corge"); + KJ_EXPECT(*iter++ == "waldo"); + KJ_EXPECT(iter == table.end()); + } + + // Verify that move constructor/assignment work. + Table> other(kj::mv(table)); + KJ_EXPECT(other.size() == 5); + KJ_EXPECT(table.size() == 0); + KJ_EXPECT(table.begin() == table.end()); + { + auto iter = other.begin(); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(*iter++ == "qux"); + KJ_EXPECT(*iter++ == "corge"); + KJ_EXPECT(*iter++ == "waldo"); + KJ_EXPECT(iter == other.end()); + } + + table = kj::mv(other); + KJ_EXPECT(other.size() == 0); + KJ_EXPECT(table.size() == 5); + { + auto iter = table.begin(); + KJ_EXPECT(*iter++ == "garply"); + KJ_EXPECT(*iter++ == "grault"); + KJ_EXPECT(*iter++ == "qux"); + KJ_EXPECT(*iter++ == "corge"); + KJ_EXPECT(*iter++ == "waldo"); + KJ_EXPECT(iter == table.end()); + } + KJ_EXPECT(other.begin() == other.end()); +} + +class UintCompare { +public: + uint keyForRow(uint i) const { return i; } + + bool isBefore(uint a, uint b) const { + return a < b; + } + bool matches(uint a, uint b) const { + return a == b; + } +}; + +KJ_TEST("large tree table") { + constexpr uint SOME_PRIME = MEDIUM_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + for (auto step: STEP) { + KJ_CONTEXT(step); + Table> table; + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(j * 5 + 123); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint value = KJ_ASSERT_NONNULL(table.find(i * 5 + 123)); + KJ_ASSERT(value == i * 5 + 123); + KJ_ASSERT(table.find(i * 5 + 122) == nullptr); + KJ_ASSERT(table.find(i * 5 + 124) == nullptr); + } + table.verify(); + + { + auto range = table.ordered(); + auto iter = range.begin(); + for (uint i: kj::zeroTo(SOME_PRIME)) { + KJ_ASSERT(*iter++ == i * 5 + 123); + } + KJ_ASSERT(iter == range.end()); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + KJ_CONTEXT(i); + if (i % 2 == 0 || i % 7 == 0) { + table.erase(KJ_ASSERT_NONNULL(table.find(i * 5 + 123), i)); + table.verify(); + } + } + + { + auto range = table.ordered(); + auto iter = range.begin(); + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(i * 5 + 123) == nullptr); + } else { + uint value = KJ_ASSERT_NONNULL(table.find(i * 5 + 123)); + KJ_ASSERT(value == i * 5 + 123); + KJ_ASSERT(*iter++ == i * 5 + 123); + } + } + KJ_ASSERT(iter == range.end()); + } + } +} + +KJ_TEST("TreeIndex fuzz test") { + // A test which randomly modifies a TreeIndex to try to discover buggy state changes. + + uint seed = (kj::systemPreciseCalendarClock().now() - kj::UNIX_EPOCH) / kj::NANOSECONDS; + KJ_CONTEXT(seed); // print the seed if the test fails + srand(seed); + + Table> table; + + auto randomInsert = [&]() { + table.upsert(rand(), [](auto&&, auto&&) {}); + }; + auto randomErase = [&]() { + if (table.size() > 0) { + auto& row = table.begin()[rand() % table.size()]; + table.erase(row); + } + }; + auto randomLookup = [&]() { + if (table.size() > 0) { + auto& row = table.begin()[rand() % table.size()]; + auto& found = KJ_ASSERT_NONNULL(table.find(row)); + KJ_ASSERT(&found == &row); + } + }; + + // First pass: focus on insertions, aim to do 2x as many insertions as deletions. + for (auto i KJ_UNUSED: kj::zeroTo(1000)) { + switch (rand() % 4) { + case 0: + case 1: + randomInsert(); + break; + case 2: + randomErase(); + break; + case 3: + randomLookup(); + break; + } + + table.verify(); + } + + // Second pass: focus on deletions, aim to do 2x as many deletions as insertions. + for (auto i KJ_UNUSED: kj::zeroTo(1000)) { + switch (rand() % 4) { + case 0: + randomInsert(); + break; + case 1: + case 2: + randomErase(); + break; + case 3: + randomLookup(); + break; + } + + table.verify(); + } +} + +KJ_TEST("TreeIndex clear() leaves tree in valid state") { + // A test which ensures that calling clear() does not break the internal state of a TreeIndex. + // It used to be the case that clearing a non-empty tree would leave it thinking that it had room + // for one more node than it really did, causing it to write and read beyond the end of its + // internal array of nodes. + Table> table; + + // Insert at least one value to allocate an initial set of tree nodes. + table.upsert(1, [](auto&&, auto&&) {}); + KJ_EXPECT(table.find(1) != nullptr); + table.clear(); + + // Insert enough values to force writes/reads beyond the end of the tree's internal node array. + for (uint i = 0; i < 29; ++i) { + table.upsert(i, [](auto&&, auto&&) {}); + } + for (uint i = 0; i < 29; ++i) { + KJ_EXPECT(table.find(i) != nullptr); + } +} + +KJ_TEST("benchmark: kj::Table") { + constexpr uint SOME_PRIME = BIG_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + for (auto step: STEP) { + KJ_CONTEXT(step); + Table> table; + table.reserve(SOME_PRIME); + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(j * 5 + 123); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint value = KJ_ASSERT_NONNULL(table.find(i * 5 + 123)); + KJ_ASSERT(value == i * 5 + 123); + KJ_ASSERT(table.find(i * 5 + 122) == nullptr); + KJ_ASSERT(table.find(i * 5 + 124) == nullptr); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + table.erase(KJ_ASSERT_NONNULL(table.find(i * 5 + 123))); + } + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(i * 5 + 123) == nullptr); + } else { + uint value = KJ_ASSERT_NONNULL(table.find(i * 5 + 123)); + KJ_ASSERT(value == i * 5 + 123); + } + } + } +} + +KJ_TEST("benchmark: std::set") { + constexpr uint SOME_PRIME = BIG_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + for (auto step: STEP) { + KJ_CONTEXT(step); + std::set table; + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(j * 5 + 123); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + auto iter = table.find(i * 5 + 123); + KJ_ASSERT(iter != table.end()); + uint value = *iter; + KJ_ASSERT(value == i * 5 + 123); + KJ_ASSERT(table.find(i * 5 + 122) == table.end()); + KJ_ASSERT(table.find(i * 5 + 124) == table.end()); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + KJ_ASSERT(table.erase(i * 5 + 123) > 0); + } + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(i * 5 + 123) == table.end()); + } else { + auto iter = table.find(i * 5 + 123); + KJ_ASSERT(iter != table.end()); + uint value = *iter; + KJ_ASSERT(value == i * 5 + 123); + } + } + } +} + +KJ_TEST("benchmark: kj::Table") { + constexpr uint SOME_PRIME = BIG_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + kj::Vector strings(SOME_PRIME); + for (uint i: kj::zeroTo(SOME_PRIME)) { + strings.add(kj::str(i * 5 + 123)); + } + + for (auto step: STEP) { + KJ_CONTEXT(step); + Table> table; + table.reserve(SOME_PRIME); + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(strings[j]); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + StringPtr value = KJ_ASSERT_NONNULL(table.find(strings[i])); + KJ_ASSERT(value == strings[i]); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + table.erase(KJ_ASSERT_NONNULL(table.find(strings[i]))); + } + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(strings[i]) == nullptr); + } else { + auto& value = KJ_ASSERT_NONNULL(table.find(strings[i])); + KJ_ASSERT(value == strings[i]); + } + } + } +} + +KJ_TEST("benchmark: std::set") { + constexpr uint SOME_PRIME = BIG_PRIME; + constexpr uint STEP[] = {1, 2, 4, 7, 43, 127}; + + kj::Vector strings(SOME_PRIME); + for (uint i: kj::zeroTo(SOME_PRIME)) { + strings.add(kj::str(i * 5 + 123)); + } + + for (auto step: STEP) { + KJ_CONTEXT(step); + std::set table; + for (uint i: kj::zeroTo(SOME_PRIME)) { + uint j = (i * step) % SOME_PRIME; + table.insert(strings[j]); + } + for (uint i: kj::zeroTo(SOME_PRIME)) { + auto iter = table.find(strings[i]); + KJ_ASSERT(iter != table.end()); + StringPtr value = *iter; + KJ_ASSERT(value == strings[i]); + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + KJ_ASSERT(table.erase(strings[i]) > 0); + } + } + + for (uint i: kj::zeroTo(SOME_PRIME)) { + if (i % 2 == 0 || i % 7 == 0) { + // erased + KJ_ASSERT(table.find(strings[i]) == table.end()); + } else { + auto iter = table.find(strings[i]); + KJ_ASSERT(iter != table.end()); + StringPtr value = *iter; + KJ_ASSERT(value == strings[i]); + } + } + } +} + +// ======================================================================================= + +KJ_TEST("insertion order index") { + Table table; + + { + auto range = table.ordered(); + KJ_EXPECT(range.begin() == range.end()); + } + + table.insert(12); + table.insert(34); + table.insert(56); + table.insert(78); + + { + auto range = table.ordered(); + auto iter = range.begin(); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 12); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 34); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 56); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 78); + KJ_EXPECT(iter == range.end()); + KJ_EXPECT(*--iter == 78); + KJ_EXPECT(*--iter == 56); + KJ_EXPECT(*--iter == 34); + KJ_EXPECT(*--iter == 12); + KJ_EXPECT(iter == range.begin()); + } + + table.erase(table.begin()[1]); + + { + auto range = table.ordered(); + auto iter = range.begin(); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 12); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 56); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 78); + KJ_EXPECT(iter == range.end()); + KJ_EXPECT(*--iter == 78); + KJ_EXPECT(*--iter == 56); + KJ_EXPECT(*--iter == 12); + KJ_EXPECT(iter == range.begin()); + } + + // Allocate enough more elements to cause a resize. + table.insert(111); + table.insert(222); + table.insert(333); + table.insert(444); + table.insert(555); + table.insert(666); + table.insert(777); + table.insert(888); + table.insert(999); + + { + auto range = table.ordered(); + auto iter = range.begin(); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 12); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 56); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 78); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 111); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 222); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 333); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 444); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 555); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 666); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 777); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 888); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 999); + KJ_EXPECT(iter == range.end()); + } + + // Remove everything. + while (table.size() > 0) { + table.erase(*table.begin()); + } + + { + auto range = table.ordered(); + KJ_EXPECT(range.begin() == range.end()); + } +} + +KJ_TEST("insertion order index is movable") { + using UintTable = Table; + + kj::Maybe myTable; + + { + UintTable yourTable; + + yourTable.insert(12); + yourTable.insert(34); + yourTable.insert(56); + yourTable.insert(78); + yourTable.insert(111); + yourTable.insert(222); + yourTable.insert(333); + yourTable.insert(444); + yourTable.insert(555); + yourTable.insert(666); + yourTable.insert(777); + yourTable.insert(888); + yourTable.insert(999); + + myTable = kj::mv(yourTable); + } + + auto& table = KJ_ASSERT_NONNULL(myTable); + + // At one time the following induced a segfault/double-free, due to incorrect memory management in + // InsertionOrderIndex's move ctor and dtor. + auto range = table.ordered(); + auto iter = range.begin(); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 12); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 34); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 56); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 78); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 111); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 222); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 333); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 444); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 555); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 666); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 777); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 888); + KJ_ASSERT(iter != range.end()); + KJ_EXPECT(*iter++ == 999); + KJ_EXPECT(iter == range.end()); +} + +// ======================================================================================= +// Test bug where insertion failure on a later index in the table would not be rolled back +// correctly if a previous index was TreeIndex. + +class StringLengthCompare { + // Considers two strings equal if they have the same length. +public: + inline size_t keyForRow(StringPtr entry) const { + return entry.size(); + } + + inline bool matches(StringPtr e, size_t key) const { + return e.size() == key; + } + + inline bool isBefore(StringPtr e, size_t key) const { + return e.size() < key; + } + + uint hashCode(size_t size) const { + return size; + } +}; + +KJ_TEST("HashIndex rollback on insertion failure") { + // Test that when an insertion produces a duplicate on a later index, changes to previous indexes + // are properly rolled back. + + Table, HashIndex> table; + table.insert("a"_kj); + table.insert("ab"_kj); + table.insert("abc"_kj); + + { + // We use upsert() so that we don't throw an exception from the duplicate, but this exercises + // the same logic as a duplicate insert() other than throwing. + kj::StringPtr& found = table.upsert("xyz"_kj, [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(existing == "abc"); + KJ_EXPECT(param == "xyz"); + }); + KJ_EXPECT(found == "abc"); + + table.erase(found); + } + + table.insert("xyz"_kj); + + { + kj::StringPtr& found = table.upsert("tuv"_kj, [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(existing == "xyz"); + KJ_EXPECT(param == "tuv"); + }); + KJ_EXPECT(found == "xyz"); + } +} + +KJ_TEST("TreeIndex rollback on insertion failure") { + // Test that when an insertion produces a duplicate on a later index, changes to previous indexes + // are properly rolled back. + + Table, TreeIndex> table; + table.insert("a"_kj); + table.insert("ab"_kj); + table.insert("abc"_kj); + + { + // We use upsert() so that we don't throw an exception from the duplicate, but this exercises + // the same logic as a duplicate insert() other than throwing. + kj::StringPtr& found = table.upsert("xyz"_kj, [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(existing == "abc"); + KJ_EXPECT(param == "xyz"); + }); + KJ_EXPECT(found == "abc"); + + table.erase(found); + } + + table.insert("xyz"_kj); + + { + kj::StringPtr& found = table.upsert("tuv"_kj, [&](StringPtr& existing, StringPtr&& param) { + KJ_EXPECT(existing == "xyz"); + KJ_EXPECT(param == "tuv"); + }); + KJ_EXPECT(found == "xyz"); + } +} + +} // namespace +} // namespace _ +} // namespace kj diff --git a/c++/src/kj/table.c++ b/c++/src/kj/table.c++ new file mode 100644 index 0000000000..4b0e028524 --- /dev/null +++ b/c++/src/kj/table.c++ @@ -0,0 +1,965 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "table.h" +#include "debug.h" +#include + +#if KJ_DEBUG_TABLE_IMPL +#undef KJ_DASSERT +#define KJ_DASSERT KJ_ASSERT +#endif + +namespace kj { +namespace _ { + +static inline uint lg(uint value) { + // Compute floor(log2(value)). + // + // Undefined for value = 0. +#if _MSC_VER && !defined(__clang__) + unsigned long i; + auto found = _BitScanReverse(&i, value); + KJ_DASSERT(found); // !found means value = 0 + return i; +#else + return sizeof(uint) * 8 - 1 - __builtin_clz(value); +#endif +} + +void throwDuplicateTableRow() { + KJ_FAIL_REQUIRE("inserted row already exists in table"); +} + +void logHashTableInconsistency() { + KJ_LOG(ERROR, + "HashIndex detected hash table inconsistency. This can happen if you create a kj::Table " + "with a hash index and you modify the rows in the table post-indexing in a way that would " + "change their hash. This is a serious bug which will lead to undefined behavior." + "\nstack: ", kj::getStackTrace()); +} + +// List of primes where each element is roughly double the previous. Obtained +// from: +// http://planetmath.org/goodhashtableprimes +// Primes < 53 were added to ensure that small tables don't allocate excessive memory. +static const size_t PRIMES[] = { + 1, // 2^ 0 = 1 + 3, // 2^ 1 = 2 + 5, // 2^ 2 = 4 + 11, // 2^ 3 = 8 + 23, // 2^ 4 = 16 + 53, // 2^ 5 = 32 + 97, // 2^ 6 = 64 + 193, // 2^ 7 = 128 + 389, // 2^ 8 = 256 + 769, // 2^ 9 = 512 + 1543, // 2^10 = 1024 + 3079, // 2^11 = 2048 + 6151, // 2^12 = 4096 + 12289, // 2^13 = 8192 + 24593, // 2^14 = 16384 + 49157, // 2^15 = 32768 + 98317, // 2^16 = 65536 + 196613, // 2^17 = 131072 + 393241, // 2^18 = 262144 + 786433, // 2^19 = 524288 + 1572869, // 2^20 = 1048576 + 3145739, // 2^21 = 2097152 + 6291469, // 2^22 = 4194304 + 12582917, // 2^23 = 8388608 + 25165843, // 2^24 = 16777216 + 50331653, // 2^25 = 33554432 + 100663319, // 2^26 = 67108864 + 201326611, // 2^27 = 134217728 + 402653189, // 2^28 = 268435456 + 805306457, // 2^29 = 536870912 + 1610612741, // 2^30 = 1073741824 +}; + +uint chooseBucket(uint hash, uint count) { + // Integer modulus is really, really slow. It turns out that the compiler can generate much + // faster code if the denominator is a constant. Since we have a fixed set of possible + // denominators, a big old switch() statement is a win. + + // TODO(perf): Consider using power-of-two bucket sizes. We can safely do so as long as we demand + // high-quality hash functions -- kj::hashCode() needs good diffusion even for integers, can't + // just be a cast. Also be sure to implement Robin Hood hashing to avoid extremely bad negative + // lookup time when elements have sequential hashes (otherwise, it could be necessary to scan + // the entire list to determine that an element isn't present). + + switch (count) { +#define HANDLE(i) case i##u: return hash % i##u + HANDLE( 1); + HANDLE( 3); + HANDLE( 5); + HANDLE( 11); + HANDLE( 23); + HANDLE( 53); + HANDLE( 97); + HANDLE( 193); + HANDLE( 389); + HANDLE( 769); + HANDLE( 1543); + HANDLE( 3079); + HANDLE( 6151); + HANDLE( 12289); + HANDLE( 24593); + HANDLE( 49157); + HANDLE( 98317); + HANDLE( 196613); + HANDLE( 393241); + HANDLE( 786433); + HANDLE( 1572869); + HANDLE( 3145739); + HANDLE( 6291469); + HANDLE( 12582917); + HANDLE( 25165843); + HANDLE( 50331653); + HANDLE( 100663319); + HANDLE( 201326611); + HANDLE( 402653189); + HANDLE( 805306457); + HANDLE(1610612741); +#undef HANDLE + default: return hash % count; + } +} + +size_t chooseHashTableSize(uint size) { + if (size == 0) return 0; + + // Add 1 to compensate for the floor() above, then look up the best prime bucket size for that + // target size. + return PRIMES[lg(size) + 1]; +} + +kj::Array rehash(kj::ArrayPtr oldBuckets, size_t targetSize) { + // Rehash the whole table. + + KJ_REQUIRE(targetSize < (1 << 30), "hash table has reached maximum size"); + + size_t size = chooseHashTableSize(targetSize); + + if (size < oldBuckets.size()) { + size = oldBuckets.size(); + } + + auto newBuckets = kj::heapArray(size); + memset(newBuckets.begin(), 0, sizeof(HashBucket) * size); + + uint entryCount = 0; + uint collisionCount = 0; + + for (auto& oldBucket: oldBuckets) { + if (oldBucket.isOccupied()) { + ++entryCount; + for (uint i = oldBucket.hash % newBuckets.size();; i = probeHash(newBuckets, i)) { + auto& newBucket = newBuckets[i]; + if (newBucket.isEmpty()) { + newBucket = oldBucket; + break; + } + ++collisionCount; + } + } + } + + if (collisionCount > 16 + entryCount * 4) { + static bool warned = false; + if (!warned) { + KJ_LOG(WARNING, "detected excessive collisions in hash table; is your hash function OK?", + entryCount, collisionCount, kj::getStackTrace()); + warned = true; + } + } + + return newBuckets; +} + +// ======================================================================================= +// BTree + +#if _WIN32 +#define aligned_free _aligned_free +#else +#define aligned_free ::free +#endif + +BTreeImpl::BTreeImpl() + : tree(const_cast(&EMPTY_NODE)), + treeCapacity(1), + height(0), + freelistHead(1), + freelistSize(0), + beginLeaf(0), + endLeaf(0) {} + +BTreeImpl::~BTreeImpl() noexcept(false) { + if (tree != &EMPTY_NODE) { + aligned_free(tree); + } +} + +BTreeImpl::BTreeImpl(BTreeImpl&& other) + : BTreeImpl() { + *this = kj::mv(other); +} + +BTreeImpl& BTreeImpl::operator=(BTreeImpl&& other) { + KJ_DASSERT(&other != this); + + if (tree != &EMPTY_NODE) { + aligned_free(tree); + } + tree = other.tree; + treeCapacity = other.treeCapacity; + height = other.height; + freelistHead = other.freelistHead; + freelistSize = other.freelistSize; + beginLeaf = other.beginLeaf; + endLeaf = other.endLeaf; + + other.tree = const_cast(&EMPTY_NODE); + other.treeCapacity = 1; + other.height = 0; + other.freelistHead = 1; + other.freelistSize = 0; + other.beginLeaf = 0; + other.endLeaf = 0; + + return *this; +} + +const BTreeImpl::NodeUnion BTreeImpl::EMPTY_NODE = {{{0, {0}}}}; + +void BTreeImpl::verify(size_t size, FunctionParam f) { + KJ_ASSERT(verifyNode(size, f, 0, height, nullptr) == size); +} +size_t BTreeImpl::verifyNode(size_t size, FunctionParam& f, + uint pos, uint height, MaybeUint maxRow) { + if (height > 0) { + auto& parent = tree[pos].parent; + + auto n = parent.keyCount(); + size_t total = 0; + for (auto i: kj::zeroTo(n)) { + KJ_ASSERT(*parent.keys[i] < size, n, i); + total += verifyNode(size, f, parent.children[i], height - 1, parent.keys[i]); + if (i > 0) { + KJ_ASSERT(f(*parent.keys[i - 1], *parent.keys[i]), + n, i, parent.keys[i - 1], parent.keys[i]); + } + } + total += verifyNode(size, f, parent.children[n], height - 1, maxRow); + if (maxRow != nullptr) { + KJ_ASSERT(f(*parent.keys[n-1], *maxRow), n, parent.keys[n-1], maxRow); + } + return total; + } else { + auto& leaf = tree[pos].leaf; + auto n = leaf.size(); + for (auto i: kj::zeroTo(n)) { + KJ_ASSERT(*leaf.rows[i] < size, n, i); + if (i > 0) { + KJ_ASSERT(f(*leaf.rows[i - 1], *leaf.rows[i]), + n, i, leaf.rows[i - 1], leaf.rows[i]); + } + } + if (maxRow != nullptr) { + KJ_ASSERT(leaf.rows[n-1] == maxRow, n); + } + return n; + } +} + +kj::String BTreeImpl::MaybeUint::toString() const { + return i == 0 ? kj::str("(null)") : kj::str(i - 1); +} + +void BTreeImpl::logInconsistency() const { + KJ_LOG(ERROR, + "BTreeIndex detected tree state inconsistency. This can happen if you create a kj::Table " + "with a b-tree index and you modify the rows in the table post-indexing in a way that would " + "change their ordering. This is a serious bug which will lead to undefined behavior." + "\nstack: ", kj::getStackTrace()); +} + +void BTreeImpl::reserve(size_t size) { + KJ_REQUIRE(size < (1u << 31), "b-tree has reached maximum size"); + + // Calculate the worst-case number of leaves to cover the size, given that a leaf is always at + // least half-full. (Note that it's correct for this calculation to round down, not up: The + // remainder will necessarily be distributed among the non-full leaves, rather than creating a + // new leaf, because if it went into a new leaf, that leaf would be less than half-full.) + uint leaves = size / (Leaf::NROWS / 2); + + // Calculate the worst-case number of parents to cover the leaves, given that a parent is always + // at least half-full. Since the parents form a tree with branching factor B, the size of the + // tree is N/B + N/B^2 + N/B^3 + N/B^4 + ... = N / (B - 1). Math. + constexpr uint branchingFactor = Parent::NCHILDREN / 2; + uint parents = leaves / (branchingFactor - 1); + + // Height is log-base-branching-factor of leaves, plus 1 for the root node. + uint height = lg(leaves | 1) / lg(branchingFactor) + 1; + + size_t newSize = leaves + + parents + 1 + // + 1 for the root + height + 2; // minimum freelist size needed by insert() + + if (treeCapacity < newSize) { + growTree(newSize); + } +} + +void BTreeImpl::clear() { + if (tree != &EMPTY_NODE) { + azero(tree, treeCapacity); + height = 0; + freelistHead = 1; + freelistSize = treeCapacity - 1; // subtract one to account for the root node + beginLeaf = 0; + endLeaf = 0; + } +} + +void BTreeImpl::growTree(uint minCapacity) { + uint newCapacity = kj::max(kj::max(minCapacity, treeCapacity * 2), 4); + freelistSize += newCapacity - treeCapacity; + + // Allocate some aligned memory! In theory this should be as simple as calling the C11 standard + // aligned_alloc() function. Unfortunately, many platforms don't implement it. Luckily, there + // are usually alternatives. + +#if _WIN32 + // Windows lacks aligned_alloc() but has its own _aligned_malloc() (which requires freeing using + // _aligned_free()). + // WATCH OUT: The argument order for _aligned_malloc() is opposite of aligned_alloc()! + NodeUnion* newTree = reinterpret_cast( + _aligned_malloc(newCapacity * sizeof(BTreeImpl::NodeUnion), sizeof(BTreeImpl::NodeUnion))); + KJ_ASSERT(newTree != nullptr, "memory allocation failed", newCapacity); +#else + // macOS, OpenBSD, and Android lack aligned_alloc(), but have posix_memalign(). Fine. + void* allocPtr; + int error = posix_memalign(&allocPtr, + sizeof(BTreeImpl::NodeUnion), newCapacity * sizeof(BTreeImpl::NodeUnion)); + if (error != 0) { + KJ_FAIL_SYSCALL("posix_memalign", error); + } + NodeUnion* newTree = reinterpret_cast(allocPtr); +#endif + + // Note: C11 introduces aligned_alloc() as a standard, but it's still missing on many platforms, + // so we don't use it. But if you wanted to use it, you'd do this: +// NodeUnion* newTree = reinterpret_cast( +// aligned_alloc(sizeof(BTreeImpl::NodeUnion), newCapacity * sizeof(BTreeImpl::NodeUnion))); +// KJ_ASSERT(newTree != nullptr, "memory allocation failed", newCapacity); + + acopy(newTree, tree, treeCapacity); + azero(newTree + treeCapacity, newCapacity - treeCapacity); + if (tree != &EMPTY_NODE) aligned_free(tree); + tree = newTree; + treeCapacity = newCapacity; +} + +BTreeImpl::Iterator BTreeImpl::search(const SearchKey& searchKey) const { + // Find the "first" row number (in sorted order) for which searchKey.isAfter(rowNumber) returns + // false. + + uint pos = 0; + + for (auto i KJ_UNUSED: zeroTo(height)) { + auto& parent = tree[pos].parent; + pos = parent.children[searchKey.search(parent)]; + } + + auto& leaf = tree[pos].leaf; + return { tree, &leaf, searchKey.search(leaf) }; +} + +template +struct BTreeImpl::AllocResult { + uint index; + T& node; +}; + +template +inline BTreeImpl::AllocResult BTreeImpl::alloc() { + // Allocate a new item from the freelist. Guaranteed to be zero'd except for the first member. + uint i = freelistHead; + NodeUnion* ptr = &tree[i]; + freelistHead = i + 1 + ptr->freelist.nextOffset; + --freelistSize; + return { i, *ptr }; +} + +inline void BTreeImpl::free(uint pos) { + // Add the given node to the freelist. + + // HACK: This is typically called on a node immediately after copying its contents away, but the + // pointer used to copy it away may be a different pointer pointing to a different union member + // which the compiler may not recgonize as aliasing with this object. Just to be extra-safe, + // insert a compiler barrier. + compilerBarrier(); + + auto& node = tree[pos]; + node.freelist.nextOffset = freelistHead - pos - 1; + azero(node.freelist.zero, kj::size(node.freelist.zero)); + freelistHead = pos; + ++freelistSize; +} + +BTreeImpl::Iterator BTreeImpl::insert(const SearchKey& searchKey) { + // Like search() but ensures that there is room in the leaf node to insert a new row. + + // If we split the root node it will generate two new nodes. If we split any other node in the + // path it will generate one new node. `height` doesn't count leaf nodes, but we can equivalently + // think of it as not counting the root node, so in the worst case we may allocate height + 2 + // new nodes. + // + // (Also note that if the tree is currently empty, then `tree` points to a dummy root node in + // read-only memory. We definitely need to allocate a real tree node array in this case, and + // we'll start out allocating space for four nodes, which will be all we need up to 28 rows.) + if (freelistSize < height + 2) { + if (height > 0 && !tree[0].parent.isFull() && freelistSize >= height) { + // Slight optimization: The root node is not full, so we're definitely not going to split it. + // That means that the maximum allocations we might do is equal to `height`, not + // `height + 2`, and we have that much space, so no need to grow yet. + // + // This optimization is particularly important for small trees, e.g. when treeCapacity is 4 + // and the tree so far consists of a root and two children, we definitely don't need to grow + // the tree yet. + } else { + growTree(); + + if (freelistHead == 0) { + // We have no root yet. Allocate one. + KJ_ASSERT(alloc().index == 0); + } + } + } + + uint pos = 0; + + // Track grandparent node and child index within grandparent. + Parent* parent = nullptr; + uint indexInParent = 0; + + for (auto i KJ_UNUSED: zeroTo(height)) { + Parent& node = insertHelper(searchKey, tree[pos].parent, parent, indexInParent, pos); + + parent = &node; + indexInParent = searchKey.search(node); + pos = node.children[indexInParent]; + } + + Leaf& leaf = insertHelper(searchKey, tree[pos].leaf, parent, indexInParent, pos); + + // Fun fact: Unlike erase(), there's no need to climb back up the tree modifying keys, because + // either the newly-inserted node will not be the last in the leaf (and thus parent keys aren't + // modified), or the leaf is the last leaf in the tree (and thus there's no parent key to + // modify). + + return { tree, &leaf, searchKey.search(leaf) }; +} + +template +Node& BTreeImpl::insertHelper(const SearchKey& searchKey, + Node& node, Parent* parent, uint indexInParent, uint pos) { + if (node.isFull()) { + // This node is full. Need to split. + if (parent == nullptr) { + // This is the root node. We need to split into two nodes and create a new root. + auto n1 = alloc(); + auto n2 = alloc(); + + uint pivot = split(n2.node, n2.index, node, pos); + move(n1.node, n1.index, node); + + // Rewrite root to have the two children. + tree[0].parent.initRoot(pivot, n1.index, n2.index); + + // Increased height. + ++height; + + // Decide which new branch has our search key. + if (searchKey.isAfter(pivot)) { + // the right one + return n2.node; + } else { + // the left one + return n1.node; + } + } else { + // This is a non-root parent node. We need to split it into two and insert the new node + // into the grandparent. + auto n = alloc(); + uint pivot = split(n.node, n.index, node, pos); + + // Insert new child into grandparent. + parent->insertAfter(indexInParent, pivot, n.index); + + // Decide which new branch has our search key. + if (searchKey.isAfter(pivot)) { + // the new one, which is right of the original + return n.node; + } else { + // the original one, which is left of the new one + return node; + } + } + } else { + // No split needed. + return node; + } +} + +void BTreeImpl::erase(uint row, const SearchKey& searchKey) { + // Erase the given row number from the tree. predicate() returns true for the given row and all + // rows after it. + + uint pos = 0; + + // Track grandparent node and child index within grandparent. + Parent* parent = nullptr; + uint indexInParent = 0; + + MaybeUint* fixup = nullptr; + + for (auto i KJ_UNUSED: zeroTo(height)) { + Parent& node = eraseHelper(tree[pos].parent, parent, indexInParent, pos, fixup); + + parent = &node; + indexInParent = searchKey.search(node); + pos = node.children[indexInParent]; + + if (indexInParent < kj::size(node.keys) && node.keys[indexInParent] == row) { + // Oh look, the row is a key in this node! We'll need to come back and fix this up later. + // Note that any particular row can only appear as *one* key value anywhere in the tree, so + // we only need one fixup pointer, which is nice. + MaybeUint* newFixup = &node.keys[indexInParent]; + if (fixup == newFixup) { + // The fixup pointer was already set while processing a parent node, and then a merge or + // rotate caused it to be moved, but the fixup pointer was updated... so it's already set + // to point at the slot we wanted it to point to, so nothing to see here. + } else { + KJ_DASSERT(fixup == nullptr); + fixup = newFixup; + } + } + } + + Leaf& leaf = eraseHelper(tree[pos].leaf, parent, indexInParent, pos, fixup); + + uint r = searchKey.search(leaf); + if (leaf.rows[r] == row) { + leaf.erase(r); + + if (fixup != nullptr) { + // There's a key in a parent node that needs fixup. This is only possible if the removed + // node is the last in its leaf. + KJ_DASSERT(leaf.rows[r] == nullptr); + KJ_DASSERT(r > 0); // non-root nodes must be at least half full so this can't be item 0 + KJ_DASSERT(*fixup == row); + *fixup = leaf.rows[r - 1]; + } + } else { + logInconsistency(); + } +} + +template +Node& BTreeImpl::eraseHelper( + Node& node, Parent* parent, uint indexInParent, uint pos, MaybeUint*& fixup) { + if (parent != nullptr && !node.isMostlyFull()) { + // This is not the root, but it's only half-full. Rebalance. + KJ_DASSERT(node.isHalfFull()); + + if (indexInParent > 0) { + // There's a sibling to the left. + uint sibPos = parent->children[indexInParent - 1]; + Node& sib = tree[sibPos]; + if (sib.isMostlyFull()) { + // Left sibling is more than half full. Steal one member. + rotateRight(sib, node, *parent, indexInParent - 1); + return node; + } else { + // Left sibling is half full, too. Merge. + KJ_ASSERT(sib.isHalfFull()); + merge(sib, sibPos, *parent->keys[indexInParent - 1], node); + parent->eraseAfter(indexInParent - 1); + free(pos); + if (fixup == &parent->keys[indexInParent]) --fixup; + + if (parent->keys[0] == nullptr) { + // Oh hah, the parent has no keys left. It must be the root. We can eliminate it. + KJ_DASSERT(parent == &tree->parent); + compilerBarrier(); // don't reorder any writes to parent below here + move(tree[0], 0, sib); + free(sibPos); + --height; + return tree[0]; + } else { + return sib; + } + } + } else if (indexInParent < Parent::NKEYS && parent->keys[indexInParent] != nullptr) { + // There's a sibling to the right. + uint sibPos = parent->children[indexInParent + 1]; + Node& sib = tree[sibPos]; + if (sib.isMostlyFull()) { + // Right sibling is more than half full. Steal one member. + rotateLeft(node, sib, *parent, indexInParent, fixup); + return node; + } else { + // Right sibling is half full, too. Merge. + KJ_ASSERT(sib.isHalfFull()); + merge(node, pos, *parent->keys[indexInParent], sib); + parent->eraseAfter(indexInParent); + free(sibPos); + if (fixup == &parent->keys[indexInParent]) fixup = nullptr; + + if (parent->keys[0] == nullptr) { + // Oh hah, the parent has no keys left. It must be the root. We can eliminate it. + KJ_DASSERT(parent == &tree->parent); + compilerBarrier(); // don't reorder any writes to parent below here + move(tree[0], 0, node); + free(pos); + --height; + return tree[0]; + } else { + return node; + } + } + } else { + KJ_FAIL_ASSERT("inconsistent b-tree"); + } + } + + return node; +} + +void BTreeImpl::renumber(uint oldRow, uint newRow, const SearchKey& searchKey) { + // Renumber the given row from oldRow to newRow. predicate() returns true for oldRow and all + // rows after it. (It will not be called on newRow.) + + uint pos = 0; + + for (auto i KJ_UNUSED: zeroTo(height)) { + auto& node = tree[pos].parent; + uint indexInParent = searchKey.search(node); + pos = node.children[indexInParent]; + if (indexInParent < kj::size(node.keys) && node.keys[indexInParent] == oldRow) { + node.keys[indexInParent] = newRow; + } + KJ_DASSERT(pos != 0); + } + + auto& leaf = tree[pos].leaf; + uint r = searchKey.search(leaf); + if (leaf.rows[r] == oldRow) { + leaf.rows[r] = newRow; + } else { + logInconsistency(); + } +} + +uint BTreeImpl::split(Parent& dst, uint dstPos, Parent& src, uint srcPos) { + constexpr size_t mid = Parent::NKEYS / 2; + uint pivot = *src.keys[mid]; + acopy(dst.keys, src.keys + mid + 1, Parent::NKEYS - mid - 1); + azero(src.keys + mid, Parent::NKEYS - mid); + acopy(dst.children, src.children + mid + 1, Parent::NCHILDREN - mid - 1); + azero(src.children + mid + 1, Parent::NCHILDREN - mid - 1); + return pivot; +} + +uint BTreeImpl::split(Leaf& dst, uint dstPos, Leaf& src, uint srcPos) { + constexpr size_t mid = Leaf::NROWS / 2; + uint pivot = *src.rows[mid - 1]; + acopy(dst.rows, src.rows + mid, Leaf::NROWS - mid); + azero(src.rows + mid, Leaf::NROWS - mid); + + if (src.next == 0) { + endLeaf = dstPos; + } else { + tree[src.next].leaf.prev = dstPos; + } + dst.next = src.next; + dst.prev = srcPos; + src.next = dstPos; + + return pivot; +} + +void BTreeImpl::merge(Parent& dst, uint dstPos, uint pivot, Parent& src) { + // merge() is only legal if both nodes are half-empty. Meanwhile, B-tree invariants + // guarantee that the node can't be more than half-empty, or we would have merged it sooner. + // (The root can be more than half-empty, but it is never merged with anything.) + KJ_DASSERT(src.isHalfFull()); + KJ_DASSERT(dst.isHalfFull()); + + constexpr size_t mid = Parent::NKEYS/2; + dst.keys[mid] = pivot; + acopy(dst.keys + mid + 1, src.keys, mid); + acopy(dst.children + mid + 1, src.children, mid + 1); +} + +void BTreeImpl::merge(Leaf& dst, uint dstPos, uint pivot, Leaf& src) { + // merge() is only legal if both nodes are half-empty. Meanwhile, B-tree invariants + // guarantee that the node can't be more than half-empty, or we would have merged it sooner. + // (The root can be more than half-empty, but it is never merged with anything.) + KJ_DASSERT(src.isHalfFull()); + KJ_DASSERT(dst.isHalfFull()); + + constexpr size_t mid = Leaf::NROWS/2; + KJ_DASSERT(dst.rows[mid-1] == pivot); + acopy(dst.rows + mid, src.rows, mid); + + dst.next = src.next; + if (dst.next == 0) { + endLeaf = dstPos; + } else { + tree[dst.next].leaf.prev = dstPos; + } +} + +void BTreeImpl::move(Parent& dst, uint dstPos, Parent& src) { + dst = src; +} + +void BTreeImpl::move(Leaf& dst, uint dstPos, Leaf& src) { + dst = src; + if (src.next == 0) { + endLeaf = dstPos; + } else { + tree[src.next].leaf.prev = dstPos; + } + if (src.prev == 0) { + beginLeaf = dstPos; + } else { + tree[src.prev].leaf.next = dstPos; + } +} + +void BTreeImpl::rotateLeft( + Parent& left, Parent& right, Parent& parent, uint indexInParent, MaybeUint*& fixup) { + // Steal one item from the right node and move it to the left node. + + // Like merge(), this is only called on an exactly-half-empty node. + KJ_DASSERT(left.isHalfFull()); + KJ_DASSERT(right.isMostlyFull()); + + constexpr size_t mid = Parent::NKEYS/2; + left.keys[mid] = parent.keys[indexInParent]; + if (fixup == &parent.keys[indexInParent]) fixup = &left.keys[mid]; + parent.keys[indexInParent] = right.keys[0]; + left.children[mid + 1] = right.children[0]; + amove(right.keys, right.keys + 1, Parent::NKEYS - 1); + right.keys[Parent::NKEYS - 1] = nullptr; + amove(right.children, right.children + 1, Parent::NCHILDREN - 1); + right.children[Parent::NCHILDREN - 1] = 0; +} + +void BTreeImpl::rotateLeft( + Leaf& left, Leaf& right, Parent& parent, uint indexInParent, MaybeUint*& fixup) { + // Steal one item from the right node and move it to the left node. + + // Like merge(), this is only called on an exactly-half-empty node. + KJ_DASSERT(left.isHalfFull()); + KJ_DASSERT(right.isMostlyFull()); + + constexpr size_t mid = Leaf::NROWS/2; + parent.keys[indexInParent] = left.rows[mid] = right.rows[0]; + if (fixup == &parent.keys[indexInParent]) fixup = nullptr; + amove(right.rows, right.rows + 1, Leaf::NROWS - 1); + right.rows[Leaf::NROWS - 1] = nullptr; +} + +void BTreeImpl::rotateRight(Parent& left, Parent& right, Parent& parent, uint indexInParent) { + // Steal one item from the left node and move it to the right node. + + // Like merge(), this is only called on an exactly-half-empty node. + KJ_DASSERT(right.isHalfFull()); + KJ_DASSERT(left.isMostlyFull()); + + constexpr size_t mid = Parent::NKEYS/2; + amove(right.keys + 1, right.keys, mid); + amove(right.children + 1, right.children, mid + 1); + + uint back = left.keyCount() - 1; + + right.keys[0] = parent.keys[indexInParent]; + parent.keys[indexInParent] = left.keys[back]; + right.children[0] = left.children[back + 1]; + left.keys[back] = nullptr; + left.children[back + 1] = 0; +} + +void BTreeImpl::rotateRight(Leaf& left, Leaf& right, Parent& parent, uint indexInParent) { + // Steal one item from the left node and move it to the right node. + + // Like mergeFrom(), this is only called on an exactly-half-empty node. + KJ_DASSERT(right.isHalfFull()); + KJ_DASSERT(left.isMostlyFull()); + + constexpr size_t mid = Leaf::NROWS/2; + amove(right.rows + 1, right.rows, mid); + + uint back = left.size() - 1; + + right.rows[0] = left.rows[back]; + parent.keys[indexInParent] = left.rows[back - 1]; + left.rows[back] = nullptr; +} + +void BTreeImpl::Parent::initRoot(uint key, uint leftChild, uint rightChild) { + // HACK: This is typically called on the root node immediately after copying its contents away, + // but the pointer used to copy it away may be a different pointer pointing to a different + // union member which the compiler may not recgonize as aliasing with this object. Just to + // be extra-safe, insert a compiler barrier. + compilerBarrier(); + + keys[0] = key; + children[0] = leftChild; + children[1] = rightChild; + azero(keys + 1, Parent::NKEYS - 1); + azero(children + 2, Parent::NCHILDREN - 2); +} + +void BTreeImpl::Parent::insertAfter(uint i, uint splitKey, uint child) { + KJ_IREQUIRE(children[Parent::NCHILDREN - 1] == 0); // check not full + + amove(keys + i + 1, keys + i, Parent::NKEYS - (i + 1)); + keys[i] = splitKey; + + amove(children + i + 2, children + i + 1, Parent::NCHILDREN - (i + 2)); + children[i + 1] = child; +} + +void BTreeImpl::Parent::eraseAfter(uint i) { + amove(keys + i, keys + i + 1, Parent::NKEYS - (i + 1)); + keys[Parent::NKEYS - 1] = nullptr; + amove(children + i + 1, children + i + 2, Parent::NCHILDREN - (i + 2)); + children[Parent::NCHILDREN - 1] = 0; +} + +} // namespace _ + +// ======================================================================================= +// Insertion order + +const InsertionOrderIndex::Link InsertionOrderIndex::EMPTY_LINK = { 0, 0 }; + +InsertionOrderIndex::InsertionOrderIndex(): capacity(0), links(const_cast(&EMPTY_LINK)) {} +InsertionOrderIndex::InsertionOrderIndex(InsertionOrderIndex&& other) + : capacity(other.capacity), links(other.links) { + other.capacity = 0; + other.links = const_cast(&EMPTY_LINK); +} +InsertionOrderIndex& InsertionOrderIndex::operator=(InsertionOrderIndex&& other) { + KJ_DASSERT(&other != this); + capacity = other.capacity; + links = other.links; + other.capacity = 0; + other.links = const_cast(&EMPTY_LINK); + return *this; +} +InsertionOrderIndex::~InsertionOrderIndex() noexcept(false) { + if (links != &EMPTY_LINK) delete[] links; +} + +void InsertionOrderIndex::reserve(size_t size) { + KJ_ASSERT(size < (1u << 31), "Table too big for InsertionOrderIndex"); + + if (size > capacity) { + // Need to grow. + // Note that `size` and `capacity` do not include the special link[0]. + + // Round up to the next power of 2. + size_t allocation = 1u << (_::lg(size) + 1); + KJ_DASSERT(allocation > size); + KJ_DASSERT(allocation <= size * 2); + + // Round first allocation up to 8. + allocation = kj::max(allocation, 8); + + Link* newLinks = new Link[allocation]; +#ifdef KJ_DEBUG + // To catch bugs, fill unused links with 0xff. + memset(newLinks, 0xff, allocation * sizeof(Link)); +#endif + _::acopy(newLinks, links, capacity + 1); + if (links != &EMPTY_LINK) delete[] links; + links = newLinks; + capacity = allocation - 1; + } +} + +void InsertionOrderIndex::clear() { + links[0] = Link { 0, 0 }; + +#ifdef KJ_DEBUG + // To catch bugs, fill unused links with 0xff. + memset(links + 1, 0xff, capacity * sizeof(Link)); +#endif +} + +kj::Maybe InsertionOrderIndex::insertImpl(size_t pos) { + if (pos >= capacity) { + reserve(pos + 1); + } + + links[pos + 1].prev = links[0].prev; + links[pos + 1].next = 0; + links[links[0].prev].next = pos + 1; + links[0].prev = pos + 1; + + return nullptr; +} + +void InsertionOrderIndex::eraseImpl(size_t pos) { + Link& link = links[pos + 1]; + links[link.next].prev = link.prev; + links[link.prev].next = link.next; + +#ifdef KJ_DEBUG + memset(&link, 0xff, sizeof(Link)); +#endif +} + +void InsertionOrderIndex::moveImpl(size_t oldPos, size_t newPos) { + Link& link = links[oldPos + 1]; + Link& newLink = links[newPos + 1]; + + newLink = link; + + KJ_DASSERT(links[link.next].prev == oldPos + 1); + KJ_DASSERT(links[link.prev].next == oldPos + 1); + links[link.next].prev = newPos + 1; + links[link.prev].next = newPos + 1; + +#ifdef KJ_DEBUG + memset(&link, 0xff, sizeof(Link)); +#endif +} + +} // namespace kj diff --git a/c++/src/kj/table.h b/c++/src/kj/table.h new file mode 100644 index 0000000000..d5d1b41371 --- /dev/null +++ b/c++/src/kj/table.h @@ -0,0 +1,1652 @@ +// Copyright (c) 2018 Kenton Varda and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include "common.h" +#include "tuple.h" +#include "vector.h" +#include "function.h" + +#if _MSC_VER +// Need _ReadWriteBarrier +#if _MSC_VER < 1910 +#include +#else +#include +#endif +#endif + +#if KJ_DEBUG_TABLE_IMPL +#include "debug.h" +#define KJ_TABLE_IREQUIRE KJ_REQUIRE +#define KJ_TABLE_IASSERT KJ_ASSERT +#else +#define KJ_TABLE_IREQUIRE KJ_IREQUIRE +#define KJ_TABLE_IASSERT KJ_IASSERT +#endif + +KJ_BEGIN_HEADER + +namespace kj { + +class String; + +namespace _ { // private + +template +class TableMapping; +template +using TableIterable = MappedIterable>; +template +using TableIterator = MappedIterator>; + +} // namespace _ (private) + +template +class Table { + // A table with one or more indexes. This is the KJ alternative to map, set, unordered_map, and + // unordered_set. + // + // Unlike a traditional map, which explicitly stores key/value pairs, a Table simply stores + // "rows" of arbitrary type, and then lets the application specify how these should be indexed. + // Rows could be indexed on a specific struct field, or they could be indexed based on a computed + // property. An index could be hash-based or tree-based. Multiple indexes are supported, making + // it easy to construct a "bimap". + // + // The table has deterministic iteration order based on the sequence of insertions and deletions. + // In the case of only insertions, the iteration order is the order of insertion. If deletions + // occur, then the current last row is moved to occupy the deleted slot. This determinism is + // intended to be reliable for the purpose of testing, etc. + // + // Each index is a class that looks like: + // + // class Index { + // public: + // void reserve(size_t size); + // // Called when Table::reserve() is called. + // + // SearchParam& keyForRow(const Row& row) const; + // // Given a row, return a value appropriate to pass as SearchParams to the other functions. + // + // // In all function calls below, `SearchPrams` refers to whatever parameters the index + // // supports for looking up a row in the table. + // + // template + // kj::Maybe insert(kj::ArrayPtr table, size_t pos, SearchParams&&...); + // // Called to indicate that we're about to insert a new row which will match the given + // // search parameters, and will be located at the given position. If this index disallows + // // duplicates and some other matching row already exists, then insert() returns the index + // // of that row without modifying the index. If the row does not exist, then insert() + // // updates the index to note that the new row is located at `pos`. Note that `table[pos]` + // // may not be valid yet at the time of this call; the index must go on the search params + // // alone. + // // + // // Insert may throw an exception, in which case the table will roll back insertion. + // + // template + // void erase(kj::ArrayPtr table, size_t pos, SearchParams&&...); + // // Called to indicate that the index must remove references to row number `pos`. The + // // index must not attempt to access table[pos] directly -- in fact, `pos` may be equal to + // // `table.size()`, i.e., may be out-of-bounds (this happens when rolling back a failed + // // insertion). Instead, the index can use the search params to search for the row -- they + // // will either be the same as the params passed to insert(), or will be a single value of + // // type `Row&`. + // // + // // erase() called immediately after a successful insert() must not throw an exception, as + // // it may be called during unwind. + // + // template + // void move(kj::ArrayPtr table, size_t oldPos, size_t newPos, SearchParams&&...); + // // Called when a row is about to be moved from `oldPos` to `newPos` in the table. The + // // index should update it to the new location. Neither `table[oldPos]` nor `table[newPos]` + // // is valid during the call -- use the search params to find the row. Before this call + // // `oldPos` is indexed and `newPos` is not -- after the call, the opposite is true. + // // + // // This should never throw; if it does the table may be corrupted. + // + // class Iterator; // Behaves like a C++ iterator over size_t values. + // class Iterable; // Has begin() and end() methods returning iterators. + // + // template + // Maybe find(kj::ArrayPtr table, SearchParams&&...) const; + // // Optional. Implements Table::find(...). + // + // template + // Iterator seek(kj::ArrayPtr table, SearchParams&&...) const; + // // Optional. Implements Table::seek() and Table::range(...). + // + // Iterator begin() const; + // Iterator end() const; + // // Optional. Implements Table::ordered(). + // }; + +public: + Table(); + Table(Indexes&&... indexes); + + void reserve(size_t size); + // Pre-allocates space for a table of the given size. Normally a Table grows by re-allocating + // its backing array whenever more space is needed. Reserving in advance avoids redundantly + // re-allocating as the table grows. + + size_t size() const; + size_t capacity() const; + + void clear(); + + Row* begin(); + Row* end(); + const Row* begin() const; + const Row* end() const; + + Row& insert(Row&& row); + Row& insert(const Row& row); + // Inserts a new row. Throws an exception if this would violate the uniqueness constraints of any + // of the indexes. + + template + void insertAll(Collection&& collection); + template + void insertAll(Collection& collection); + // Given an iterable collection of Rows, inserts all of them into this table. If the input is + // an rvalue, the rows will be moved rather than copied. + // + // If an insertion throws (e.g. because it violates a uniqueness constraint of some index), + // subsequent insertions do not occur, but previous insertions remain inserted. + + template + Row& upsert(Row&& row, UpdateFunc&& update); + template + Row& upsert(const Row& row, UpdateFunc&& update); + // Tries to insert a new row. However, if a duplicate already exists (according to some index), + // then update(Row& existingRow, Row&& newRow) is called to modify the existing row. + + template + kj::Maybe find(Params&&... params); + template + kj::Maybe find(Params&&... params) const; + // Using the given index, search for a matching row. What parameters are accepted depends on the + // index. Not all indexes support this method -- "multimap" indexes may support only range(). + + template + Row& findOrCreate(Params&&... params, Func&& createFunc); + // Like find(), but if the row doesn't exist, call a function to create it. createFunc() must + // return `Row` or something that implicitly converts to `Row`. + // + // NOTE: C++ doesn't actually properly support inferring types of a parameter pack at the + // beginning of an argument list, but we define a hack to support it below. Don't worry about + // it. + + template + auto range(BeginKey&& begin, EndKey&& end); + template + auto range(BeginKey&& begin, EndKey&& end) const; + // Using the given index, look up a range of values, returning an iterable. What parameters are + // accepted depends on the index. Not all indexes support this method (in particular, unordered + // indexes normally don't). + + template + _::TableIterable ordered(); + template + _::TableIterable ordered() const; + // Returns an iterable over the whole table ordered using the given index. Not all indexes + // support this method. + + template + auto seek(Params&&... params); + template + auto seek(Params&&... params) const; + // Takes same parameters as find(), but returns an iterator at the position where the search + // key should go. That is, this returns an iterator that points to the matching entry or, if + // there is no matching entry, points at the next entry after the key, in order. Or, if there + // is no such entry, the returned iterator is the same as ordered().end(). + // + // seek() is only supported by indexes that support ordered(). It returns the same kind of + // iterator that ordered() uses. + + template + bool eraseMatch(Params&&... params); + // Erase the row that would be matched by `find(params)`. Returns true if there was a + // match. + + template + size_t eraseRange(BeginKey&& begin, EndKey&& end); + // Erase the row that would be matched by `range(params)`. Returns the number of + // elements erased. + + void erase(Row& row); + // Erase the given row. + // + // WARNING: This invalidates all iterators, so you can't iterate over rows and erase them this + // way. Use `eraseAll()` for that. + + Row release(Row& row); + // Remove the given row from the table and return it in one operation. + // + // WARNING: This invalidates all iterators, so you can't iterate over rows and release them this + // way. + + template ()(instance()))> + size_t eraseAll(Predicate&& predicate); + // Erase all rows for which predicate(row) returns true. This scans over the entire table. + + template ().begin()), bool = true> + size_t eraseAll(Collection&& collection); + // Erase all rows in the given iterable collection of rows. This carefully marks rows for + // deletion in a first pass then deletes them in a second. + + template + kj::Maybe find(Params&&... params); + template + kj::Maybe find(Params&&... params) const; + template + Row& findOrCreate(Params&&... params, Func&& createFunc); + template + auto range(BeginKey&& begin, EndKey&& end); + template + auto range(BeginKey&& begin, EndKey&& end) const; + template + _::TableIterable>&> ordered(); + template + _::TableIterable>&> ordered() const; + template + auto seek(Params&&... params); + template + auto seek(Params&&... params) const; + template + bool eraseMatch(Params&&... params); + template + size_t eraseRange(BeginKey&& begin, EndKey&& end); + // Methods which take an index type as a template parameter can also take an index number. This + // is useful particularly when you have multiple indexes of the same type but different runtime + // properties. Additionally, you can omit the template parameter altogether to use the first + // index. + + template + void verify(); + // Checks the integrity of indexes, throwing an exception if there are any problems. This is + // intended to be called within the unit test for an index. + + template + Row& findOrCreate(First&& first, Rest&&... rest); + template + Row& findOrCreate(First&& first, Rest&&... rest); + // HACK: A parameter pack can only be inferred if it lives at the end of the argument list, so + // the findOrCreate() definitions from earlier won't actually work. These ones will, but we + // have to do some annoying things inside to regroup the arguments. + +private: + Vector rows; + Tuple indexes; + + template = sizeof...(Indexes))> + class Impl; + template + class FindOrCreateImpl; + + template + struct FindOrCreateHack; + + void eraseImpl(size_t pos); + template + size_t eraseAllImpl(Collection&& collection); +}; + +template +class HashIndex; +// A Table index based on a hash table. +// +// This implementation: +// * Is based on linear probing, not chaining. It is important to use a high-quality hash function. +// Use the KJ hashing library if possible. +// * Is limited to tables of 2^30 rows or less, mainly to allow for tighter packing with 32-bit +// integers instead of 64-bit. +// * Caches hash codes so that each table row need only be hashed once, and never checks equality +// unless hash codes have already been determined to be equal. +// +// The `Callbacks` type defines how to compute hash codes and equality. It should be defined like: +// +// class Callbacks { +// public: +// // In this interface, `SearchParams...` means whatever parameters you want to support in +// // a call to table.find(...). By overloading the calls to support various inputs, you can +// // affect what table.find(...) accepts. +// +// SearchParam& keyForRow(const Row& row); +// // Given a row of the table, return the SearchParams that might be passed to the other +// // methods to match this row. +// +// bool matches(const Row&, SearchParams&&...) const; +// // Returns true if the row on the left matches the search params on the right. +// +// uint hashCode(SearchParams&&...) const; +// // Computes the hash code of the given search params. Matching rows (as determined by +// // matches()) must have the same hash code. Non-matching rows should have different hash +// // codes, to the maximum extent possible. Non-matching rows with the same hash code hurt +// // performance. +// }; +// +// If your `Callbacks` type has dynamic state, you may pass its constructor parameters as the +// constructor parameters to `HashIndex`. + +template +class TreeIndex; +// A Table index based on a B-tree. +// +// This allows sorted iteration over rows. +// +// The `Callbacks` type defines how to compare rows. It should be defined like: +// +// class Callbacks { +// public: +// // In this interface, `SearchParams...` means whatever parameters you want to support in +// // a call to table.find(...). By overloading the calls to support various inputs, you can +// // affect what table.find(...) accepts. +// +// SearchParam& keyForRow(const Row& row); +// // Given a row of the table, return the SearchParams that might be passed to the other +// // methods to match this row. +// +// bool isBefore(const Row&, SearchParams&&...) const; +// // Returns true if the row on the left comes before the search params on the right. +// +// bool matches(const Row&, SearchParams&&...) const; +// // Returns true if the row "matches" the search params. +// }; + +// ======================================================================================= +// inline implementation details + +namespace _ { // private + +KJ_NORETURN(void throwDuplicateTableRow()); + +template ().size())> +inline void tryReserveSize(Dst& dst, Src&& src) { dst.reserve(dst.size() + src.size()); } +template +inline void tryReserveSize(Params&&...) {} +// If `src` has a `.size()` method, call dst.reserve(dst.size() + src.size()). +// Otherwise, do nothing. + +template +class TableMapping { +public: + TableMapping(Row* table): table(table) {} + Row& map(size_t i) const { return table[i]; } + +private: + Row* table; +}; + +template +class TableUnmapping { +public: + TableUnmapping(Row* table): table(table) {} + size_t map(Row& row) const { return &row - table; } + size_t map(Row* row) const { return row - table; } + +private: + Row* table; +}; + +template +class IterRange { +public: + inline IterRange(Iterator b, Iterator e): b(b), e(e) {} + + inline Iterator begin() const { return b; } + inline Iterator end() const { return e; } +private: + Iterator b; + Iterator e; +}; + +template +inline IterRange> iterRange(Iterator b, Iterator e) { + return { b, e }; +} + +} // namespace _ (private) + +template +template +class Table::Impl { +public: + static void reserve(Table& table, size_t size) { + get(table.indexes).reserve(size); + Impl::reserve(table, size); + } + + static void clear(Table& table) { + get(table.indexes).clear(); + Impl::clear(table); + } + + static kj::Maybe insert(Table& table, size_t pos, Row& row, uint skip) { + if (skip == index) { + return Impl::insert(table, pos, row, skip); + } + auto& indexObj = get(table.indexes); + KJ_IF_MAYBE(existing, indexObj.insert(table.rows.asPtr(), pos, indexObj.keyForRow(row))) { + return *existing; + } + + bool success = false; + KJ_DEFER(if (!success) { + indexObj.erase(table.rows.asPtr(), pos, indexObj.keyForRow(row)); + }); + auto result = Impl::insert(table, pos, row, skip); + success = result == nullptr; + return result; + } + + static void erase(Table& table, size_t pos, Row& row) { + auto& indexObj = get(table.indexes); + indexObj.erase(table.rows.asPtr(), pos, indexObj.keyForRow(row)); + Impl::erase(table, pos, row); + } + + static void move(Table& table, size_t oldPos, size_t newPos, Row& row) { + auto& indexObj = get(table.indexes); + indexObj.move(table.rows.asPtr(), oldPos, newPos, indexObj.keyForRow(row)); + Impl::move(table, oldPos, newPos, row); + } +}; + +template +template +class Table::Impl { +public: + static void reserve(Table& table, size_t size) {} + static void clear(Table& table) {} + static kj::Maybe insert(Table& table, size_t pos, Row& row, uint skip) { + return nullptr; + } + static void erase(Table& table, size_t pos, Row& row) {} + static void move(Table& table, size_t oldPos, size_t newPos, Row& row) {} +}; + +template +Table::Table() {} + +template +Table::Table(Indexes&&... indexes) + : indexes(tuple(kj::fwd(indexes)...)) {} + +template +void Table::reserve(size_t size) { + rows.reserve(size); + Impl<>::reserve(*this, size); +} + +template +size_t Table::size() const { + return rows.size(); +} +template +void Table::clear() { + Impl<>::clear(*this); + rows.clear(); +} +template +size_t Table::capacity() const { + return rows.capacity(); +} + +template +Row* Table::begin() { + return rows.begin(); +} +template +Row* Table::end() { + return rows.end(); +} +template +const Row* Table::begin() const { + return rows.begin(); +} +template +const Row* Table::end() const { + return rows.end(); +} + +template +Row& Table::insert(Row&& row) { + KJ_IF_MAYBE(existing, Impl<>::insert(*this, rows.size(), row, kj::maxValue)) { + _::throwDuplicateTableRow(); + } else { + return rows.add(kj::mv(row)); + } +} +template +Row& Table::insert(const Row& row) { + return insert(kj::cp(row)); +} + +template +template +void Table::insertAll(Collection&& collection) { + _::tryReserveSize(*this, collection); + for (auto& row: collection) { + insert(kj::mv(row)); + } +} + +template +template +void Table::insertAll(Collection& collection) { + _::tryReserveSize(*this, collection); + for (auto& row: collection) { + insert(row); + } +} + +template +template +Row& Table::upsert(Row&& row, UpdateFunc&& update) { + KJ_IF_MAYBE(existing, Impl<>::insert(*this, rows.size(), row, kj::maxValue)) { + update(rows[*existing], kj::mv(row)); + return rows[*existing]; + } else { + return rows.add(kj::mv(row)); + } +} +template +template +Row& Table::upsert(const Row& row, UpdateFunc&& update) { + return upsert(kj::cp(row), kj::fwd(update)); +} + +template +template +kj::Maybe Table::find(Params&&... params) { + return find>()>(kj::fwd(params)...); +} +template +template +kj::Maybe Table::find(Params&&... params) { + KJ_IF_MAYBE(pos, get(indexes).find(rows.asPtr(), kj::fwd(params)...)) { + return rows[*pos]; + } else { + return nullptr; + } +} +template +template +kj::Maybe Table::find(Params&&... params) const { + return find>()>(kj::fwd(params)...); +} +template +template +kj::Maybe Table::find(Params&&... params) const { + KJ_IF_MAYBE(pos, get(indexes).find(rows.asPtr(), kj::fwd(params)...)) { + return rows[*pos]; + } else { + return nullptr; + } +} + +template +template +class Table::FindOrCreateImpl { +public: + template + static Row& apply(Table& table, Params&&... params, Func&& createFunc) { + auto pos = table.rows.size(); + KJ_IF_MAYBE(existing, get(table.indexes).insert(table.rows.asPtr(), pos, params...)) { + return table.rows[*existing]; + } else { + bool success = false; + KJ_DEFER({ + if (!success) { + get(table.indexes).erase(table.rows.asPtr(), pos, params...); + } + }); + auto& newRow = table.rows.add(createFunc()); + KJ_DEFER({ + if (!success) { + table.rows.removeLast(); + } + }); + if (Table::template Impl<>::insert(table, pos, newRow, index) == nullptr) { + success = true; + } else { + _::throwDuplicateTableRow(); + } + return newRow; + } + } +}; + +template +template +struct Table::FindOrCreateHack<_::Tuple, U, V, W...> + : public FindOrCreateHack<_::Tuple, V, W...> {}; +template +template +struct Table::FindOrCreateHack<_::Tuple, U> + : public FindOrCreateImpl {}; +// This awful hack works around C++'s lack of support for parameter packs anywhere other than at +// the end of an argument list. We accumulate all of the types except for the last one into a +// Tuple, then forward to FindOrCreateImpl with the last parameter as the Func. + +template +template +Row& Table::findOrCreate(First&& first, Rest&&... rest) { + return findOrCreate>()>( + kj::fwd(first), kj::fwd(rest)...); +} +template +template +Row& Table::findOrCreate(First&& first, Rest&&... rest) { + return FindOrCreateHack<_::Tuple<>, First, Rest...>::template apply( + *this, kj::fwd(first), kj::fwd(rest)...); +} + +template +template +auto Table::range(BeginKey&& begin, EndKey&& end) { + return range>()>( + kj::fwd(begin), kj::fwd(end)); +} +template +template +auto Table::range(BeginKey&& begin, EndKey&& end) { + auto inner = _::iterRange(get(indexes).seek(rows.asPtr(), kj::fwd(begin)), + get(indexes).seek(rows.asPtr(), kj::fwd(end))); + return _::TableIterable(kj::mv(inner), rows.begin()); +} +template +template +auto Table::range(BeginKey&& begin, EndKey&& end) const { + return range>()>( + kj::fwd(begin), kj::fwd(end)); +} +template +template +auto Table::range(BeginKey&& begin, EndKey&& end) const { + auto inner = _::iterRange(get(indexes).seek(rows.asPtr(), kj::fwd(begin)), + get(indexes).seek(rows.asPtr(), kj::fwd(end))); + return _::TableIterable(kj::mv(inner), rows.begin()); +} + +template +template +_::TableIterable Table::ordered() { + return ordered>()>(); +} +template +template +_::TableIterable>&> Table::ordered() { + return { get(indexes), rows.begin() }; +} +template +template +_::TableIterable Table::ordered() const { + return ordered>()>(); +} +template +template +_::TableIterable>&> +Table::ordered() const { + return { get(indexes), rows.begin() }; +} + +template +template +auto Table::seek(Params&&... params) { + return seek>()>(kj::fwd(params)...); +} +template +template +auto Table::seek(Params&&... params) { + auto inner = get(indexes).seek(rows.asPtr(), kj::fwd(params)...); + return _::TableIterator(kj::mv(inner), rows.begin()); +} +template +template +auto Table::seek(Params&&... params) const { + return seek>()>(kj::fwd(params)...); +} +template +template +auto Table::seek(Params&&... params) const { + auto inner = get(indexes).seek(rows.asPtr(), kj::fwd(params)...); + return _::TableIterator(kj::mv(inner), rows.begin()); +} + +template +template +bool Table::eraseMatch(Params&&... params) { + return eraseMatch>()>(kj::fwd(params)...); +} +template +template +bool Table::eraseMatch(Params&&... params) { + KJ_IF_MAYBE(pos, get(indexes).find(rows.asPtr(), kj::fwd(params)...)) { + eraseImpl(*pos); + return true; + } else { + return false; + } +} + +template +template +size_t Table::eraseRange(BeginKey&& begin, EndKey&& end) { + return eraseRange>()>( + kj::fwd(begin), kj::fwd(end)); +} +template +template +size_t Table::eraseRange(BeginKey&& begin, EndKey&& end) { + auto inner = _::iterRange(get(indexes).seek(rows.asPtr(), kj::fwd(begin)), + get(indexes).seek(rows.asPtr(), kj::fwd(end))); + return eraseAllImpl(inner); +} + +template +template +void Table::verify() { + get(indexes).verify(rows.asPtr()); +} + +template +void Table::erase(Row& row) { + KJ_TABLE_IREQUIRE(&row >= rows.begin() && &row < rows.end(), "row is not a member of this table"); + eraseImpl(&row - rows.begin()); +} +template +void Table::eraseImpl(size_t pos) { + Impl<>::erase(*this, pos, rows[pos]); + size_t back = rows.size() - 1; + if (pos != back) { + Impl<>::move(*this, back, pos, rows[back]); + rows[pos] = kj::mv(rows[back]); + } + rows.removeLast(); +} + +template +Row Table::release(Row& row) { + KJ_TABLE_IREQUIRE(&row >= rows.begin() && &row < rows.end(), "row is not a member of this table"); + size_t pos = &row - rows.begin(); + Impl<>::erase(*this, pos, row); + Row result = kj::mv(row); + size_t back = rows.size() - 1; + if (pos != back) { + Impl<>::move(*this, back, pos, rows[back]); + row = kj::mv(rows[back]); + } + rows.removeLast(); + return result; +} + +template +template +size_t Table::eraseAll(Predicate&& predicate) { + size_t count = 0; + for (size_t i = 0; i < rows.size();) { + if (predicate(rows[i])) { + eraseImpl(i); + ++count; + // eraseImpl() replaces the erased row with the last row, so don't increment i here; repeat + // with the same i. + } else { + ++i; + } + } + return count; +} + +template +template +size_t Table::eraseAll(Collection&& collection) { + return eraseAllImpl(MappedIterable>( + collection, rows.begin())); +} + +template +template +size_t Table::eraseAllImpl(Collection&& collection) { + // We need to transform the collection of row numbers into a sequence of erasures, accounting + // for the fact that each erasure re-positions the last row into its slot. + Vector erased; + _::tryReserveSize(erased, collection); + for (size_t pos: collection) { + while (pos >= rows.size() - erased.size()) { + // Oops, the next item to be erased is already scheduled to be moved to a different location + // due to a previous erasure. Figure out where it will be at this point. + size_t erasureNumber = rows.size() - pos - 1; + pos = erased[erasureNumber]; + } + erased.add(pos); + } + + // Now we can execute the sequence of erasures. + for (size_t pos: erased) { + eraseImpl(pos); + } + + return erased.size(); +} + +// ----------------------------------------------------------------------------- +// Hash table index + +namespace _ { // private + +void logHashTableInconsistency(); + +struct HashBucket { + uint hash; + uint value; + + HashBucket() = default; + HashBucket(uint hash, uint pos) + : hash(hash), value(pos + 2) {} + + inline bool isEmpty() const { return value == 0; } + inline bool isErased() const { return value == 1; } + inline bool isOccupied() const { return value >= 2; } + template + inline Row& getRow(ArrayPtr table) const { return table[getPos()]; } + template + inline const Row& getRow(ArrayPtr table) const { return table[getPos()]; } + inline bool isPos(uint pos) const { return pos + 2 == value; } + inline uint getPos() const { + KJ_TABLE_IASSERT(value >= 2); + return value - 2; + } + inline void setEmpty() { value = 0; } + inline void setErased() { value = 1; } + inline void setPos(uint pos) { value = pos + 2; } +}; + +inline size_t probeHash(const kj::Array& buckets, size_t i) { + // TODO(perf): Is linear probing OK or should we do something fancier? + if (++i == buckets.size()) { + return 0; + } else { + return i; + } +} + +kj::Array rehash(kj::ArrayPtr oldBuckets, size_t targetSize); + +uint chooseBucket(uint hash, uint count); + +} // namespace _ (private) + +template +class HashIndex { +public: + HashIndex() = default; + template + HashIndex(Params&&... params): cb(kj::fwd(params)...) {} + + size_t capacity() { + // This method is for testing. + return buckets.size(); + } + + void reserve(size_t size) { + if (buckets.size() < size * 2) { + rehash(size); + } + } + + void clear() { + erasedCount = 0; + if (buckets.size() > 0) memset(buckets.begin(), 0, buckets.asBytes().size()); + } + + template + decltype(auto) keyForRow(Row&& row) const { + return cb.keyForRow(kj::fwd(row)); + } + + template + kj::Maybe insert(kj::ArrayPtr table, size_t pos, Params&&... params) { + if (buckets.size() * 2 < (table.size() + 1 + erasedCount) * 3) { + // Load factor is more than 2/3, let's rehash so that it's 1/3, i.e. double the buckets. + // Note that rehashing also cleans up erased entries, so we may not actually be doubling if + // there are a lot of erasures. Nevertheless, this gives us amortized constant time -- it + // would take at least O(table.size()) more insertions (whether or not erasures occur) + // before another rehash is needed. + rehash((table.size() + 1) * 3); + } + + uint hashCode = cb.hashCode(params...); + Maybe<_::HashBucket&> erasedSlot; + for (uint i = _::chooseBucket(hashCode, buckets.size());; i = _::probeHash(buckets, i)) { + auto& bucket = buckets[i]; + if (bucket.isEmpty()) { + // no duplicates found + KJ_IF_MAYBE(s, erasedSlot) { + --erasedCount; + *s = { hashCode, uint(pos) }; + } else { + bucket = { hashCode, uint(pos) }; + } + return nullptr; + } else if (bucket.isErased()) { + // We can fill in the erased slot. However, we have to keep searching to make sure there + // are no duplicates before we do that. + if (erasedSlot == nullptr) { + erasedSlot = bucket; + } + } else if (bucket.hash == hashCode && + cb.matches(bucket.getRow(table), params...)) { + // duplicate row + return size_t(bucket.getPos()); + } + } + } + + template + void erase(kj::ArrayPtr table, size_t pos, Params&&... params) { + uint hashCode = cb.hashCode(params...); + for (uint i = _::chooseBucket(hashCode, buckets.size());; i = _::probeHash(buckets, i)) { + auto& bucket = buckets[i]; + if (bucket.isPos(pos)) { + // found it + ++erasedCount; + bucket.setErased(); + return; + } else if (bucket.isEmpty()) { + // can't find the bucket, something is very wrong + _::logHashTableInconsistency(); + return; + } + } + } + + template + void move(kj::ArrayPtr table, size_t oldPos, size_t newPos, Params&&... params) { + uint hashCode = cb.hashCode(params...); + for (uint i = _::chooseBucket(hashCode, buckets.size());; i = _::probeHash(buckets, i)) { + auto& bucket = buckets[i]; + if (bucket.isPos(oldPos)) { + // found it + bucket.setPos(newPos); + return; + } else if (bucket.isEmpty()) { + // can't find the bucket, something is very wrong + _::logHashTableInconsistency(); + return; + } + } + } + + template + Maybe find(kj::ArrayPtr table, Params&&... params) const { + if (buckets.size() == 0) return nullptr; + + uint hashCode = cb.hashCode(params...); + for (uint i = _::chooseBucket(hashCode, buckets.size());; i = _::probeHash(buckets, i)) { + auto& bucket = buckets[i]; + if (bucket.isEmpty()) { + // not found. + return nullptr; + } else if (bucket.isErased()) { + // skip, keep searching + } else if (bucket.hash == hashCode && + cb.matches(bucket.getRow(table), params...)) { + // found + return size_t(bucket.getPos()); + } + } + } + + // No begin() nor end() because hash tables are not usefully ordered. + +private: + Callbacks cb; + size_t erasedCount = 0; + Array<_::HashBucket> buckets; + + void rehash(size_t targetSize) { + buckets = _::rehash(buckets, targetSize); + erasedCount = 0; + } +}; + +// ----------------------------------------------------------------------------- +// BTree index + +namespace _ { // private + +KJ_ALWAYS_INLINE(void compilerBarrier()); +void compilerBarrier() { + // Make sure that reads occurring before this call cannot be re-ordered to happen after + // writes that occur after this call. We need this in a couple places below to prevent C++ + // strict aliasing rules from breaking things. +#if _MSC_VER + _ReadWriteBarrier(); +#else + __asm__ __volatile__("": : :"memory"); +#endif +} + +template +inline void acopy(T* to, T* from, size_t size) { memcpy(to, from, size * sizeof(T)); } +template +inline void amove(T* to, T* from, size_t size) { memmove(to, from, size * sizeof(T)); } +template +inline void azero(T* ptr, size_t size) { memset(ptr, 0, size * sizeof(T)); } +// memcpy/memmove/memset variants that count size in elements, not bytes. +// +// TODO(cleanup): These are generally useful, put them somewhere. + +class BTreeImpl { +public: + class Iterator; + class MaybeUint; + struct NodeUnion; + struct Leaf; + struct Parent; + struct Freelisted; + + class SearchKey { + // Passed to methods that need to search the tree. This class allows most of the B-tree + // implementation to be kept out of templates, avoiding code bloat, at the cost of some + // performance trade-off. In order to lessen the performance cost of virtual calls, we design + // this interface so that it only needs to be called once per tree node, rather than once per + // comparison. + + public: + virtual uint search(const Parent& parent) const = 0; + virtual uint search(const Leaf& leaf) const = 0; + // Binary search for the first key/row in the parent/leaf that is equal to or comes after the + // search key. + + virtual bool isAfter(uint rowIndex) const = 0; + // Returns true if the key comes after the value in the given row. + }; + + BTreeImpl(); + ~BTreeImpl() noexcept(false); + + KJ_DISALLOW_COPY(BTreeImpl); + BTreeImpl(BTreeImpl&& other); + BTreeImpl& operator=(BTreeImpl&& other); + + void logInconsistency() const; + + void reserve(size_t size); + + void clear(); + + Iterator begin() const; + Iterator end() const; + + Iterator search(const SearchKey& searchKey) const; + // Find the "first" row (in sorted order) for which searchKey.isAfter(rowNumber) returns true. + + Iterator insert(const SearchKey& searchKey); + // Like search() but ensures that there is room in the leaf node to insert a new row. + + void erase(uint row, const SearchKey& searchKey); + // Erase the given row number from the tree. searchKey.isAfter() returns true for the given row + // and all rows after it. + + void renumber(uint oldRow, uint newRow, const SearchKey& searchKey); + // Renumber the given row from oldRow to newRow. searchKey.isAfter() returns true for oldRow and + // all rows after it. (It will not be called on newRow.) + + void verify(size_t size, FunctionParam); + +private: + NodeUnion* tree; // allocated with aligned_alloc aligned to cache lines + uint treeCapacity; + uint height; // height of *parent* tree -- does not include the leaf level + uint freelistHead; + uint freelistSize; + uint beginLeaf; + uint endLeaf; + void growTree(uint minCapacity = 0); + + template + struct AllocResult; + + template + inline AllocResult alloc(); + inline void free(uint pos); + + inline uint split(Parent& src, uint srcPos, Parent& dst, uint dstPos); + inline uint split(Leaf& dst, uint dstPos, Leaf& src, uint srcPos); + inline void merge(Parent& dst, uint dstPos, uint pivot, Parent& src); + inline void merge(Leaf& dst, uint dstPos, uint pivot, Leaf& src); + inline void move(Parent& dst, uint dstPos, Parent& src); + inline void move(Leaf& dst, uint dstPos, Leaf& src); + inline void rotateLeft( + Parent& left, Parent& right, Parent& parent, uint indexInParent, MaybeUint*& fixup); + inline void rotateLeft( + Leaf& left, Leaf& right, Parent& parent, uint indexInParent, MaybeUint*& fixup); + inline void rotateRight(Parent& left, Parent& right, Parent& parent, uint indexInParent); + inline void rotateRight(Leaf& left, Leaf& right, Parent& parent, uint indexInParent); + + template + inline Node& insertHelper(const SearchKey& searchKey, + Node& node, Parent* parent, uint indexInParent, uint pos); + + template + inline Node& eraseHelper( + Node& node, Parent* parent, uint indexInParent, uint pos, MaybeUint*& fixup); + + size_t verifyNode(size_t size, FunctionParam&, + uint pos, uint height, MaybeUint maxRow); + + static const NodeUnion EMPTY_NODE; +}; + +class BTreeImpl::MaybeUint { + // A nullable uint, using the value zero to mean null and shifting all other values up by 1. +public: + MaybeUint() = default; + inline MaybeUint(uint i): i(i - 1) {} + inline MaybeUint(decltype(nullptr)): i(0) {} + + inline bool operator==(decltype(nullptr)) const { return i == 0; } + inline bool operator==(uint j) const { return i == j + 1; } + inline bool operator==(const MaybeUint& other) const { return i == other.i; } + inline bool operator!=(decltype(nullptr)) const { return i != 0; } + inline bool operator!=(uint j) const { return i != j + 1; } + inline bool operator!=(const MaybeUint& other) const { return i != other.i; } + + inline MaybeUint& operator=(decltype(nullptr)) { i = 0; return *this; } + inline MaybeUint& operator=(uint j) { i = j + 1; return *this; } + + inline uint operator*() const { KJ_TABLE_IREQUIRE(i != 0); return i - 1; } + + template + inline bool check(Func& func) const { return i != 0 && func(i - 1); } + // Equivalent to *this != nullptr && func(**this) + + kj::String toString() const; + +private: + uint i; +}; + +struct BTreeImpl::Leaf { + uint next; + uint prev; + // Pointers to next and previous nodes at the same level, used for fast iteration. + + static constexpr size_t NROWS = 14; + MaybeUint rows[NROWS]; + // Pointers to table rows, offset by 1 so that 0 is an empty value. + + inline bool isFull() const; + inline bool isMostlyFull() const; + inline bool isHalfFull() const; + + inline void insert(uint i, uint newRow) { + KJ_TABLE_IREQUIRE(rows[Leaf::NROWS - 1] == nullptr); // check not full + + amove(rows + i + 1, rows + i, Leaf::NROWS - (i + 1)); + rows[i] = newRow; + } + + inline void erase(uint i) { + KJ_TABLE_IREQUIRE(rows[0] != nullptr); // check not empty + + amove(rows + i, rows + i + 1, Leaf::NROWS - (i + 1)); + rows[Leaf::NROWS - 1] = nullptr; + } + + inline uint size() const { + static_assert(Leaf::NROWS == 14, "logic here needs updating"); + + // Binary search for first empty element in `rows`, or return 14 if no empty elements. We do + // this in a branch-free manner. Since there are 15 possible results (0 through 14, inclusive), + // this isn't a perfectly balanced binary search. We carefully choose the split points so that + // there's no way we'll try to dereference row[14] or later (which would be a buffer overflow). + uint i = (rows[6] != nullptr) * 7; + i += (rows[i + 3] != nullptr) * 4; + i += (rows[i + 1] != nullptr) * 2; + i += (rows[i ] != nullptr); + return i; + } + + template + inline uint binarySearch(Func& predicate) const { + // Binary search to find first row for which predicate(row) is false. + + static_assert(Leaf::NROWS == 14, "logic here needs updating"); + + // See comments in size(). + uint i = (rows[6].check(predicate)) * 7; + i += (rows[i + 3].check(predicate)) * 4; + i += (rows[i + 1].check(predicate)) * 2; + if (i != 6) { // don't redundantly check row 6 + i += (rows[i ].check(predicate)); + } + return i; + } +}; + +struct BTreeImpl::Parent { + uint unused; + // Not used. May be arbitrarily non-zero due to overlap with Freelisted::nextOffset. + + static constexpr size_t NKEYS = 7; + MaybeUint keys[NKEYS]; + // Pointers to table rows, offset by 1 so that 0 is an empty value. + // + // Each keys[i] specifies the table row which is the "last" row found under children[i]. + // + // Note that `keys` has size 7 but `children` has size 8. `children[8]`'s "last row" is not + // recorded here, because the Parent's Parent records it instead. (Or maybe the Parent's Parent's + // Parent, if this Parent is `children[8]` of its own Parent. And so on.) + + static constexpr size_t NCHILDREN = NKEYS + 1; + uint children[NCHILDREN]; + // Pointers to children. Not offset because the root is always at position 0, and a pointer + // to the root would be nonsensical. + + inline bool isFull() const; + inline bool isMostlyFull() const; + inline bool isHalfFull() const; + inline void initRoot(uint key, uint leftChild, uint rightChild); + inline void insertAfter(uint i, uint splitKey, uint child); + inline void eraseAfter(uint i); + + inline uint keyCount() const { + static_assert(Parent::NKEYS == 7, "logic here needs updating"); + + // Binary search for first empty element in `keys`, or return 7 if no empty elements. We do + // this in a branch-free manner. Since there are 8 possible results (0 through 7, inclusive), + // this is a perfectly balanced binary search. + uint i = (keys[3] != nullptr) * 4; + i += (keys[i + 1] != nullptr) * 2; + i += (keys[i ] != nullptr); + return i; + } + + template + inline uint binarySearch(Func& predicate) const { + // Binary search to find first key for which predicate(key) is false. + + static_assert(Parent::NKEYS == 7, "logic here needs updating"); + + // See comments in size(). + uint i = (keys[3].check(predicate)) * 4; + i += (keys[i + 1].check(predicate)) * 2; + i += (keys[i ].check(predicate)); + return i; + } +}; + +struct BTreeImpl::Freelisted { + int nextOffset; + // The next node in the freelist is at: this + 1 + nextOffset + // + // Hence, newly-allocated space can initialize this to zero. + + uint zero[15]; + // Freelisted entries are always zero'd. +}; + +struct BTreeImpl::NodeUnion { + union { + Freelisted freelist; + // If this node is in the freelist. + + Leaf leaf; + // If this node is a leaf. + + Parent parent; + // If this node is not a leaf. + }; + + inline operator Leaf&() { return leaf; } + inline operator Parent&() { return parent; } + inline operator const Leaf&() const { return leaf; } + inline operator const Parent&() const { return parent; } +}; + +static_assert(sizeof(BTreeImpl::Parent) == 64, + "BTreeImpl::Parent should be optimized to fit a cache line"); +static_assert(sizeof(BTreeImpl::Leaf) == 64, + "BTreeImpl::Leaf should be optimized to fit a cache line"); +static_assert(sizeof(BTreeImpl::Freelisted) == 64, + "BTreeImpl::Freelisted should be optimized to fit a cache line"); +static_assert(sizeof(BTreeImpl::NodeUnion) == 64, + "BTreeImpl::NodeUnion should be optimized to fit a cache line"); + +bool BTreeImpl::Leaf::isFull() const { + return rows[Leaf::NROWS - 1] != nullptr; +} +bool BTreeImpl::Leaf::isMostlyFull() const { + return rows[Leaf::NROWS / 2] != nullptr; +} +bool BTreeImpl::Leaf::isHalfFull() const { + KJ_TABLE_IASSERT(rows[Leaf::NROWS / 2 - 1] != nullptr); + return rows[Leaf::NROWS / 2] == nullptr; +} + +bool BTreeImpl::Parent::isFull() const { + return keys[Parent::NKEYS - 1] != nullptr; +} +bool BTreeImpl::Parent::isMostlyFull() const { + return keys[Parent::NKEYS / 2] != nullptr; +} +bool BTreeImpl::Parent::isHalfFull() const { + KJ_TABLE_IASSERT(keys[Parent::NKEYS / 2 - 1] != nullptr); + return keys[Parent::NKEYS / 2] == nullptr; +} + +class BTreeImpl::Iterator { +public: + Iterator(const NodeUnion* tree, const Leaf* leaf, uint row) + : tree(tree), leaf(leaf), row(row) {} + + size_t operator*() const { + KJ_TABLE_IREQUIRE(row < Leaf::NROWS && leaf->rows[row] != nullptr, + "tried to dereference end() iterator"); + return *leaf->rows[row]; + } + + inline Iterator& operator++() { + KJ_TABLE_IREQUIRE(leaf->rows[row] != nullptr, "B-tree iterator overflow"); + ++row; + if (row >= Leaf::NROWS || leaf->rows[row] == nullptr) { + if (leaf->next == 0) { + // at end; stay on current leaf + } else { + leaf = &tree[leaf->next].leaf; + row = 0; + } + } + return *this; + } + inline Iterator operator++(int) { + Iterator other = *this; + ++*this; + return other; + } + + inline Iterator& operator--() { + if (row == 0) { + KJ_TABLE_IREQUIRE(leaf->prev != 0, "B-tree iterator underflow"); + leaf = &tree[leaf->prev].leaf; + row = leaf->size() - 1; + } else { + --row; + } + return *this; + } + inline Iterator operator--(int) { + Iterator other = *this; + --*this; + return other; + } + + inline bool operator==(const Iterator& other) const { + return leaf == other.leaf && row == other.row; + } + inline bool operator!=(const Iterator& other) const { + return leaf != other.leaf || row != other.row; + } + + bool isEnd() { + return row == Leaf::NROWS || leaf->rows[row] == nullptr; + } + + void insert(BTreeImpl& impl, uint newRow) { + KJ_TABLE_IASSERT(impl.tree == tree); + const_cast(leaf)->insert(row, newRow); + } + + void erase(BTreeImpl& impl) { + KJ_TABLE_IASSERT(impl.tree == tree); + const_cast(leaf)->erase(row); + } + + void replace(BTreeImpl& impl, uint newRow) { + KJ_TABLE_IASSERT(impl.tree == tree); + const_cast(leaf)->rows[row] = newRow; + } + +private: + const NodeUnion* tree; + const Leaf* leaf; + uint row; +}; + +inline BTreeImpl::Iterator BTreeImpl::begin() const { + return { tree, &tree[beginLeaf].leaf, 0 }; +} +inline BTreeImpl::Iterator BTreeImpl::end() const { + auto& leaf = tree[endLeaf].leaf; + return { tree, &leaf, leaf.size() }; +} + +} // namespace _ (private) + +template +class TreeIndex { +public: + TreeIndex() = default; + template + TreeIndex(Params&&... params): cb(kj::fwd(params)...) {} + + template + void verify(kj::ArrayPtr table) { + impl.verify(table.size(), [&](uint i, uint j) { + return cb.isBefore(table[i], table[j]); + }); + } + + inline void reserve(size_t size) { impl.reserve(size); } + inline void clear() { impl.clear(); } + inline auto begin() const { return impl.begin(); } + inline auto end() const { return impl.end(); } + + template + decltype(auto) keyForRow(Row&& row) const { + return cb.keyForRow(kj::fwd(row)); + } + + template + kj::Maybe insert(kj::ArrayPtr table, size_t pos, Params&&... params) { + auto iter = impl.insert(searchKey(table, params...)); + + if (!iter.isEnd() && cb.matches(table[*iter], params...)) { + return *iter; + } else { + iter.insert(impl, pos); + return nullptr; + } + } + + template + void erase(kj::ArrayPtr table, size_t pos, Params&&... params) { + impl.erase(pos, searchKeyForErase(table, pos, params...)); + } + + template + void move(kj::ArrayPtr table, size_t oldPos, size_t newPos, Params&&... params) { + impl.renumber(oldPos, newPos, searchKey(table, params...)); + } + + template + Maybe find(kj::ArrayPtr table, Params&&... params) const { + auto iter = impl.search(searchKey(table, params...)); + + if (!iter.isEnd() && cb.matches(table[*iter], params...)) { + return size_t(*iter); + } else { + return nullptr; + } + } + + template + _::BTreeImpl::Iterator seek(kj::ArrayPtr table, Params&&... params) const { + return impl.search(searchKey(table, params...)); + } + +private: + Callbacks cb; + _::BTreeImpl impl; + + template + class SearchKeyImpl: public _::BTreeImpl::SearchKey { + public: + SearchKeyImpl(Predicate&& predicate) + : predicate(kj::mv(predicate)) {} + + uint search(const _::BTreeImpl::Parent& parent) const override { + return parent.binarySearch(predicate); + } + uint search(const _::BTreeImpl::Leaf& leaf) const override { + return leaf.binarySearch(predicate); + } + bool isAfter(uint rowIndex) const override { + return predicate(rowIndex); + } + + private: + Predicate predicate; + }; + + template + inline auto searchKey(kj::ArrayPtr& table, Params&... params) const { + auto predicate = [&](uint i) { return cb.isBefore(table[i], params...); }; + return SearchKeyImpl(kj::mv(predicate)); + } + + template + inline auto searchKeyForErase(kj::ArrayPtr& table, uint pos, Params&... params) const { + // When erasing, the table entry for the erased row may already be invalid, so we must avoid + // accessing it. + auto predicate = [&,pos](uint i) { + return i != pos && cb.isBefore(table[i], params...); + }; + return SearchKeyImpl(kj::mv(predicate)); + } +}; + +// ----------------------------------------------------------------------------- +// Insertion order index + +class InsertionOrderIndex { + // Table index which allows iterating over elements in order of insertion. This index cannot + // be used for Table::find(), but can be used for Table::ordered(). + + struct Link; +public: + InsertionOrderIndex(); + InsertionOrderIndex(const InsertionOrderIndex&) = delete; + InsertionOrderIndex& operator=(const InsertionOrderIndex&) = delete; + InsertionOrderIndex(InsertionOrderIndex&& other); + InsertionOrderIndex& operator=(InsertionOrderIndex&& other); + ~InsertionOrderIndex() noexcept(false); + + class Iterator { + public: + Iterator(const Link* links, uint pos) + : links(links), pos(pos) {} + + inline size_t operator*() const { + KJ_TABLE_IREQUIRE(pos != 0, "can't dereference end() iterator"); + return pos - 1; + }; + + inline Iterator& operator++() { + pos = links[pos].next; + return *this; + } + inline Iterator operator++(int) { + Iterator result = *this; + ++*this; + return result; + } + inline Iterator& operator--() { + pos = links[pos].prev; + return *this; + } + inline Iterator operator--(int) { + Iterator result = *this; + --*this; + return result; + } + + inline bool operator==(const Iterator& other) const { + return pos == other.pos; + } + inline bool operator!=(const Iterator& other) const { + return pos != other.pos; + } + + private: + const Link* links; + uint pos; + }; + + template + Row& keyForRow(Row& row) const { return row; } + + void reserve(size_t size); + void clear(); + inline Iterator begin() const { return Iterator(links, links[0].next); } + inline Iterator end() const { return Iterator(links, 0); } + + template + kj::Maybe insert(kj::ArrayPtr table, size_t pos, const Row& row) { + return insertImpl(pos); + } + + template + void erase(kj::ArrayPtr table, size_t pos, const Row& row) { + eraseImpl(pos); + } + + template + void move(kj::ArrayPtr table, size_t oldPos, size_t newPos, const Row& row) { + return moveImpl(oldPos, newPos); + } + +private: + struct Link { + uint next; + uint prev; + }; + + uint capacity; + Link* links; + // links[0] is special: links[0].next points to the first link, links[0].prev points to the last. + // links[n+1] corresponds to row n. + + kj::Maybe insertImpl(size_t pos); + void eraseImpl(size_t pos); + void moveImpl(size_t oldPos, size_t newPos); + + static const Link EMPTY_LINK; +}; + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/test-helpers.c++ b/c++/src/kj/test-helpers.c++ index 1cc5650017..6ae8cd3259 100644 --- a/c++/src/kj/test-helpers.c++ +++ b/c++/src/kj/test-helpers.c++ @@ -19,7 +19,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include "test.h" + +#include #ifndef _WIN32 #include #include @@ -32,13 +38,26 @@ namespace kj { namespace _ { // private bool hasSubstring(StringPtr haystack, StringPtr needle) { - // TODO(perf): This is not the best algorithm for substring matching. if (needle.size() <= haystack.size()) { + // Boyer Moore Horspool wins https://quick-bench.com/q/RiKdKduhdLb6x_DfS1fHaksqwdQ + // https://quick-bench.com/q/KV8irwXrkvsNMbNpP8ENR_tBEPY but libc++ only has default_searcher + // which performs *drastically worse* than the naiive algorithm (seriously - why even bother?). + // Hell, doing a query for an embedded null & dispatching to strstr is still cheaper & only + // marginally slower than the purely naiive implementation. + +#if !defined(_WIN32) + return memmem(haystack.begin(), haystack.size(), needle.begin(), needle.size()) != nullptr; +#else + // TODO(perf): This is not the best algorithm for substring matching. strstr can't be used + // because this is supposed to be safe to call on strings with embedded nulls. + // Amusingly this naiive algorithm some times outperforms std::default_searcher, even if we need + // to double-check first if the needle has an embedded null (indicating std::search ). for (size_t i = 0; i <= haystack.size() - needle.size(); i++) { if (haystack.slice(i).startsWith(needle)) { return true; } } +#endif } return false; } @@ -128,7 +147,75 @@ bool expectFatalThrow(kj::Maybe type, kj::Maybe mess KJ_FAIL_EXPECT("subprocess crashed without throwing exception", WTERMSIG(status)); return false; } else { - KJ_FAIL_EXPECT("subprocess neiter excited nor crashed?", status); + KJ_FAIL_EXPECT("subprocess neither excited nor crashed?", status); + return false; + } +#endif +} + +bool expectExit(Maybe statusCode, FunctionParam code) noexcept { +#if _WIN32 + // We don't support death tests on Windows due to lack of efficient fork. + return true; +#else + pid_t child; + KJ_SYSCALL(child = fork()); + if (child == 0) { + code(); + _exit(0); + } + + int status; + KJ_SYSCALL(waitpid(child, &status, 0)); + + if (WIFEXITED(status)) { + KJ_IF_MAYBE(s, statusCode) { + KJ_EXPECT(WEXITSTATUS(status) == *s); + return WEXITSTATUS(status) == *s; + } else { + KJ_EXPECT(WEXITSTATUS(status) != 0); + return WEXITSTATUS(status) != 0; + } + } else { + if (WIFSIGNALED(status)) { + KJ_FAIL_EXPECT("subprocess didn't exit but triggered a signal", strsignal(WTERMSIG(status))); + } else { + KJ_FAIL_EXPECT("subprocess didn't exit and didn't trigger a signal", status); + } + return false; + } +#endif +} + + +bool expectSignal(Maybe signal, FunctionParam code) noexcept { +#if _WIN32 + // We don't support death tests on Windows due to lack of efficient fork. + return true; +#else + pid_t child; + KJ_SYSCALL(child = fork()); + if (child == 0) { + resetCrashHandlers(); + code(); + _exit(0); + } + + int status; + KJ_SYSCALL(waitpid(child, &status, 0)); + + if (WIFSIGNALED(status)) { + KJ_IF_MAYBE(s, signal) { + KJ_EXPECT(WTERMSIG(status) == *s); + return WTERMSIG(status) == *s; + } + return true; + } else { + if (WIFEXITED(status)) { + KJ_FAIL_EXPECT("subprocess didn't trigger a signal but exited", WEXITSTATUS(status)); + } else { + KJ_FAIL_EXPECT("subprocess didn't exit and didn't trigger a signal", status); + } return false; } #endif diff --git a/c++/src/kj/test-test.c++ b/c++/src/kj/test-test.c++ index 7d02027094..b69eaf6ab7 100644 --- a/c++/src/kj/test-test.c++ +++ b/c++/src/kj/test-test.c++ @@ -21,6 +21,13 @@ #include "common.h" #include "test.h" +#include +#include +#include + +#ifndef _WIN32 +#include +#endif namespace kj { namespace _ { @@ -78,6 +85,26 @@ KJ_TEST("GlobFilter") { } } +KJ_TEST("expect exit from exit") { + KJ_EXPECT_EXIT(42, _exit(42)); + KJ_EXPECT_EXIT(nullptr, _exit(42)); +} + +#if !KJ_NO_EXCEPTIONS +KJ_TEST("expect exit from thrown exception") { + KJ_EXPECT_EXIT(1, throw std::logic_error("test error")); +} +#endif + +KJ_TEST("expect signal from abort") { + KJ_EXPECT_SIGNAL(SIGABRT, abort()); +} + +KJ_TEST("expect signal from sigint") { + KJ_EXPECT_SIGNAL(SIGINT, raise(SIGINT)); + KJ_EXPECT_SIGNAL(nullptr, raise(SIGINT)); +} + } // namespace } // namespace _ } // namespace kj diff --git a/c++/src/kj/test.c++ b/c++/src/kj/test.c++ index 33f1bdb793..e310f20f05 100644 --- a/c++/src/kj/test.c++ +++ b/c++/src/kj/test.c++ @@ -19,6 +19,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + #include "test.h" #include "main.h" #include "io.h" @@ -26,6 +30,7 @@ #include #include #include +#include "time.h" #ifndef _WIN32 #include #endif @@ -37,6 +42,8 @@ namespace { TestCase* testCasesHead = nullptr; TestCase** testCasesTail = &testCasesHead; +size_t benchmarkIterCount = 1; + } // namespace TestCase::TestCase(const char* file, uint line, const char* description) @@ -55,6 +62,10 @@ TestCase::~TestCase() { } } +size_t TestCase::iterCount() { + return benchmarkIterCount; +} + // ======================================================================================= namespace _ { // private @@ -167,7 +178,7 @@ public: if (severity == LogSeverity::ERROR || severity == LogSeverity::FATAL) { sawError = true; - context.error(kj::str(text, "\nstack: ", strArray(trace, " "), stringifyStackTrace(trace))); + context.error(kj::str(text, "\nstack: ", stringifyStackTraceAddresses(trace), stringifyStackTrace(trace))); } else { context.warning(text); } @@ -178,6 +189,10 @@ private: bool sawError = false; }; +TimePoint readClock() { + return systemPreciseMonotonicClock().now(); +} + } // namespace class TestRunner { @@ -196,6 +211,9 @@ public: .addOption({'l', "list"}, KJ_BIND_METHOD(*this, setList), "List all test cases that would run, but don't run them. If --filter is specified " "then only the match tests will be listed.") + .addOptionWithArg({'b', "benchmark"}, KJ_BIND_METHOD(*this, setBenchmarkIters), "", + "Specifies that any benchmarks in the tests should run for iterations. " + "If not specified, then count is 1, which simply tests that the benchmarks function.") .callAfterParsing(KJ_BIND_METHOD(*this, run)) .build(); } @@ -254,6 +272,15 @@ public: return true; } + MainBuilder::Validity setBenchmarkIters(StringPtr param) { + KJ_IF_MAYBE(i, param.tryParseAs()) { + benchmarkIterCount = *i; + return true; + } else { + return "expected an integer"; + } + } + MainBuilder::Validity run() { if (testCasesHead == nullptr) { return "no tests were declared"; @@ -287,6 +314,7 @@ public: if (!listOnly) { bool currentFailed = true; + auto start = readClock(); KJ_IF_MAYBE(exception, runCatchingExceptions([&]() { TestExceptionCallback exceptionCallback(context); testCase->run(); @@ -294,12 +322,15 @@ public: })) { context.error(kj::str(*exception)); } + auto end = readClock(); + + auto message = kj::str(name, " (", (end - start) / kj::MICROSECONDS, " μs)"); if (currentFailed) { - write(RED, "[ FAIL ]", name); + write(RED, "[ FAIL ]", message); ++failCount; } else { - write(GREEN, "[ PASS ]", name); + write(GREEN, "[ PASS ]", message); ++passCount; } } diff --git a/c++/src/kj/test.h b/c++/src/kj/test.h index 69e1c80840..5acbb00d40 100644 --- a/c++/src/kj/test.h +++ b/c++/src/kj/test.h @@ -19,16 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_TEST_H_ -#define KJ_TEST_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "debug.h" #include "vector.h" #include "function.h" +#include "windows-sanity.h" // work-around macro conflict with `ERROR` + +KJ_BEGIN_HEADER namespace kj { @@ -41,6 +39,21 @@ class TestCase { virtual void run() = 0; +protected: + template + void doBenchmark(Func&& func) { + // Perform a benchmark with configurable iterations. func() will be called N times, where N + // is set by the --benchmark CLI flag. This defaults to 1, so that when --benchmark is not + // specified, we only test that the benchmark works. + // + // In the future, this could adaptively choose iteration count by running a few iterations to + // find out how fast the benchmark is, then scaling. + + for (size_t i = iterCount(); i-- > 0;) { + func(); + } + } + private: const char* file; uint line; @@ -49,6 +62,8 @@ class TestCase { TestCase** prev; bool matchedFilter; + static size_t iterCount(); + friend class TestRunner; }; @@ -62,45 +77,69 @@ class TestCase { } KJ_UNIQUE_NAME(testCase); \ void KJ_UNIQUE_NAME(TestCase)::run() -#if _MSC_VER +#if KJ_MSVC_TRADITIONAL_CPP #define KJ_INDIRECT_EXPAND(m, vargs) m vargs #define KJ_FAIL_EXPECT(...) \ KJ_INDIRECT_EXPAND(KJ_LOG, (ERROR , __VA_ARGS__)); #define KJ_EXPECT(cond, ...) \ - if (cond); else KJ_INDIRECT_EXPAND(KJ_FAIL_EXPECT, ("failed: expected " #cond , __VA_ARGS__)) + if (auto _kjCondition = ::kj::_::MAGIC_ASSERT << cond); \ + else KJ_INDIRECT_EXPAND(KJ_FAIL_EXPECT, ("failed: expected " #cond , _kjCondition, __VA_ARGS__)) #else #define KJ_FAIL_EXPECT(...) \ KJ_LOG(ERROR, ##__VA_ARGS__); #define KJ_EXPECT(cond, ...) \ - if (cond); else KJ_FAIL_EXPECT("failed: expected " #cond, ##__VA_ARGS__) + if (auto _kjCondition = ::kj::_::MAGIC_ASSERT << cond); \ + else KJ_FAIL_EXPECT("failed: expected " #cond, _kjCondition, ##__VA_ARGS__) #endif -#define KJ_EXPECT_THROW_RECOVERABLE(type, code) \ +#if _MSC_VER && !defined(__clang__) +#define KJ_EXPECT_THROW_RECOVERABLE(type, code, ...) \ + do { \ + KJ_IF_MAYBE(e, ::kj::runCatchingExceptions([&]() { code; })) { \ + KJ_INDIRECT_EXPAND(KJ_EXPECT, (e->getType() == ::kj::Exception::Type::type, \ + "code threw wrong exception type: " #code, *e, __VA_ARGS__)); \ + } else { \ + KJ_INDIRECT_EXPAND(KJ_FAIL_EXPECT, ("code did not throw: " #code, __VA_ARGS__)); \ + } \ + } while (false) + +#define KJ_EXPECT_THROW_RECOVERABLE_MESSAGE(message, code, ...) \ + do { \ + KJ_IF_MAYBE(e, ::kj::runCatchingExceptions([&]() { code; })) { \ + KJ_INDIRECT_EXPAND(KJ_EXPECT, (::kj::_::hasSubstring(e->getDescription(), message), \ + "exception description didn't contain expected substring", *e, __VA_ARGS__)); \ + } else { \ + KJ_INDIRECT_EXPAND(KJ_FAIL_EXPECT, ("code did not throw: " #code, __VA_ARGS__)); \ + } \ + } while (false) +#else +#define KJ_EXPECT_THROW_RECOVERABLE(type, code, ...) \ do { \ KJ_IF_MAYBE(e, ::kj::runCatchingExceptions([&]() { code; })) { \ KJ_EXPECT(e->getType() == ::kj::Exception::Type::type, \ - "code threw wrong exception type: " #code, e->getType()); \ + "code threw wrong exception type: " #code, *e, ##__VA_ARGS__); \ } else { \ - KJ_FAIL_EXPECT("code did not throw: " #code); \ + KJ_FAIL_EXPECT("code did not throw: " #code, ##__VA_ARGS__); \ } \ } while (false) -#define KJ_EXPECT_THROW_RECOVERABLE_MESSAGE(message, code) \ +#define KJ_EXPECT_THROW_RECOVERABLE_MESSAGE(message, code, ...) \ do { \ KJ_IF_MAYBE(e, ::kj::runCatchingExceptions([&]() { code; })) { \ KJ_EXPECT(::kj::_::hasSubstring(e->getDescription(), message), \ - "exception description didn't contain expected substring", e->getDescription()); \ + "exception description didn't contain expected substring", *e, ##__VA_ARGS__); \ } else { \ - KJ_FAIL_EXPECT("code did not throw: " #code); \ + KJ_FAIL_EXPECT("code did not throw: " #code, ##__VA_ARGS__); \ } \ } while (false) +#endif #if KJ_NO_EXCEPTIONS -#define KJ_EXPECT_THROW(type, code) \ +#define KJ_EXPECT_THROW(type, code, ...) \ do { \ - KJ_EXPECT(::kj::_::expectFatalThrow(type, nullptr, [&]() { code; })); \ + KJ_EXPECT(::kj::_::expectFatalThrow(::kj::Exception::Type::type, nullptr, [&]() { code; })); \ } while (false) -#define KJ_EXPECT_THROW_MESSAGE(message, code) \ +#define KJ_EXPECT_THROW_MESSAGE(message, code, ...) \ do { \ KJ_EXPECT(::kj::_::expectFatalThrow(nullptr, kj::StringPtr(message), [&]() { code; })); \ } while (false) @@ -109,6 +148,19 @@ class TestCase { #define KJ_EXPECT_THROW_MESSAGE KJ_EXPECT_THROW_RECOVERABLE_MESSAGE #endif +#define KJ_EXPECT_EXIT(statusCode, code) \ + do { \ + KJ_EXPECT(::kj::_::expectExit(statusCode, [&]() { code; })); \ + } while (false) +// Forks the code and expects it to exit with a given code. + +#define KJ_EXPECT_SIGNAL(signal, code) \ + do { \ + KJ_EXPECT(::kj::_::expectSignal(signal, [&]() { code; })); \ + } while (false) +// Forks the code and expects it to trigger a signal. +// In the child resets all signal handlers as printStackTraceOnCrash sets. + #define KJ_EXPECT_LOG(level, substring) \ ::kj::_::LogExpectation KJ_UNIQUE_NAME(_kjLogExpectation)(::kj::LogSeverity::level, substring) // Expects that a log message with the given level and substring text will be printed within @@ -128,6 +180,17 @@ bool expectFatalThrow(Maybe type, Maybe message, // fork() is not available, this always returns true. #endif +bool expectExit(Maybe statusCode, FunctionParam code) noexcept; +// Expects that the given code will exit with a given statusCode. +// The test will fork() and run in a subprocess. On Windows, where fork() is not available, +// this always returns true. + +bool expectSignal(Maybe signal, FunctionParam code) noexcept; +// Expects that the given code will trigger a signal. +// The test will fork() and run in a subprocess. On Windows, where fork() is not available, +// this always returns true. +// Resets signal handlers to default prior to running the code in the child process. + class LogExpectation: public ExceptionCallback { public: LogExpectation(LogSeverity severity, StringPtr substring); @@ -164,4 +227,4 @@ class GlobFilter { } // namespace _ (private) } // namespace kj -#endif // KJ_TEST_H_ +KJ_END_HEADER diff --git a/c++/src/kj/thread-test.c++ b/c++/src/kj/thread-test.c++ index 028fd68338..043e7a6191 100644 --- a/c++/src/kj/thread-test.c++ +++ b/c++/src/kj/thread-test.c++ @@ -86,5 +86,39 @@ KJ_TEST("detaching thread doesn't delete function") { } } +class CapturingExceptionCallback final: public ExceptionCallback { +public: + CapturingExceptionCallback(String& target): target(target) {} + + void logMessage(LogSeverity severity, const char* file, int line, int contextDepth, + String&& text) { + target = kj::mv(text); + } + +private: + String& target; +}; + +class ThreadedExceptionCallback final: public ExceptionCallback { +public: + Function)> getThreadInitializer() override { + return [this](Function func) { + CapturingExceptionCallback context(captured); + func(); + }; + } + + String captured; +}; + +KJ_TEST("threads pick up exception callback initializer") { + ThreadedExceptionCallback context; + KJ_EXPECT(context.captured != "foobar"); + Thread([]() { + KJ_LOG(ERROR, "foobar"); + }); + KJ_EXPECT(context.captured == "foobar", context.captured); +} + } // namespace } // namespace kj diff --git a/c++/src/kj/thread.c++ b/c++/src/kj/thread.c++ index dcdf6f12ba..e013c07c5e 100644 --- a/c++/src/kj/thread.c++ +++ b/c++/src/kj/thread.c++ @@ -34,7 +34,7 @@ namespace kj { #if _WIN32 -Thread::Thread(Function func): state(new ThreadState { kj::mv(func), nullptr, 2 }) { +Thread::Thread(Function func): state(new ThreadState(kj::mv(func))) { threadHandle = CreateThread(nullptr, 0, &runThread, state, 0, nullptr); if (threadHandle == nullptr) { state->unref(); @@ -61,20 +61,9 @@ void Thread::detach() { detached = true; } -DWORD Thread::runThread(void* ptr) { - ThreadState* state = reinterpret_cast(ptr); - KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { - state->func(); - })) { - state->exception = kj::mv(*exception); - } - state->unref(); - return 0; -} - #else // _WIN32 -Thread::Thread(Function func): state(new ThreadState { kj::mv(func), nullptr, 2 }) { +Thread::Thread(Function func): state(new ThreadState(kj::mv(func))) { static_assert(sizeof(threadId) >= sizeof(pthread_t), "pthread_t is larger than a long long on your platform. Please port."); @@ -119,21 +108,16 @@ void Thread::detach() { state->unref(); } -void* Thread::runThread(void* ptr) { - ThreadState* state = reinterpret_cast(ptr); - KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { - state->func(); - })) { - state->exception = kj::mv(*exception); - } - state->unref(); - return nullptr; -} - #endif // _WIN32, else +Thread::ThreadState::ThreadState(Function func) + : func(kj::mv(func)), + initializer(getExceptionCallback().getThreadInitializer()), + exception(nullptr), + refcount(2) {} + void Thread::ThreadState::unref() { -#if _MSC_VER +#if _MSC_VER && !defined(__clang__) if (_InterlockedDecrement(&refcount) == 0) { #else if (__atomic_sub_fetch(&refcount, 1, __ATOMIC_RELEASE) == 0) { @@ -141,11 +125,33 @@ void Thread::ThreadState::unref() { #endif KJ_IF_MAYBE(e, exception) { - KJ_LOG(ERROR, "uncaught exception thrown by detached thread", *e); + // If the exception is still present in ThreadState, this must be a detached thread, so + // the exception will never be rethrown. We should at least log it. + // + // We need to run the thread initializer again before we log anything because the main + // purpose of the thread initializer is to set up a logging callback. + initializer([&]() { + KJ_LOG(ERROR, "uncaught exception thrown by detached thread", *e); + }); } delete this; } } +#if _WIN32 +DWORD Thread::runThread(void* ptr) { +#else +void* Thread::runThread(void* ptr) { +#endif + ThreadState* state = reinterpret_cast(ptr); + KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { + state->initializer(kj::mv(state->func)); + })) { + state->exception = kj::mv(*exception); + } + state->unref(); + return 0; +} + } // namespace kj diff --git a/c++/src/kj/thread.h b/c++/src/kj/thread.h index b17b88c520..2261ab12c9 100644 --- a/c++/src/kj/thread.h +++ b/c++/src/kj/thread.h @@ -19,17 +19,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_THREAD_H_ -#define KJ_THREAD_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "common.h" #include "function.h" #include "exception.h" +KJ_BEGIN_HEADER + namespace kj { class Thread { @@ -39,7 +36,7 @@ class Thread { public: explicit Thread(Function func); - KJ_DISALLOW_COPY(Thread); + KJ_DISALLOW_COPY_AND_MOVE(Thread); ~Thread() noexcept(false); @@ -53,7 +50,10 @@ class Thread { private: struct ThreadState { + ThreadState(Function func); + Function func; + Function)> initializer; kj::Maybe exception; unsigned int refcount; @@ -79,4 +79,4 @@ class Thread { } // namespace kj -#endif // KJ_THREAD_H_ +KJ_END_HEADER diff --git a/c++/src/kj/threadlocal.h b/c++/src/kj/threadlocal.h index 67d0db60ef..613b96e788 100644 --- a/c++/src/kj/threadlocal.h +++ b/c++/src/kj/threadlocal.h @@ -20,12 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_THREADLOCAL_H_ -#define KJ_THREADLOCAL_H_ +#pragma once -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif // This file declares a macro `KJ_THREADLOCAL_PTR` for declaring thread-local pointer-typed // variables. Use like: // KJ_THREADLOCAL_PTR(MyType) foo = nullptr; @@ -47,83 +43,17 @@ #include "common.h" -#if !defined(KJ_USE_PTHREAD_THREADLOCAL) && defined(__APPLE__) -#include "TargetConditionals.h" -#if TARGET_OS_IPHONE -// iOS apparently does not support __thread (nor C++11 thread_local). -#define KJ_USE_PTHREAD_TLS 1 -#endif -#endif - -#if KJ_USE_PTHREAD_TLS -#include -#endif +KJ_BEGIN_HEADER namespace kj { -#if KJ_USE_PTHREAD_TLS -// If __thread is unavailable, we'll fall back to pthreads. - -#define KJ_THREADLOCAL_PTR(type) \ - namespace { struct KJ_UNIQUE_NAME(_kj_TlpTag); } \ - static ::kj::_::ThreadLocalPtr< type, KJ_UNIQUE_NAME(_kj_TlpTag)> -// Hack: In order to ensure each thread-local results in a unique template instance, we declare -// a one-off dummy type to use as the second type parameter. - -namespace _ { // private - -template -class ThreadLocalPtr { - // Hacky type to emulate __thread T*. We need a separate instance of the ThreadLocalPtr template - // for every thread-local variable, because we don't want to require a global constructor, and in - // order to initialize the TLS on first use we need to use a local static variable (in getKey()). - // Each template instance will get a separate such local static variable, fulfilling our need. - -public: - ThreadLocalPtr() = default; - constexpr ThreadLocalPtr(decltype(nullptr)) {} - // Allow initialization to nullptr without a global constructor. - - inline ThreadLocalPtr& operator=(T* val) { - pthread_setspecific(getKey(), val); - return *this; - } - - inline operator T*() const { - return get(); - } - - inline T& operator*() const { - return *get(); - } - - inline T* operator->() const { - return get(); - } - -private: - inline T* get() const { - return reinterpret_cast(pthread_getspecific(getKey())); - } - - inline static pthread_key_t getKey() { - static pthread_key_t key = createKey(); - return key; - } - - static pthread_key_t createKey() { - pthread_key_t key; - pthread_key_create(&key, 0); - return key; - } -}; - -} // namespace _ (private) - -#elif __GNUC__ +#if __GNUC__ #define KJ_THREADLOCAL_PTR(type) static __thread type* // GCC's __thread is lighter-weight than thread_local and is good enough for our purposes. +// +// TODO(cleanup): The above comment was written many years ago. Is it still true? Shouldn't the +// compiler be smart enough to optimize a thread_local of POD type? #else @@ -133,4 +63,4 @@ class ThreadLocalPtr { } // namespace kj -#endif // KJ_THREADLOCAL_H_ +KJ_END_HEADER diff --git a/c++/src/kj/time-test.c++ b/c++/src/kj/time-test.c++ new file mode 100644 index 0000000000..8bbdf9366b --- /dev/null +++ b/c++/src/kj/time-test.c++ @@ -0,0 +1,150 @@ +// Copyright (c) 2019 Cloudflare, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#if _WIN32 +#include "win32-api-version.h" +#endif + +#include "time.h" +#include "debug.h" +#include +#include + +#if _WIN32 +#include +#include "windows-sanity.h" +#else +#include +#endif + +namespace kj { +namespace { + +KJ_TEST("stringify times") { + KJ_EXPECT(kj::str(50 * kj::SECONDS) == "50s"); + KJ_EXPECT(kj::str(5 * kj::SECONDS + 2 * kj::MILLISECONDS) == "5.002s"); + KJ_EXPECT(kj::str(256 * kj::MILLISECONDS) == "256ms"); + KJ_EXPECT(kj::str(5 * kj::MILLISECONDS + 2 * kj::NANOSECONDS) == "5.000002ms"); + KJ_EXPECT(kj::str(50 * kj::MICROSECONDS) == "50μs"); + KJ_EXPECT(kj::str(5 * kj::MICROSECONDS + 300 * kj::NANOSECONDS) == "5.3μs"); + KJ_EXPECT(kj::str(50 * kj::NANOSECONDS) == "50ns"); + KJ_EXPECT(kj::str(-256 * kj::MILLISECONDS) == "-256ms"); + KJ_EXPECT(kj::str(-50 * kj::NANOSECONDS) == "-50ns"); + KJ_EXPECT(kj::str((int64_t)kj::maxValue * kj::NANOSECONDS) == "9223372036.854775807s"); + KJ_EXPECT(kj::str((int64_t)kj::minValue * kj::NANOSECONDS) == "-9223372036.854775808s"); +} + +#if _WIN32 +void delay(kj::Duration d) { + Sleep(d / kj::MILLISECONDS); +} +#else +void delay(kj::Duration d) { + usleep(d / kj::MICROSECONDS); +} +#endif + +KJ_TEST("calendar clocks matches unix time") { + // Check that the times returned by the calendar clock are within 1s of what time() returns. + + auto& coarse = systemCoarseCalendarClock(); + auto& precise = systemPreciseCalendarClock(); + + Date p = precise.now(); + Date c = coarse.now(); + time_t t = time(nullptr); + + int64_t pi = (p - UNIX_EPOCH) / kj::SECONDS; + int64_t ci = (c - UNIX_EPOCH) / kj::SECONDS; + + KJ_EXPECT(pi >= t - 1); + KJ_EXPECT(pi <= t + 1); + KJ_EXPECT(ci >= t - 1); + KJ_EXPECT(ci <= t + 1); +} + +KJ_TEST("monotonic clocks match each other") { + // Check that the monotonic clocks return comparable times. + + auto& coarse = systemCoarseMonotonicClock(); + auto& precise = systemPreciseMonotonicClock(); + + TimePoint p = precise.now(); + TimePoint c = coarse.now(); + + // 40ms tolerance due to Windows timeslices being quite long, especially on GitHub Actions where + // Windows is drunk and has completely lost track of time. + KJ_EXPECT(p < c + 40 * kj::MILLISECONDS, p - c); + KJ_EXPECT(p > c - 40 * kj::MILLISECONDS, c - p); +} + +KJ_TEST("all clocks advance in real time") { + Duration coarseCalDiff; + Duration preciseCalDiff; + Duration coarseMonoDiff; + Duration preciseMonoDiff; + + for (uint retryCount KJ_UNUSED: kj::zeroTo(20)) { + auto& coarseCal = systemCoarseCalendarClock(); + auto& preciseCal = systemPreciseCalendarClock(); + auto& coarseMono = systemCoarseMonotonicClock(); + auto& preciseMono = systemPreciseMonotonicClock(); + + Date coarseCalBefore = coarseCal.now(); + Date preciseCalBefore = preciseCal.now(); + TimePoint coarseMonoBefore = coarseMono.now(); + TimePoint preciseMonoBefore = preciseMono.now(); + + Duration delayTime = 150 * kj::MILLISECONDS; + delay(delayTime); + + Date coarseCalAfter = coarseCal.now(); + Date preciseCalAfter = preciseCal.now(); + TimePoint coarseMonoAfter = coarseMono.now(); + TimePoint preciseMonoAfter = preciseMono.now(); + + coarseCalDiff = coarseCalAfter - coarseCalBefore; + preciseCalDiff = preciseCalAfter - preciseCalBefore; + coarseMonoDiff = coarseMonoAfter - coarseMonoBefore; + preciseMonoDiff = preciseMonoAfter - preciseMonoBefore; + + // 20ms tolerance due to Windows timeslices being quite long (and Windows sleeps being only + // accurate to the timeslice). + if (coarseCalDiff > delayTime - 20 * kj::MILLISECONDS && + coarseCalDiff < delayTime + 20 * kj::MILLISECONDS && + preciseCalDiff > delayTime - 20 * kj::MILLISECONDS && + preciseCalDiff < delayTime + 20 * kj::MILLISECONDS && + coarseMonoDiff > delayTime - 20 * kj::MILLISECONDS && + coarseMonoDiff < delayTime + 20 * kj::MILLISECONDS && + preciseMonoDiff > delayTime - 20 * kj::MILLISECONDS && + preciseMonoDiff < delayTime + 20 * kj::MILLISECONDS) { + // success + return; + } + } + + KJ_FAIL_EXPECT("clocks seem inaccurate even after 20 tries", + coarseCalDiff / kj::MICROSECONDS, preciseCalDiff / kj::MICROSECONDS, + coarseMonoDiff / kj::MICROSECONDS, preciseMonoDiff / kj::MICROSECONDS); +} + +} // namespace +} // namespace kj diff --git a/c++/src/kj/time.c++ b/c++/src/kj/time.c++ index 5b3b433456..98ebbb1f7b 100644 --- a/c++/src/kj/time.c++ +++ b/c++/src/kj/time.c++ @@ -20,115 +20,302 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#if _WIN32 +#include "win32-api-version.h" +#endif + #include "time.h" #include "debug.h" #include -namespace kj { +#if _WIN32 +#include +#else +#include +#endif -kj::Exception Timer::makeTimeoutException() { - return KJ_EXCEPTION(OVERLOADED, "operation timed out"); -} +namespace kj { -Clock& nullClock() { +const Clock& nullClock() { class NullClock final: public Clock { public: - Date now() override { return UNIX_EPOCH; } + Date now() const override { return UNIX_EPOCH; } }; - static NullClock NULL_CLOCK; + static KJ_CONSTEXPR(const) NullClock NULL_CLOCK = NullClock(); return NULL_CLOCK; } -struct TimerImpl::Impl { - struct TimerBefore { - bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs); - }; - using Timers = std::multiset; - Timers timers; -}; +#if _WIN32 + +namespace { -class TimerImpl::TimerPromiseAdapter { +static constexpr int64_t WIN32_EPOCH_OFFSET = 116444736000000000ull; +// Number of 100ns intervals from Jan 1, 1601 to Jan 1, 1970. + +static Date toKjDate(FILETIME t) { + int64_t value = (static_cast(t.dwHighDateTime) << 32) | t.dwLowDateTime; + return (value - WIN32_EPOCH_OFFSET) * (100 * kj::NANOSECONDS) + UNIX_EPOCH; +} + +class Win32CoarseClock: public Clock { public: - TimerPromiseAdapter(PromiseFulfiller& fulfiller, TimerImpl::Impl& impl, TimePoint time) - : time(time), fulfiller(fulfiller), impl(impl) { - pos = impl.timers.insert(this); + Date now() const override { + FILETIME ft; + GetSystemTimeAsFileTime(&ft); + return toKjDate(ft); } +}; - ~TimerPromiseAdapter() { - if (pos != impl.timers.end()) { - impl.timers.erase(pos); +class Win32PreciseClock: public Clock { + typedef VOID WINAPI GetSystemTimePreciseAsFileTimeFunc(LPFILETIME); +public: + Date now() const override { + static GetSystemTimePreciseAsFileTimeFunc* const getSystemTimePreciseAsFileTimePtr = + getGetSystemTimePreciseAsFileTime(); + FILETIME ft; + if (getSystemTimePreciseAsFileTimePtr == nullptr) { + // We can't use QueryPerformanceCounter() to get any more precision because we have no way + // of knowing when the calendar clock jumps. So I guess we're stuck. + GetSystemTimeAsFileTime(&ft); + } else { + getSystemTimePreciseAsFileTimePtr(&ft); } + return toKjDate(ft); } - void fulfill() { - fulfiller.fulfill(); - impl.timers.erase(pos); - pos = impl.timers.end(); +private: + static GetSystemTimePreciseAsFileTimeFunc* getGetSystemTimePreciseAsFileTime() { + // Dynamically look up the function GetSystemTimePreciseAsFileTimeFunc(). This was only + // introduced as of Windows 8, so it might be missing. +#if __GNUC__ && !__clang__ && __GNUC__ >= 8 +// GCC 8 warns that our reinterpret_cast of a function pointer below is casting between +// incompatible types. Yes, GCC, we know that. This is the nature of GetProcAddress(); it returns +// everything as `long long int (*)()` and we have to cast to the actual type. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wcast-function-type" +#endif + return reinterpret_cast(GetProcAddress( + GetModuleHandleA("kernel32.dll"), + "GetSystemTimePreciseAsFileTime")); + } +}; + +class Win32CoarseMonotonicClock: public MonotonicClock { +public: + TimePoint now() const override { + return kj::origin() + GetTickCount64() * kj::MILLISECONDS; } +}; - const TimePoint time; +class Win32PreciseMonotonicClock: public MonotonicClock { + // Precise clock implemented using QueryPerformanceCounter(). + // + // TODO(someday): Windows 10 has QueryUnbiasedInterruptTime() and + // QueryUnbiasedInterruptTimePrecise(), a new API for monotonic timing that isn't as difficult. + // Is there any benefit to dynamically checking for these and using them if available? + +public: + TimePoint now() const override { + static const QpcProperties props; + + LARGE_INTEGER now; + QueryPerformanceCounter(&now); + uint64_t adjusted = now.QuadPart - props.origin; + uint64_t ns = mulDiv64(adjusted, 1'000'000'000, props.frequency); + return kj::origin() + ns * kj::NANOSECONDS; + } private: - PromiseFulfiller& fulfiller; - TimerImpl::Impl& impl; - Impl::Timers::const_iterator pos; + struct QpcProperties { + uint64_t origin; + // What QueryPerformanceCounter() would have returned at the time when GetTickCount64() returned + // zero. Used to ensure that the coarse and precise timers return similar values. + + uint64_t frequency; + // From QueryPerformanceFrequency(). + + QpcProperties() { + LARGE_INTEGER now, freqLi; + uint64_t ticks = GetTickCount64(); + QueryPerformanceCounter(&now); + + QueryPerformanceFrequency(&freqLi); + frequency = freqLi.QuadPart; + + // Convert the millisecond tick count into performance counter ticks. + uint64_t ticksAsQpc = mulDiv64(ticks, freqLi.QuadPart, 1000); + + origin = now.QuadPart - ticksAsQpc; + } + }; + + static inline uint64_t mulDiv64(uint64_t value, uint64_t numer, uint64_t denom) { + // Inspired by: + // https://github.com/rust-lang/rust/pull/22788/files#diff-24f054cd23f65af3b574c6ce8aa5a837R54 + // Computes (value*numer)/denom without overflow, as long as both + // (numer*denom) and the overall result fit into 64 bits. + uint64_t q = value / denom; + uint64_t r = value % denom; + return q * numer + r * numer / denom; + } }; -inline bool TimerImpl::Impl::TimerBefore::operator()( - TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) { - return lhs->time < rhs->time; -} +} // namespace -Promise TimerImpl::atTime(TimePoint time) { - return newAdaptedPromise(*impl, time); +const Clock& systemCoarseCalendarClock() { + static constexpr Win32CoarseClock clock; + return clock; +} +const Clock& systemPreciseCalendarClock() { + static constexpr Win32PreciseClock clock; + return clock; } -Promise TimerImpl::afterDelay(Duration delay) { - return newAdaptedPromise(*impl, time + delay); +const MonotonicClock& systemCoarseMonotonicClock() { + static constexpr Win32CoarseMonotonicClock clock; + return clock; +} +const MonotonicClock& systemPreciseMonotonicClock() { + static constexpr Win32PreciseMonotonicClock clock; + return clock; } -TimerImpl::TimerImpl(TimePoint startTime) - : time(startTime), impl(heap()) {} +#else -TimerImpl::~TimerImpl() noexcept(false) {} +namespace { -Maybe TimerImpl::nextEvent() { - auto iter = impl->timers.begin(); - if (iter == impl->timers.end()) { - return nullptr; - } else { - return (*iter)->time; +class PosixClock: public Clock { +public: + constexpr PosixClock(clockid_t clockId): clockId(clockId) {} + + Date now() const override { + struct timespec ts; + KJ_SYSCALL(clock_gettime(clockId, &ts)); + return UNIX_EPOCH + ts.tv_sec * kj::SECONDS + ts.tv_nsec * kj::NANOSECONDS; } -} -Maybe TimerImpl::timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max) { - return nextEvent().map([&](TimePoint nextTime) -> uint64_t { - if (nextTime <= start) return 0; +private: + clockid_t clockId; +}; - Duration timeout = nextTime - start; +class PosixMonotonicClock: public MonotonicClock { +public: + constexpr PosixMonotonicClock(clockid_t clockId): clockId(clockId) {} - uint64_t result = timeout / unit; - bool roundUp = timeout % unit > 0 * SECONDS; + TimePoint now() const override { + struct timespec ts; + KJ_SYSCALL(clock_gettime(clockId, &ts)); + return kj::origin() + ts.tv_sec * kj::SECONDS + ts.tv_nsec * kj::NANOSECONDS; + } - if (result >= max) { - return max; - } else { - return result + roundUp; - } - }); +private: + clockid_t clockId; +}; + +} // namespace + +// FreeBSD has "_PRECISE", but Linux just defaults to precise. +#ifndef CLOCK_REALTIME_PRECISE +#define CLOCK_REALTIME_PRECISE CLOCK_REALTIME +#endif + +#ifndef CLOCK_MONOTONIC_PRECISE +#define CLOCK_MONOTONIC_PRECISE CLOCK_MONOTONIC +#endif + +// FreeBSD has "_FAST", Linux has "_COARSE". +// MacOS has an "_APPROX" but only for CLOCK_MONOTONIC_RAW, which isn't helpful. +#ifndef CLOCK_REALTIME_COARSE +#ifdef CLOCK_REALTIME_FAST +#define CLOCK_REALTIME_COARSE CLOCK_REALTIME_FAST +#else +#define CLOCK_REALTIME_COARSE CLOCK_REALTIME +#endif +#endif + +#ifndef CLOCK_MONOTONIC_COARSE +#ifdef CLOCK_MONOTONIC_FAST +#define CLOCK_MONOTONIC_COARSE CLOCK_MONOTONIC_FAST +#else +#define CLOCK_MONOTONIC_COARSE CLOCK_MONOTONIC +#endif +#endif + +const Clock& systemCoarseCalendarClock() { + static constexpr PosixClock clock(CLOCK_REALTIME_COARSE); + return clock; +} +const Clock& systemPreciseCalendarClock() { + static constexpr PosixClock clock(CLOCK_REALTIME_PRECISE); + return clock; +} + +const MonotonicClock& systemCoarseMonotonicClock() { + static constexpr PosixMonotonicClock clock(CLOCK_MONOTONIC_COARSE); + return clock; +} +const MonotonicClock& systemPreciseMonotonicClock() { + static constexpr PosixMonotonicClock clock(CLOCK_MONOTONIC_PRECISE); + return clock; +} + +#endif + +CappedArray KJ_STRINGIFY(TimePoint t) { + return kj::toCharSequence(t - kj::origin()); +} +CappedArray KJ_STRINGIFY(Date d) { + return kj::toCharSequence(d - UNIX_EPOCH); } +CappedArray KJ_STRINGIFY(Duration d) { + bool negative = d < 0 * kj::SECONDS; + uint64_t ns = d / kj::NANOSECONDS; + if (negative) { + ns = -ns; + } + + auto digits = kj::toCharSequence(ns); + ArrayPtr arr = digits; -void TimerImpl::advanceTo(TimePoint newTime) { - KJ_REQUIRE(newTime >= time, "can't advance backwards in time") { return; } + size_t point; + kj::StringPtr suffix; + kj::Duration unit; + if (digits.size() > 9) { + point = arr.size() - 9; + suffix = "s"; + unit = kj::SECONDS; + } else if (digits.size() > 6) { + point = arr.size() - 6; + suffix = "ms"; + unit = kj::MILLISECONDS; + } else if (digits.size() > 3) { + point = arr.size() - 3; + suffix = "μs"; + unit = kj::MICROSECONDS; + } else { + point = arr.size(); + suffix = "ns"; + unit = kj::NANOSECONDS; + } - time = newTime; - for (;;) { - auto front = impl->timers.begin(); - if (front == impl->timers.end() || (*front)->time > time) { - break; + CappedArray result; + char* begin = result.begin(); + char* end; + if (negative) { + *begin++ = '-'; + } + if (d % unit == 0 * kj::NANOSECONDS) { + end = _::fillLimited(begin, result.end(), arr.slice(0, point), suffix); + } else { + while (arr.back() == '0') { + arr = arr.slice(0, arr.size() - 1); } - (*front)->fulfill(); + KJ_DASSERT(arr.size() > point); + end = _::fillLimited(begin, result.end(), arr.slice(0, point), "."_kj, + arr.slice(point, arr.size()), suffix); } + result.setSize(end - result.begin()); + return result; } } // namespace kj diff --git a/c++/src/kj/time.h b/c++/src/kj/time.h index 37d7b8a90e..aaf1031d7e 100644 --- a/c++/src/kj/time.h +++ b/c++/src/kj/time.h @@ -20,16 +20,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_TIME_H_ -#define KJ_TIME_H_ +#pragma once -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif - -#include "async.h" #include "units.h" #include +#include "string.h" + +KJ_BEGIN_HEADER namespace kj { namespace _ { // private @@ -38,6 +35,10 @@ class NanosecondLabel; class TimeLabel; class DateLabel; +static constexpr size_t TIME_STR_LEN = sizeof(int64_t) * 3 + 8; +// Maximum length of a stringified time. 3 digits per byte of integer, plus 8 digits to cover +// negative sign, decimal point, unit, NUL terminator, and anything else that might sneak in. + } // namespace _ (private) using Duration = Quantity; @@ -52,123 +53,67 @@ constexpr Duration HOURS = 60 * MINUTES; constexpr Duration DAYS = 24 * HOURS; using TimePoint = Absolute; -// An absolute time measured by some particular instance of `Timer`. `Time`s from two different -// `Timer`s may be measured from different origins and so are not necessarily compatible. +// An absolute time measured by some particular instance of `Timer` or `MonotonicClock`. `Time`s +// from two different `Timer`s or `MonotonicClock`s may be measured from different origins and so +// are not necessarily compatible. using Date = Absolute; // A point in real-world time, measured relative to the Unix epoch (Jan 1, 1970 00:00:00 UTC). +CappedArray KJ_STRINGIFY(TimePoint); +CappedArray KJ_STRINGIFY(Date); +CappedArray KJ_STRINGIFY(Duration); + constexpr Date UNIX_EPOCH = origin(); // The `Date` representing Jan 1, 1970 00:00:00 UTC. class Clock { // Interface to read the current date and time. public: - virtual Date now() = 0; -}; - -Clock& nullClock(); -// A clock which always returns UNIX_EPOCH as the current time. Useful when you don't care about -// time. - -class Timer { - // Interface to time and timer functionality. - // - // Each `Timer` may have a different origin, and some `Timer`s may in fact tick at a different - // rate than real time (e.g. a `Timer` could represent CPU time consumed by a thread). However, - // all `Timer`s are monotonic: time will never appear to move backwards, even if the calendar - // date as tracked by the system is manually modified. - -public: - virtual TimePoint now() = 0; - // Returns the current value of a clock that moves steadily forward, independent of any - // changes in the wall clock. The value is updated every time the event loop waits, - // and is constant in-between waits. - - virtual Promise atTime(TimePoint time) = 0; - // Returns a promise that returns as soon as now() >= time. - - virtual Promise afterDelay(Duration delay) = 0; - // Equivalent to atTime(now() + delay). - - template - Promise timeoutAt(TimePoint time, Promise&& promise) KJ_WARN_UNUSED_RESULT; - // Return a promise equivalent to `promise` but which throws an exception (and cancels the - // original promise) if it hasn't completed by `time`. The thrown exception is of type - // "OVERLOADED". - - template - Promise timeoutAfter(Duration delay, Promise&& promise) KJ_WARN_UNUSED_RESULT; - // Return a promise equivalent to `promise` but which throws an exception (and cancels the - // original promise) if it hasn't completed after `delay` from now. The thrown exception is of - // type "OVERLOADED". - -private: - static kj::Exception makeTimeoutException(); + virtual Date now() const = 0; }; -class TimerImpl final: public Timer { - // Implementation of Timer that expects an external caller -- usually, the EventPort - // implementation -- to tell it when time has advanced. +class MonotonicClock { + // Interface to read time in a way that increases as real-world time increases, independent of + // any manual changes to the calendar date/time. Such a clock never "goes backwards" even if the + // system administrator changes the calendar time or suspends the system. However, this clock's + // time points are only meaningful in comparison to other time points from the same clock, and + // cannot be used to determine the current calendar date. public: - TimerImpl(TimePoint startTime); - ~TimerImpl() noexcept(false); - - Maybe nextEvent(); - // Returns the time at which the next scheduled timer event will occur, or null if no timer - // events are scheduled. - - Maybe timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max); - // Convenience method which computes a timeout value to pass to an event-waiting system call to - // cause it to time out when the next timer event occurs. - // - // `start` is the time at which the timeout starts counting. This is typically not the same as - // now() since some time may have passed since the last time advanceTo() was called. - // - // `unit` is the time unit in which the timeout is measured. This is often MILLISECONDS. Note - // that this method will fractional values *up*, to guarantee that the returned timeout waits - // until just *after* the time the event is scheduled. - // - // The timeout will be clamped to `max`. Use this to avoid an overflow if e.g. the OS wants a - // 32-bit value or a signed value. - // - // Returns nullptr if there are no future events. - - void advanceTo(TimePoint newTime); - // Set the time to `time` and fire any at() events that have been passed. - - // implements Timer ---------------------------------------------------------- - TimePoint now() override; - Promise atTime(TimePoint time) override; - Promise afterDelay(Duration delay) override; - -private: - struct Impl; - class TimerPromiseAdapter; - TimePoint time; - Own impl; + virtual TimePoint now() const = 0; }; -// ======================================================================================= -// inline implementation details - -template -Promise Timer::timeoutAt(TimePoint time, Promise&& promise) { - return promise.exclusiveJoin(atTime(time).then([]() -> kj::Promise { - return makeTimeoutException(); - })); -} - -template -Promise Timer::timeoutAfter(Duration delay, Promise&& promise) { - return promise.exclusiveJoin(afterDelay(delay).then([]() -> kj::Promise { - return makeTimeoutException(); - })); -} +const Clock& nullClock(); +// A clock which always returns UNIX_EPOCH as the current time. Useful when you don't care about +// time. -inline TimePoint TimerImpl::now() { return time; } +const Clock& systemCoarseCalendarClock(); +const Clock& systemPreciseCalendarClock(); +// A clock that reads the real system time. +// +// In well-designed code, this should only be called by the top-level dependency injector. All +// other modules should request that the caller provide a Clock so that alternate clock +// implementations can be injected for testing, simulation, reproducibility, and other purposes. +// +// The "coarse" version has precision around 1-10ms, while the "precise" version has precision +// better than 1us. The "precise" version may be slightly slower, though on modern hardware and +// a reasonable operating system the difference is usually negligible. +// +// Note: On Windows prior to Windows 8, there is no precise calendar clock; the "precise" clock +// will be no more precise than the "coarse" clock in this case. +const MonotonicClock& systemCoarseMonotonicClock(); +const MonotonicClock& systemPreciseMonotonicClock(); +// A MonotonicClock that reads the real system time. +// +// In well-designed code, this should only be called by the top-level dependency injector. All +// other modules should request that the caller provide a Clock so that alternate clock +// implementations can be injected for testing, simulation, reproducibility, and other purposes. +// +// The "coarse" version has precision around 1-10ms, while the "precise" version has precision +// better than 1us. The "precise" version may be slightly slower, though on modern hardware and +// a reasonable operating system the difference is usually negligible. } // namespace kj -#endif // KJ_TIME_H_ +KJ_END_HEADER diff --git a/c++/src/kj/timer.c++ b/c++/src/kj/timer.c++ new file mode 100644 index 0000000000..dbd66cf3c8 --- /dev/null +++ b/c++/src/kj/timer.c++ @@ -0,0 +1,133 @@ +// Copyright (c) 2014 Google Inc. (contributed by Remy Blank ) +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "timer.h" +#include "debug.h" +#include + +namespace kj { + +kj::Exception Timer::makeTimeoutException() { + return KJ_EXCEPTION(OVERLOADED, "operation timed out"); +} + +struct TimerImpl::Impl { + struct TimerBefore { + bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) const; + }; + using Timers = std::multiset; + Timers timers; +}; + +class TimerImpl::TimerPromiseAdapter { +public: + TimerPromiseAdapter(PromiseFulfiller& fulfiller, TimerImpl::Impl& impl, TimePoint time) + : time(time), fulfiller(fulfiller), impl(impl) { + pos = impl.timers.insert(this); + } + + ~TimerPromiseAdapter() { + if (pos != impl.timers.end()) { + impl.timers.erase(pos); + } + } + + void fulfill() { + fulfiller.fulfill(); + impl.timers.erase(pos); + pos = impl.timers.end(); + } + + const TimePoint time; + +private: + PromiseFulfiller& fulfiller; + TimerImpl::Impl& impl; + Impl::Timers::const_iterator pos; +}; + +inline bool TimerImpl::Impl::TimerBefore::operator()( + TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) const { + return lhs->time < rhs->time; +} + +Promise TimerImpl::atTime(TimePoint time) { + return newAdaptedPromise(*impl, time); +} + +Promise TimerImpl::afterDelay(Duration delay) { + return newAdaptedPromise(*impl, time + delay); +} + +TimerImpl::TimerImpl(TimePoint startTime) + : time(startTime), impl(heap()) {} + +TimerImpl::~TimerImpl() noexcept(false) {} + +Maybe TimerImpl::nextEvent() { + auto iter = impl->timers.begin(); + if (iter == impl->timers.end()) { + return nullptr; + } else { + return (*iter)->time; + } +} + +Maybe TimerImpl::timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max) { + return nextEvent().map([&](TimePoint nextTime) -> uint64_t { + if (nextTime <= start) return 0; + + Duration timeout = nextTime - start; + + uint64_t result = timeout / unit; + bool roundUp = timeout % unit > 0 * SECONDS; + + if (result >= max) { + return max; + } else { + return result + roundUp; + } + }); +} + +void TimerImpl::advanceTo(TimePoint newTime) { + // On Macs running an Intel processor, it has been observed that clock_gettime + // may return non monotonic time, even when CLOCK_MONOTONIC is used. + // This workaround is to avoid the assert triggering on these machines. + // See also https://github.com/capnproto/capnproto/issues/1693 +#if __APPLE__ && (defined(__x86_64__) || defined(__POWERPC__)) + time = std::max(time, newTime); +#else + KJ_REQUIRE(newTime >= time, "can't advance backwards in time") { return; } + time = newTime; +#endif + + for (;;) { + auto front = impl->timers.begin(); + if (front == impl->timers.end() || (*front)->time > time) { + break; + } + (*front)->fulfill(); + } +} + +} // namespace kj diff --git a/c++/src/kj/timer.h b/c++/src/kj/timer.h new file mode 100644 index 0000000000..eb9443c23b --- /dev/null +++ b/c++/src/kj/timer.h @@ -0,0 +1,144 @@ +// Copyright (c) 2014 Google Inc. (contributed by Remy Blank ) +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include "async.h" + +KJ_BEGIN_HEADER + +namespace kj { + +class Timer: public MonotonicClock { + // Interface to time and timer functionality. + // + // Each `Timer` may have a different origin, and some `Timer`s may in fact tick at a different + // rate than real time (e.g. a `Timer` could represent CPU time consumed by a thread). However, + // all `Timer`s are monotonic: time will never appear to move backwards, even if the calendar + // date as tracked by the system is manually modified. + // + // That said, the `Timer` returned by `kj::setupAsyncIo().provider->getTimer()` in particular is + // guaranteed to be synchronized with the `MonotonicClock` returned by + // `systemPreciseMonotonicClock()` (or, more precisely, is updated to match that clock whenever + // the loop waits). + // + // Note that the value returned by `Timer::now()` only changes each time the + // event loop waits for I/O from the system. While the event loop is actively + // running, the time stays constant. This is intended to make behavior more + // deterministic and reproducible. However, if you need up-to-the-cycle + // accurate time, then `Timer::now()` is not appropriate. Instead, use + // `systemPreciseMonotonicClock()` directly in this case. + +public: + virtual TimePoint now() const = 0; + // Returns the current value of a clock that moves steadily forward, independent of any + // changes in the wall clock. The value is updated every time the event loop waits, + // and is constant in-between waits. + + virtual Promise atTime(TimePoint time) = 0; + // Returns a promise that returns as soon as now() >= time. + + virtual Promise afterDelay(Duration delay) = 0; + // Equivalent to atTime(now() + delay). + + template + Promise timeoutAt(TimePoint time, Promise&& promise) KJ_WARN_UNUSED_RESULT; + // Return a promise equivalent to `promise` but which throws an exception (and cancels the + // original promise) if it hasn't completed by `time`. The thrown exception is of type + // "OVERLOADED". + + template + Promise timeoutAfter(Duration delay, Promise&& promise) KJ_WARN_UNUSED_RESULT; + // Return a promise equivalent to `promise` but which throws an exception (and cancels the + // original promise) if it hasn't completed after `delay` from now. The thrown exception is of + // type "OVERLOADED". + +private: + static kj::Exception makeTimeoutException(); +}; + +class TimerImpl final: public Timer { + // Implementation of Timer that expects an external caller -- usually, the EventPort + // implementation -- to tell it when time has advanced. + +public: + TimerImpl(TimePoint startTime); + ~TimerImpl() noexcept(false); + + Maybe nextEvent(); + // Returns the time at which the next scheduled timer event will occur, or null if no timer + // events are scheduled. + + Maybe timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max); + // Convenience method which computes a timeout value to pass to an event-waiting system call to + // cause it to time out when the next timer event occurs. + // + // `start` is the time at which the timeout starts counting. This is typically not the same as + // now() since some time may have passed since the last time advanceTo() was called. + // + // `unit` is the time unit in which the timeout is measured. This is often MILLISECONDS. Note + // that this method will fractional values *up*, to guarantee that the returned timeout waits + // until just *after* the time the event is scheduled. + // + // The timeout will be clamped to `max`. Use this to avoid an overflow if e.g. the OS wants a + // 32-bit value or a signed value. + // + // Returns nullptr if there are no future events. + + void advanceTo(TimePoint newTime); + // Set the time to `time` and fire any at() events that have been passed. + + // implements Timer ---------------------------------------------------------- + TimePoint now() const override; + Promise atTime(TimePoint time) override; + Promise afterDelay(Duration delay) override; + +private: + struct Impl; + class TimerPromiseAdapter; + TimePoint time; + Own impl; +}; + +// ======================================================================================= +// inline implementation details + +template +Promise Timer::timeoutAt(TimePoint time, Promise&& promise) { + return promise.exclusiveJoin(atTime(time).then([]() -> kj::Promise { + return makeTimeoutException(); + })); +} + +template +Promise Timer::timeoutAfter(Duration delay, Promise&& promise) { + return promise.exclusiveJoin(afterDelay(delay).then([]() -> kj::Promise { + return makeTimeoutException(); + })); +} + +inline TimePoint TimerImpl::now() const { return time; } + +} // namespace kj + +KJ_END_HEADER diff --git a/c++/src/kj/tuple-test.c++ b/c++/src/kj/tuple-test.c++ index 7196968d98..5a2e606d67 100644 --- a/c++/src/kj/tuple-test.c++ +++ b/c++/src/kj/tuple-test.c++ @@ -83,6 +83,30 @@ TEST(Tuple, Tuple) { i = tuple(tuple(), 456, tuple(tuple(), tuple())); EXPECT_EQ(456u, i); + + EXPECT_EQ(0, (indexOfType>())); + EXPECT_EQ(1, (indexOfType>())); + EXPECT_EQ(2, (indexOfType>())); + EXPECT_EQ(0, (indexOfType())); +} + +TEST(Tuple, RefTuple) { + uint i = 123; + StringPtr s = "foo"; + + Tuple t = refTuple(i, s, 321, "bar"); + EXPECT_EQ(get<0>(t), 123); + EXPECT_EQ(get<1>(t), "foo"); + EXPECT_EQ(get<2>(t), 321); + EXPECT_EQ(get<3>(t), "bar"); + + i = 456; + s = "baz"; + + EXPECT_EQ(get<0>(t), 456); + EXPECT_EQ(get<1>(t), "baz"); + EXPECT_EQ(get<2>(t), 321); + EXPECT_EQ(get<3>(t), "bar"); } } // namespace kj diff --git a/c++/src/kj/tuple.h b/c++/src/kj/tuple.h index 2ea7276ec5..1351912b09 100644 --- a/c++/src/kj/tuple.h +++ b/c++/src/kj/tuple.h @@ -19,7 +19,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -// This file defines a notion of tuples that is simpler that `std::tuple`. It works as follows: +// This file defines a notion of tuples that is simpler than `std::tuple`. It works as follows: // - `kj::Tuple is the type of a tuple of an A, a B, and a C. // - `kj::tuple(a, b, c)` returns a tuple containing a, b, and c. If any of these are themselves // tuples, they are flattened, so `tuple(a, tuple(b, c), d)` is equivalent to `tuple(a, b, c, d)`. @@ -35,15 +35,12 @@ // - It is illegal for an element of `Tuple` to be a reference, due to problems this would cause // with type inference and `tuple()`. -#ifndef KJ_TUPLE_H_ -#define KJ_TUPLE_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "common.h" +KJ_BEGIN_HEADER + namespace kj { namespace _ { // private @@ -101,12 +98,10 @@ struct TupleElement { template struct TupleElement { - // If tuples contained references, one of the following would have to be true: - // - `auto x = tuple(y, z)` would cause x to be a tuple of references to y and z, which is - // probably not what you expected. - // - `Tuple x = tuple(a, b)` would not work, because `tuple()` returned - // Tuple. - static_assert(sizeof(T*) == 0, "Sorry, tuples cannot contain references."); + // A tuple containing references can be constructed using refTuple(). + + T& value; + constexpr inline TupleElement(T& value): value(value) {} }; template @@ -126,6 +121,8 @@ struct TupleImpl, Types...> static_assert(sizeof...(indexes) == sizeof...(Types), "Incorrect use of TupleImpl."); + TupleImpl() = default; + template inline TupleImpl(Params&&... params) : TupleElement(kj::fwd(params))... { @@ -137,7 +134,7 @@ struct TupleImpl, Types...> template constexpr inline TupleImpl(Tuple&& other) - : TupleElement(kj::mv(getImpl(other)))... {} + : TupleElement(kj::fwd(getImpl(other)))... {} template constexpr inline TupleImpl(Tuple& other) : TupleElement(getImpl(other))... {} @@ -147,12 +144,15 @@ struct TupleImpl, Types...> }; struct MakeTupleFunc; +struct MakeRefTupleFunc; template class Tuple { // The actual Tuple class (used for tuples of size other than 1). public: + Tuple() = default; + template constexpr inline Tuple(Tuple&& other): impl(kj::mv(other)) {} template @@ -173,6 +173,7 @@ class Tuple { template friend inline const TypeByIndex& getImpl(const Tuple& tuple); friend struct MakeTupleFunc; + friend struct MakeRefTupleFunc; }; template <> @@ -314,6 +315,17 @@ struct MakeTupleFunc { } }; +struct MakeRefTupleFunc { + template + Tuple operator()(Params&&... params) { + return Tuple(kj::fwd(params)...); + } + template + Param operator()(Param&& param) { + return kj::fwd(param); + } +}; + } // namespace _ (private) template struct Tuple_ { typedef _::Tuple Type; }; @@ -336,6 +348,14 @@ inline auto tuple(Params&&... params) return _::expandAndApply(_::MakeTupleFunc(), kj::fwd(params)...); } +template +inline auto refTuple(Params&&... params) + -> decltype(_::expandAndApply(_::MakeRefTupleFunc(), kj::fwd(params)...)) { + // Like tuple(), but if the params include lvalue references, they will be captured as + // references. rvalue references will still be captured as whole values (moved). + return _::expandAndApply(_::MakeRefTupleFunc(), kj::fwd(params)...); +} + template inline auto get(Tuple&& tuple) -> decltype(_::getImpl(kj::fwd(tuple))) { // Unpack and return the tuple element at the given index. The index is specified as a template @@ -359,6 +379,65 @@ template constexpr size_t tupleSize() { return TupleSize_::size; } // Returns size of the tuple T. +template +struct IndexOfType_; +template +struct HasType_ { + static constexpr bool value = false; +}; + +template +struct IndexOfType_ { + static constexpr size_t value = 0; +}; +template +struct HasType_ { + static constexpr bool value = true; +}; + +template +struct IndexOfType_> { + static constexpr size_t value = 0; + static_assert(!HasType_>::value, + "requested type appears multiple times in tuple"); +}; +template +struct HasType_> { + static constexpr bool value = true; +}; + +template +struct IndexOfType_> { + static constexpr size_t value = IndexOfType_>::value + 1; +}; +template +struct HasType_> { + static constexpr bool value = HasType_>::value; +}; + +template +inline constexpr size_t indexOfType() { + static_assert(HasType_::value, "type not present"); + return IndexOfType_::value; +} + +template +struct TypeOfIndex_; +template +struct TypeOfIndex_<0, T> { + typedef T Type; +}; +template +struct TypeOfIndex_> + : public TypeOfIndex_> {}; +template +struct TypeOfIndex_<0, _::Tuple> { + typedef T Type; +}; + +template +using TypeOfIndex = typename TypeOfIndex_::Type; + } // namespace kj -#endif // KJ_TUPLE_H_ +KJ_END_HEADER diff --git a/c++/src/kj/units-test.c++ b/c++/src/kj/units-test.c++ index 892c1d3986..31a2973793 100644 --- a/c++/src/kj/units-test.c++ +++ b/c++/src/kj/units-test.c++ @@ -341,7 +341,7 @@ TEST(UnitMeasure, BoundedMinMax) { assertTypeAndValue(boundedValue<4,t1>(2), kj::min(bounded<4>(), boundedValue<5,t1>(2))); assertTypeAndValue(boundedValue<4,t1>(2), kj::min(boundedValue<5,t1>(2), bounded<4>())); - // These two are degenerate cases. Currently they fail to compile but meybe they shouldn't? + // These two are degenerate cases. Currently they fail to compile but maybe they shouldn't? // assertTypeAndValue(bounded<5>(), kj::max(boundedValue<4,t2>(3), bounded<5>())); // assertTypeAndValue(bounded<5>(), kj::max(bounded<5>(), boundedValue<4,t2>(3))); diff --git a/c++/src/kj/units.h b/c++/src/kj/units.h index 297e477b9b..530abafbee 100644 --- a/c++/src/kj/units.h +++ b/c++/src/kj/units.h @@ -23,16 +23,13 @@ // time, but should then be optimized down to basic primitives (usually, integers) by the // compiler. -#ifndef KJ_UNITS_H_ -#define KJ_UNITS_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "common.h" #include +KJ_BEGIN_HEADER + namespace kj { // ======================================================================================= @@ -77,19 +74,6 @@ class Bounded; template class BoundedConst; -template constexpr bool isIntegral() { return false; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } -template <> constexpr bool isIntegral() { return true; } - template struct IsIntegralOrBounded_ { static constexpr bool value = isIntegral(); }; template @@ -379,7 +363,7 @@ class Quantity { template friend class Quantity; - template + template friend inline constexpr auto operator*(Number1 a, Quantity b) -> Quantity; }; @@ -399,7 +383,8 @@ inline constexpr auto unit() -> decltype(Unit_::get()) { return Unit_::get // unit>() returns a Quantity of value 1. It also, intentionally, works on basic // numeric types. -template +template ()>> inline constexpr auto operator*(Number1 a, Quantity b) -> Quantity { return Quantity(a * b.value, unsafe); @@ -430,6 +415,13 @@ class Absolute { // units, which is actually totally logical and kind of neat. public: + inline constexpr Absolute(MaxValue_): value(maxValue) {} + inline constexpr Absolute(MinValue_): value(minValue) {} + // Allow initialization from maxValue and minValue. + // TODO(msvc): decltype(maxValue) and decltype(minValue) deduce unknown-type for these function + // parameters, causing the compiler to complain of a duplicate constructor definition, so we + // specify MaxValue_ and MinValue_ types explicitly. + inline constexpr Absolute operator+(const T& other) const { return Absolute(value + other); } inline constexpr Absolute operator-(const T& other) const { return Absolute(value - other); } inline constexpr T operator-(const Absolute& other) const { return value - other.value; } @@ -1045,14 +1037,14 @@ inline constexpr T unboundAs(U value) { template inline constexpr T unboundMax(Bounded value) { - // Explicitly ungaurd expecting a value that is at most `maxN`. + // Explicitly unguard expecting a value that is at most `maxN`. static_assert(maxN <= requestedMax, "possible overflow detected"); return value.unwrap(); } template inline constexpr uint unboundMax(BoundedConst) { - // Explicitly ungaurd expecting a value that is at most `maxN`. + // Explicitly unguard expecting a value that is at most `maxN`. static_assert(value <= requestedMax, "overflow detected"); return value; } @@ -1060,7 +1052,7 @@ inline constexpr uint unboundMax(BoundedConst) { template inline constexpr auto unboundMaxBits(T value) -> decltype(unboundMax()>(value)) { - // Explicitly ungaurd expecting a value that fits into `bits` bits. + // Explicitly unguard expecting a value that fits into `bits` bits. return unboundMax()>(value); } @@ -1175,4 +1167,4 @@ inline constexpr Range, Unit>> } // namespace kj -#endif // KJ_UNITS_H_ +KJ_END_HEADER diff --git a/c++/src/kj/vector.h b/c++/src/kj/vector.h index 44613f3331..d072448fc6 100644 --- a/c++/src/kj/vector.h +++ b/c++/src/kj/vector.h @@ -19,15 +19,12 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_VECTOR_H_ -#define KJ_VECTOR_H_ - -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#pragma once #include "array.h" +KJ_BEGIN_HEADER + namespace kj { template @@ -43,25 +40,27 @@ class Vector { public: inline Vector() = default; inline explicit Vector(size_t capacity): builder(heapArrayBuilder(capacity)) {} + inline Vector(Array&& array): builder(kj::mv(array)) {} - inline operator ArrayPtr() { return builder; } - inline operator ArrayPtr() const { return builder; } - inline ArrayPtr asPtr() { return builder.asPtr(); } - inline ArrayPtr asPtr() const { return builder.asPtr(); } + inline operator ArrayPtr() KJ_LIFETIMEBOUND { return builder; } + inline operator ArrayPtr() const KJ_LIFETIMEBOUND { return builder; } + inline ArrayPtr asPtr() KJ_LIFETIMEBOUND { return builder.asPtr(); } + inline ArrayPtr asPtr() const KJ_LIFETIMEBOUND { return builder.asPtr(); } inline size_t size() const { return builder.size(); } inline bool empty() const { return size() == 0; } inline size_t capacity() const { return builder.capacity(); } - inline T& operator[](size_t index) const { return builder[index]; } - - inline const T* begin() const { return builder.begin(); } - inline const T* end() const { return builder.end(); } - inline const T& front() const { return builder.front(); } - inline const T& back() const { return builder.back(); } - inline T* begin() { return builder.begin(); } - inline T* end() { return builder.end(); } - inline T& front() { return builder.front(); } - inline T& back() { return builder.back(); } + inline T& operator[](size_t index) KJ_LIFETIMEBOUND { return builder[index]; } + inline const T& operator[](size_t index) const KJ_LIFETIMEBOUND { return builder[index]; } + + inline const T* begin() const KJ_LIFETIMEBOUND { return builder.begin(); } + inline const T* end() const KJ_LIFETIMEBOUND { return builder.end(); } + inline const T& front() const KJ_LIFETIMEBOUND { return builder.front(); } + inline const T& back() const KJ_LIFETIMEBOUND { return builder.back(); } + inline T* begin() KJ_LIFETIMEBOUND { return builder.begin(); } + inline T* end() KJ_LIFETIMEBOUND { return builder.end(); } + inline T& front() KJ_LIFETIMEBOUND { return builder.front(); } + inline T& back() KJ_LIFETIMEBOUND { return builder.back(); } inline Array releaseAsArray() { // TODO(perf): Avoid a copy/move by allowing Array to point to incomplete space? @@ -71,8 +70,20 @@ class Vector { return builder.finish(); } + template + inline bool operator==(const U& other) const { return asPtr() == other; } + template + inline bool operator!=(const U& other) const { return asPtr() != other; } + + inline ArrayPtr slice(size_t start, size_t end) KJ_LIFETIMEBOUND { + return asPtr().slice(start, end); + } + inline ArrayPtr slice(size_t start, size_t end) const KJ_LIFETIMEBOUND { + return asPtr().slice(start, end); + } + template - inline T& add(Params&&... params) { + inline T& add(Params&&... params) KJ_LIFETIMEBOUND { if (builder.isFull()) grow(); return builder.add(kj::fwd(params)...); } @@ -103,9 +114,7 @@ class Vector { } inline void clear() { - while (builder.size() > 0) { - builder.removeLast(); - } + builder.clear(); } inline void truncate(size_t size) { @@ -114,7 +123,7 @@ class Vector { inline void reserve(size_t size) { if (size > builder.capacity()) { - setCapacity(size); + grow(size); } } @@ -141,4 +150,4 @@ inline auto KJ_STRINGIFY(const Vector& v) -> decltype(toCharSequence(v.asPtr( } // namespace kj -#endif // KJ_VECTOR_H_ +KJ_END_HEADER diff --git a/c++/src/kj/win32-api-version.h b/c++/src/kj/win32-api-version.h new file mode 100644 index 0000000000..31d34198f9 --- /dev/null +++ b/c++/src/kj/win32-api-version.h @@ -0,0 +1,44 @@ +// Copyright (c) 2013-2017 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +// Request Vista-level APIs. +#ifndef WINVER +#define WINVER 0x0600 +#elif WINVER < 0x0600 +#error "WINVER defined but older than Vista" +#endif + +#ifndef _WIN32_WINNT +#define _WIN32_WINNT 0x0600 +#elif _WIN32_WINNT < 0x0600 +#error "_WIN32_WINNT defined but older than Vista" +#endif + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN // ::eyeroll:: +#endif + +#define NOSERVICE 1 +#define NOMCX 1 +#define NOIME 1 +#define NOMINMAX 1 diff --git a/c++/src/kj/windows-sanity.h b/c++/src/kj/windows-sanity.h index 766ba2cbd6..b2c93678d6 100644 --- a/c++/src/kj/windows-sanity.h +++ b/c++/src/kj/windows-sanity.h @@ -1,4 +1,4 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Copyright (c) 2013-2018 Sandstorm Development Group, Inc. and contributors // Licensed under the MIT License: // // Permission is hereby granted, free of charge, to any person obtaining a copy @@ -19,23 +19,53 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#ifndef KJ_WINDOWS_SANITY_H_ -#define KJ_WINDOWS_SANITY_H_ +// This file replaces poorly-named #defines from windows.h with properly-namespaced versions, so +// that they no longer conflict with similarly-named identifiers in other namespaces. +// +// This file must be #included some time after windows.h has been #included but before any attempt +// to use the names for other purposes. However, this can be difficult to determine in header +// files. Typically KJ / Cap'n Proto headers avoid including windows.h at all, but may use +// conflicting identifiers. In order to relieve application developers from the need to include +// windows-sanity.h themselves, we would like these headers to conditionally apply the fixes if +// and only if windows.h was already included. Therefore, this header checks if windows.h has been +// included and only applies fixups if this is the case. Furthermore, this header is designed such +// that it can be included multiple times, and the fixups will be applied the first time it is +// included *after* windows.h. +// +// Now, as long as any headers which need to use conflicting identifier names be sure to #include +// windows-sanity.h, we can be sure that no conflicts will occur regardless of in what order the +// application chooses to include these headers vs. windows.h. -#if defined(__GNUC__) && !KJ_HEADER_WARNINGS -#pragma GCC system_header -#endif +#if !_WIN32 && !__CYGWIN__ -#ifndef _INC_WINDOWS -#error "windows.h needs to be included before kj/windows-sanity.h (or perhaps you don't need either?)" -#endif +// Not on Windows. Tell the compiler never to try to include this again. +#pragma once + +#elif defined(_INC_WINDOWS) -namespace win32 { +// We're on Windows and windows.h has been included. We need to fixup the namespace. We only need +// to do this once, but we can't do it until windows.h has been included. Since that has happened +// now, we use `#pragma once` to tell the compiler never to include this file again. +#pragma once + +namespace kj_win32_workarounds { + // Namespace containing constant definitions intended to replace constants that are defined as + // macros in the Windows headers. Do not refer to this namespace directly, we'll import it into + // the global scope below. + +#ifdef ERROR // This could be absent if e.g. NOGDI was used. const auto ERROR_ = ERROR; #undef ERROR const auto ERROR = ERROR_; +#endif + + typedef VOID VOID_; +#undef VOID + typedef VOID_ VOID; } -using win32::ERROR; +// Pull our constant definitions into the global namespace -- but only if they don't already exist +// in the global namespace. +using namespace kj_win32_workarounds; -#endif // KJ_WINDOWS_SANITY_H_ +#endif diff --git a/c++/valgrind.supp b/c++/valgrind.supp new file mode 100644 index 0000000000..e4a0e140e0 --- /dev/null +++ b/c++/valgrind.supp @@ -0,0 +1,11 @@ +{ + + Memcheck:Addr8 + fun:check_free + fun:free_key_mem + fun:__dlerror_main_freeres + fun:__libc_freeres + fun:_vgnU_freeres + fun:_ZN2kj22TopLevelProcessContext4exitEv +} + diff --git a/doc/README.md b/doc/README.md index ad3eea43b3..96608b04ac 100644 --- a/doc/README.md +++ b/doc/README.md @@ -7,9 +7,10 @@ Start by installing ruby1.9.1-dev. On Debian-based operating systems: sudo apt-get install ruby-dev -Then install Jekyll: +Then install Jekyll 3.8.1 (Jekyll 4.x will NOT work due as they removed Pygments support): - sudo gem install jekyll pygments.rb + sudo gem install jekyll -v 3.8.1 + sudo gem install pygments.rb Now install Pygments and SetupTools to be able to install the CapnProto lexer. On Debian based operating systems: @@ -24,7 +25,7 @@ Next, install the custom Pygments syntax highlighter: Now you can launch a local server: - jekyll serve --watch + jekyll _3.8.1_ serve --watch Edit, test, commit. diff --git a/doc/_includes/buttons.html b/doc/_includes/buttons.html index 7a7ec7567f..66ec6248ca 100644 --- a/doc/_includes/buttons.html +++ b/doc/_includes/buttons.html @@ -1,6 +1,6 @@
Develop +href="https://github.com/capnproto/capnproto">Develop Discuss - - - - diff --git a/doc/_includes/header.html b/doc/_includes/header.html index f249d01730..e835a3e323 100644 --- a/doc/_includes/header.html +++ b/doc/_includes/header.html @@ -26,7 +26,7 @@
Discuss on Groups - View on GitHub + View on GitHub {% if page.title != "Introduction" %}{% endif %} diff --git a/doc/_layouts/slides.html b/doc/_layouts/slides.html index b3ec20c9b2..f04cf0d08d 100644 --- a/doc/_layouts/slides.html +++ b/doc/_layouts/slides.html @@ -38,17 +38,5 @@

Kenton Varda

- - - - diff --git a/doc/_plugins/capnp_lexer.py b/doc/_plugins/capnp_lexer.py index 0d9a8b1c90..f721fa58cf 100755 --- a/doc/_plugins/capnp_lexer.py +++ b/doc/_plugins/capnp_lexer.py @@ -15,7 +15,7 @@ class CapnpLexer(RegexLexer): (r'=', Literal, 'expression'), (r':', Name.Class, 'type'), (r'\$', Name.Attribute, 'annotation'), - (r'(struct|enum|interface|union|import|using|const|annotation|extends|in|of|on|as|with|from|fixed)\b', + (r'(struct|enum|interface|union|import|using|const|annotation|extends|in|of|on|as|with|from|fixed|bulk|realtime)\b', Token.Keyword), (r'[a-zA-Z0-9_.]+', Token.Name), (r'[^#@=:$a-zA-Z0-9_]+', Text), diff --git a/doc/_posts/2013-12-12-capnproto-0.4-time-travel.md b/doc/_posts/2013-12-12-capnproto-0.4-time-travel.md index a895a9c542..42f4bfef21 100644 --- a/doc/_posts/2013-12-12-capnproto-0.4-time-travel.md +++ b/doc/_posts/2013-12-12-capnproto-0.4-time-travel.md @@ -36,14 +36,14 @@ is just talking about implementing a promise API in C++. Pipelining is another that. Please [see the RPC page]({{ site.baseurl }}rpc.html) if you want to know more about pipelining._ -If you do a lot of serious Javascript programming, you've probably heard of +If you do a lot of serious JavaScript programming, you've probably heard of [Promises/A+](http://promisesaplus.com/) and similar proposals. Cap'n Proto RPC introduces a similar construct in C++. In fact, the API is nearly identical, and its semantics are nearly identical. Compare with -[Domenic Denicola's Javascript example](http://domenic.me/2012/10/14/youre-missing-the-point-of-promises/): +[Domenic Denicola's JavaScript example](http://domenic.me/2012/10/14/youre-missing-the-point-of-promises/): {% highlight c++ %} -// C++ version of Domenic's Javascript promises example. +// C++ version of Domenic's JavaScript promises example. getTweetsFor("domenic") // returns a promise .then([](vector tweets) { auto shortUrls = parseTweetsForUrls(tweets); @@ -63,10 +63,10 @@ getTweetsFor("domenic") // returns a promise {% endhighlight %} This is C++, but it is no more lines -- nor otherwise more complex -- than the equivalent -Javascript. We're doing several I/O operations, we're doing them asynchronously, and we don't +JavaScript. We're doing several I/O operations, we're doing them asynchronously, and we don't have a huge unreadable mess of callback functions. Promises are based on event loop concurrency, which means you can perform concurrent operations with shared state without worrying about mutex -locking -- i.e., the Javascript model. (Of course, if you really want threads, you can run +locking -- i.e., the JavaScript model. (Of course, if you really want threads, you can run multiple event loops in multiple threads and make inter-thread RPC calls between them.) [More on C++ promises.]({{ site.baseurl }}cxxrpc.html#kj_concurrency_framework) diff --git a/doc/_posts/2014-03-11-capnproto-0.4.1-bugfixes.md b/doc/_posts/2014-03-11-capnproto-0.4.1-bugfixes.md index 31e73013ce..777bf22455 100644 --- a/doc/_posts/2014-03-11-capnproto-0.4.1-bugfixes.md +++ b/doc/_posts/2014-03-11-capnproto-0.4.1-bugfixes.md @@ -25,6 +25,6 @@ In the meantime, though, there have been some major updates from the community: C++ and Python), and the second language to implement it from the ground up (Python just wraps the C++ implementation). Check out author [David Renshaw](https://github.com/dwrensha)'s [talk at Mozilla](https://air.mozilla.org/rust-meetup-february-2014/). - * A [Javascript port](https://github.com/jscheid/capnproto-js) has appeared, but it needs help + * A [JavaScript port](https://github.com/jscheid/capnproto-js) has appeared, but it needs help to keep going! diff --git a/doc/_posts/2014-06-17-capnproto-flatbuffers-sbe.md b/doc/_posts/2014-06-17-capnproto-flatbuffers-sbe.md index 746ec21717..68ce6fb13c 100644 --- a/doc/_posts/2014-06-17-capnproto-flatbuffers-sbe.md +++ b/doc/_posts/2014-06-17-capnproto-flatbuffers-sbe.md @@ -98,7 +98,7 @@ The down side of reflection is that it is generally very slow (compared to gener When building a message, depending on how your code is organized, it may be convenient to have flexibility in the order in which you fill in the data. If that flexibility is missing, you may find you have to do extra bookkeeping to store data off to the side until its time comes to be added to the message. -Protocol Buffers is natually completely flexible in terms of initialization order because the mesasge is being built on the heap. There is no reason to impose restrictions. (Although, the C++ Protobuf library heavily encourages top-down building.) +Protocol Buffers is naturally completely flexible in terms of initialization order because the message is being built on the heap. There is no reason to impose restrictions. (Although, the C++ Protobuf library heavily encourages top-down building.) All the zero-copy systems, though, have to use some form of arena allocation to make sure that the message is built in a contiguous block of memory that can be written out all at once. So, things get more complicated. diff --git a/doc/_posts/2015-01-23-capnproto-0.5.1-bugfixes.md b/doc/_posts/2015-01-23-capnproto-0.5.1-bugfixes.md index 345d756672..52f1143338 100644 --- a/doc/_posts/2015-01-23-capnproto-0.5.1-bugfixes.md +++ b/doc/_posts/2015-01-23-capnproto-0.5.1-bugfixes.md @@ -13,4 +13,4 @@ Cap'n Proto 0.5.1 has just been released with some bug fixes: Sorry about the bugs. -In other news, as you can see, the Cap'n Proto web site now lives at `capnproto.org`. Additionally, the Github repo has been moved to the [Sandstorm.io organization](https://github.com/sandstorm-io). Both moves have left behind redirects so that old links / repository references should continue to work. +In other news, as you can see, the Cap'n Proto web site now lives at `capnproto.org`. Additionally, the Github repo has been moved to the [Sandstorm.io organization](https://github.com/capnproto). Both moves have left behind redirects so that old links / repository references should continue to work. diff --git a/doc/_posts/2015-03-02-security-advisory-and-integer-overflow-protection.md b/doc/_posts/2015-03-02-security-advisory-and-integer-overflow-protection.md index 3b7a3c6470..eaa2b16158 100644 --- a/doc/_posts/2015-03-02-security-advisory-and-integer-overflow-protection.md +++ b/doc/_posts/2015-03-02-security-advisory-and-integer-overflow-protection.md @@ -6,11 +6,11 @@ author: kentonv As the installation page has always stated, I do not yet recommend using Cap'n Proto's C++ library for handling possibly-malicious input, and will not recommend it until it undergoes a formal security review. That said, security is obviously a high priority for the project. The security of Cap'n Proto is in fact essential to the security of [Sandstorm.io](https://sandstorm.io), Cap'n Proto's parent project, in which sandboxed apps communicate with each other and the platform via Cap'n Proto RPC. -A few days ago, the first major security bugs were found in Cap'n Proto C++ -- two by security guru [Ben Laurie](http://en.wikipedia.org/wiki/Ben_Laurie) and one by myself during subsequent review (see below). You can read details about each bug in our new [security advisories directory](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories): +A few days ago, the first major security bugs were found in Cap'n Proto C++ -- two by security guru [Ben Laurie](http://en.wikipedia.org/wiki/Ben_Laurie) and one by myself during subsequent review (see below). You can read details about each bug in our new [security advisories directory](https://github.com/capnproto/capnproto/tree/master/security-advisories): -* [Integer overflow in pointer validation.](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md) -* [Integer underflow in pointer validation.](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md) -* [CPU usage amplification attack.](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-2-all-cpu-amplification.md) +* [Integer overflow in pointer validation.](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md) +* [Integer underflow in pointer validation.](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md) +* [CPU usage amplification attack.](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-2-all-cpu-amplification.md) I have backported the fixes to the last two release branches -- 0.5 and 0.4: @@ -50,7 +50,7 @@ In the past, C and C++ code has been plagued by buffer overrun bugs, but these d But developing a similar sense for integer overflow is hard. We do arithmetic in code all the time, and the vast majority of it isn't an issue. The few places where overflow can happen all too easily go unnoticed. -And by the way, integer overflow affects many memory-safe languages too! Java and C# don't protect against overflow. Python does, using slow arbitrary-precision integers. Javascript doesn't use integers, and is instead succeptible to loss-of-precision bugs, which can have similar (but more subtle) consequences. +And by the way, integer overflow affects many memory-safe languages too! Java and C# don't protect against overflow. Python does, using slow arbitrary-precision integers. JavaScript doesn't use integers, and is instead succeptible to loss-of-precision bugs, which can have similar (but more subtle) consequences. While writing Cap'n Proto, I made sure to think carefully about overflow and managed to correct for it most of the time. On learning that I missed a case, I immediately feared that I might have missed many more, and wondered how I might go about systematically finding them. @@ -109,15 +109,15 @@ So, a `Guarded<10, int>` represents a `int` which is statically guaranteed to ho Moreover, because all of `Guarded`'s operators are inline and `constexpr`, a good optimizing compiler will be able to optimize `Guarded` down to the underlying primitive integer type. So, in theory, using `Guarded` has no runtime overhead. (I have not yet verified that real compilers get this right, but I suspect they do.) -Of course, the full implementation is considerably more complicated than this. The code has not been merged into the Cap'n Proto tree yet as we need to do more analysis to make sure it has no negative impact. For now, you can find it in the [overflow-safe](https://github.com/sandstorm-io/capnproto/tree/overflow-safe) branch, specifically in the second half of [kj/units.h](https://github.com/sandstorm-io/capnproto/blob/overflow-safe/c++/src/kj/units.h). (This header also contains metaprogramming for compile-time unit analysis, which Cap'n Proto has been using since its first release.) +Of course, the full implementation is considerably more complicated than this. The code has not been merged into the Cap'n Proto tree yet as we need to do more analysis to make sure it has no negative impact. For now, you can find it in the [overflow-safe](https://github.com/capnproto/capnproto/tree/overflow-safe) branch, specifically in the second half of [kj/units.h](https://github.com/capnproto/capnproto/blob/overflow-safe/c++/src/kj/units.h). (This header also contains metaprogramming for compile-time unit analysis, which Cap'n Proto has been using since its first release.) ### Results I switched Cap'n Proto's core pointer validation code (`capnp/layout.c++`) over to `Guarded`. In the process, I found: * Several overflows that could be triggered by the application calling methods with invalid parameters, but not by a remote attacker providing invalid message data. We will change the code to check these in the future, but they are not critical security problems. -* The overflow that Ben had already reported ([2015-03-02-0](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md)). I had intentionally left this unfixed during my analysis to verify that `Guarded` would catch it. -* One otherwise-undiscovered integer underflow ([2015-03-02-1](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md)). +* The overflow that Ben had already reported ([2015-03-02-0](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md)). I had intentionally left this unfixed during my analysis to verify that `Guarded` would catch it. +* One otherwise-undiscovered integer underflow ([2015-03-02-1](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md)). Based on these results, I conclude that `Guarded` is in fact effective at finding overflow bugs, and that such bugs are thankfully _not_ endemic in Cap'n Proto's code. diff --git a/doc/_posts/2015-03-05-another-cpu-amplification.md b/doc/_posts/2015-03-05-another-cpu-amplification.md index d074c9db80..4c8ae97591 100644 --- a/doc/_posts/2015-03-05-another-cpu-amplification.md +++ b/doc/_posts/2015-03-05-another-cpu-amplification.md @@ -8,7 +8,7 @@ Unfortunately, it turns out that our fix for one of [the security advisories iss Fortunately, the incomplete fix is for the non-critical vulnerability. The worst case is that an attacker could consume excessive CPU time. -Nevertheless, we've issued [a new advisory](https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md) and pushed a new release: +Nevertheless, we've issued [a new advisory](https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md) and pushed a new release: - Release 0.5.1.2: [source](https://capnproto.org/capnproto-c++-0.5.1.2.tar.gz), [win32](https://capnproto.org/capnproto-c++-win32-0.5.1.2.zip) - Release 0.4.1.2: [source](https://capnproto.org/capnproto-c++-0.4.1.2.tar.gz) diff --git a/doc/_posts/2017-05-01-capnproto-0.6-msvc-json-http-more.md b/doc/_posts/2017-05-01-capnproto-0.6-msvc-json-http-more.md index bd36cbbace..336f0d838f 100644 --- a/doc/_posts/2017-05-01-capnproto-0.6-msvc-json-http-more.md +++ b/doc/_posts/2017-05-01-capnproto-0.6-msvc-json-http-more.md @@ -40,7 +40,7 @@ The 0.6 release includes a number of measures designed to harden Cap'n Proto's C Cap'n Proto messages can now be converted to and from JSON using `libcapnp-json`. This makes it easy to integrate your JSON front-end API with your Cap'n Proto back-end. -See the capnp/compat/json.h header for API details. +See the capnp/compat/json.h header for API details. This library was primarily built by [**Kamal Marhubi**](https://github.com/kamalmarhubi) and [**Branislav Katreniak**](https://github.com/katreniak), using Cap'n Proto's [dynamic API]({{site.baseurl}}cxx.html#dynamic-reflection). @@ -48,7 +48,7 @@ This library was primarily built by [**Kamal Marhubi**](https://github.com/kamal KJ (the C++ framework library bundled with Cap'n Proto) now ships with a minimalist HTTP library, `libkj-http`. The library is based on the KJ asynchronous I/O framework and covers both client-side and server-side use cases. Although functional and used in production today, the library should be considered a work in progress -- expect improvements in future releases, such as client connection pooling and TLS support. -See the kj/compat/http.h header for API details. +See the kj/compat/http.h header for API details. #### Smaller things diff --git a/doc/_posts/2018-08-28-capnproto-0.7.md b/doc/_posts/2018-08-28-capnproto-0.7.md new file mode 100644 index 0000000000..2e4a94a808 --- /dev/null +++ b/doc/_posts/2018-08-28-capnproto-0.7.md @@ -0,0 +1,39 @@ +--- +layout: post +title: "Cap'n Proto 0.7 Released" +author: kentonv +--- + + + +Today we're releasing Cap'n Proto 0.7. + +### As used in Cloudflare Workers + +The biggest high-level development in Cap'n Proto since the last release is its use in the implementation of [Cloudflare Workers](https://blog.cloudflare.com/cloudflare-workers-unleashed/) (of which I am the tech lead). + +Cloudflare operates a global network of 152 datacenters and growing, and Cloudflare Workers allows you to deploy "serveless" JavaScript to all of those locations in under 30 seconds. Your code is written against the W3C standard [Service Workers API](https://developer.mozilla.org/en-US/docs/Web/API/Service_Worker_API) and handles HTTP traffic for your web site. + +The Cloudflare Workers runtime implementation is written in C++, leveraging the V8 JavaScript engine and libKJ, the C++ toolkit library distributed with Cap'n Proto. + +Cloudflare Workers are all about handling HTTP traffic, and the runtime uses KJ's HTTP library to do it. This means the KJ HTTP library is now battle-tested in production. Every package downloaded from [npm](https://npmjs.org), for example, passes through KJ's HTTP client and server libraries on the way (since npm uses Workers). + +The Workers runtime makes heavy use of KJ, but so far only makes light use of Cap'n Proto serialization. Cap'n Proto is used as a format for distributing configuration as well as (ironically) to handle JSON. We anticipate, however, making deeper use of Cap'n Proto in the future, including RPC. + +### What else is new? + +* The C++ library now requires C++14 or newer. It requires GCC 4.9+, Clang 3.6+, or Microsoft Visual Studio 2017. This change allows us to make faster progress and provide cleaner APIs by utilizing newer language features. +* The JSON parser now supports [annotations to customize conversion behavior](https://github.com/capnproto/capnproto/blob/master/c++/src/capnp/compat/json.capnp). These allow you to override field names (e.g. to use underscores instead of camelCase), flatten sub-objects, and express unions in various more-idiomatic ways. +* The KJ HTTP library supports WebSockets, and has generally become much higher-quality as it has been battle-tested in Cloudflare Workers. +* KJ now offers its own [hashtable- and b-tree-based container implementations](https://github.com/capnproto/capnproto/blob/master/c++/src/kj/map.h). `kj::HashMap` is significantly faster and more memory-efficient than `std::unordered_map`, with more optimizations coming. `kj::TreeMap` is somewhat slower than `std::map`, but uses less memory and has a smaller code footprint. Both are implemented on top of `kj::Table`, a building block that can also support multi-maps. Most importantly, all these interfaces are cleaner and more modern than their ancient STL counterparts. +* KJ now includes [TLS bindings](https://github.com/capnproto/capnproto/blob/master/c++/src/kj/compat/tls.h). `libkj-tls` wraps OpenSSL or BoringSSL and provides a simple, hard-to-mess-up API integrated with the KJ event loop. +* KJ now includes [gzip bindings](https://github.com/capnproto/capnproto/blob/master/c++/src/kj/compat/gzip.h), which wrap zlib in KJ stream interfaces (sync and async). +* KJ now includes [helpers for encoding/decoding Unicode (UTF-8/UTF-16/UTF-32), base64, hex, URI-encoding, and C-escaped text](https://github.com/capnproto/capnproto/blob/master/c++/src/kj/encoding.h). +* The [`kj::Url` helper class](https://github.com/capnproto/capnproto/blob/master/c++/src/kj/compat/url.h) is provided to parse and compose URLs. +* KJ now includes [a filesystem API](https://github.com/capnproto/capnproto/blob/master/c++/src/kj/filesystem.h) which is designed to be resistant to path injection attacks, is dependency-injection-friendly to ease unit testing, is cross-platform (Unix and Windows), makes atomic file replacement easy, makes mmap easy, and [other neat features](https://github.com/capnproto/capnproto/pull/384). +* The `capnp` tool now has a `convert` command which can be used to convert between all known message encodings, such as binary, packed, text, JSON, canonical, etc. This obsoletes the old `encode` and `decode` commands. +* Many smaller tweaks and bug fixes. + + diff --git a/doc/_posts/2020-04-23-capnproto-0.8.md b/doc/_posts/2020-04-23-capnproto-0.8.md new file mode 100644 index 0000000000..3d1bd2e790 --- /dev/null +++ b/doc/_posts/2020-04-23-capnproto-0.8.md @@ -0,0 +1,115 @@ +--- +layout: post +title: "Cap'n Proto 0.8: Streaming flow control, HTTP-over-RPC, fibers, etc." +author: kentonv +--- + + + +Today I'm releasing Cap'n Proto 0.8. + +### What's new? + +* [Multi-stream Flow Control](#multi-stream-flow-control) +* [HTTP-over-Cap'n-Proto](#http-over-capn-proto) +* [KJ improvements](#kj-improvements) +* Lots and lots of minor tweaks and fixes. + +#### Multi-stream Flow Control + +It is commonly believed, wrongly, that Cap'n Proto doesn't support "streaming", in the way that gRPC does. In fact, Cap'n Proto's object-capability model and promise pipelining make it much more expressive than gRPC. In Cap'n Proto, "streaming" is just a pattern, not a built-in feature. + +Streaming is accomplished by introducing a temporary RPC object as part of a call. Each streamed message becomes a call to the temporary object. Think of this like providing a callback function in an object-oriented language. + +For instance, server -> client streaming ("returning multiple responses") can look like this: + +{% highlight capnp %} +# NOT NEW: Server -> client streaming example. +interface MyInterface { + streamingCall @0 (callback :Callback) -> (); + + interface Callback { + sendChunk @0 (chunk :Data) -> (); + } +} +{% endhighlight %} + +Or for client -> server streaming, the server returns a callback: + +{% highlight capnp %} +# NOT NEW: Client -> Server streaming example. +interface MyInterface { + streamingCall @0 () -> (callback :Callback); + + interface Callback { + sendChunk @0 (chunk :Data) -> (); + } +} +{% endhighlight %} + +Note that the client -> server example relies on [promise pipelining](https://capnproto.org/rpc.html#time-travel-promise-pipelining): When the client invokes `streamingCall()`, it does NOT have to wait for the server to respond before it starts making calls to the `callback`. Using promise pipelining (which has been a built-in feature of Cap'n Proto RPC since its first release in 2013), the client sends messages to the server that say: "Once my call to `streamingCall()` call is finished, take the returned callback and call this on it." + +Obviously, you can also combine the two examples to create bidirectional streams. You can also introduce "callback" objects that have multiple methods, methods that themselves return values (maybe even further streaming callbacks!), etc. You can send and receive multiple new RPC objects in a single call. Etc. + +But there has been one problem that arises in the context of streaming specifically: flow control. Historically, if an app wanted to stream faster than the underlying network connection would allow, then it could end up queuing messages in memory. Worse, if other RPC calls were happening on the same connection concurrently, they could end up blocked behind these queued streaming calls. + +In order to avoid such problems, apps needed to implement some sort of flow control strategy. An easy strategy was to wait for each `sendChunk()` call to return before starting the next call, but this would incur an unnecessary network round trip for each chunk. A better strategy was for apps to allow multiple concurrent calls, but only up to some limit before waiting for in-flight calls to return. For example, an app could limit itself to four in-flight stream calls at a time, or to 64kB worth of chunks. + +This sort of worked, but there were two problems. First, this logic could get pretty complicated, distracting from the app's business logic. Second, the "N-bytes-in-flight-at-a-time" strategy only works well if the value of N is close to the [bandwidth-delay product (BDP)](https://en.wikipedia.org/wiki/Bandwidth-delay_product) of the connection. If N was chosen too low, the connection would be under-utilized. If too high, it would increase queuing latency for all users of the connection. + +Cap'n Proto 0.8 introduces a built-in feature to manage flow control. Now, you can declare your streaming calls like this: + +{% highlight capnp %} +interface MyInterface { + streamingCall @0 (callback :Callback) -> (); + + interface Callback { + # NEW: This streaming call features flow control! + sendChunk @0 (chunk :Data) -> stream; + done @1 (); + } +} +{% endhighlight %} + +Methods declared with `-> stream` behave like methods with empty return types (`-> ()`), but with special behavior when the call is sent over a network connection. Instead of waiting for the remote site to respond to the call, the Cap'n Proto client library will act as if the call has "returned" as soon as it thinks the app should send the next call. So, now the app can use a simple loop that calls `sendChunk()`, waits for it to "complete", then sends the next chunk. Each call will appear to "return immediately" until such a time as Cap'n Proto thinks the connection is fully-utilized, and then each call will block until space frees up. + +When using streaming, it is important that apps be aware that error handling works differently. Since the client side may indicate completion of the call before the call has actually executed on the server, any exceptions thrown on the server side obviously cannot propagate to the client. Instead, we introduce a new rule: If a streaming call ends up throwing an exception, then all later method invocations on the same object (streaming or not) will also throw the same exception. You'll notice that we added a `done()` method to the callback interface above. After completing all streaming calls, the caller _must_ call `done()` to check for errors. If any previous streaming call failed, then `done()` will fail too. + +Under the hood, Cap'n Proto currently implements flow control using a simple hack: it queries the send buffer size of the underlying network socket, and sets that as the "window size" for each stream. The operating system will typically increase the socket buffer as needed to match the TCP congestion window, and Cap'n Proto's streaming window size will increase to match. This is not a very good implementation for a number of reasons. The biggest problem is that it doesn't account for proxying: with Cap'n Proto it is common to pass objects through multiple nodes, which automatically arranges for calls to the object to be proxied though the middlemen. But, the TCP socket buffer size only approximates the BDP of the first hop. A better solution would measure the end-to-end BDP using an algorithm like [BBR](https://queue.acm.org/detail.cfm?id=3022184). Expect future versions of Cap'n Proto to improve on this. + +Note that this new feature does not come with any change to the underlying RPC protocol! The flow control behavior is implemented entirely on the client side. The `-> stream` declaration in the schema is merely a hint to the client that it should use this behavior. Methods declared with `-> stream` are wire-compatible with methods declared with `-> ()`. Currently, flow control is only implemented in the C++ library. RPC implementations in other languages will treat `-> stream` the same as `-> ()` until they add explicit support for it. Apps in those languages will need to continue doing their own flow control in the meantime, as they did before this feature was added. + +#### HTTP-over-Cap'n-Proto + +Cap'n Proto 0.8 defines [a protocol for tunnelling HTTP calls over Cap'n Proto RPC](https://github.com/capnproto/capnproto/blob/master/c++/src/capnp/compat/http-over-capnp.capnp), along with an [adapter library](https://github.com/capnproto/capnproto/blob/master/c++/src/capnp/compat/http-over-capnp.h) adapting it to the [KJ HTTP API](https://github.com/capnproto/capnproto/blob/master/c++/src/kj/compat/http.h). Thus, programs written to send or receive HTTP requests using KJ HTTP can easily be adapted to communicate over Cap'n Proto RPC instead. It's also easy to build a proxy that converts regular HTTP protocol into Cap'n Proto RPC and vice versa. + +In principle, http-over-capnp can achieve similar advantages to HTTP/2: Multiple calls can multiplex over the same connection with arbitrary ordering. But, unlike HTTP/2, calls can be initiated in either direction, can be addressed to multiple virtual endpoints (without relying on URL-based routing), and of course can be multiplexed with non-HTTP Cap'n Proto traffic. + +In practice, however, http-over-capnp is new, and should not be expected to perform as well as mature HTTP/2 implementations today. More work is needed. + +We use http-over-capnp in [Cloudflare Workers](https://workers.cloudflare.com/) to communicate HTTP requests between components of the system, especially into and out of sandboxes. Using this protocol, instead of plain HTTP or HTTP/2, allows us to communicate routing and metadata out-of-band (rather than e.g. stuffing it into private headers). It also allows us to design component APIs using an [object-capability model](http://erights.org/elib/capability/ode/ode-capabilities.html), which turns out to be an excellent choice when code needs to be securely sandboxed. + +Today, our use of this protocol is fairly experimental, but we plan to use it more heavily as the code matures. + +#### KJ improvements + +KJ is the C++ toolkit library developed together with Cap'n Proto's C++ implementation. Ironically, most of the development in the Cap'n Proto repo these days is actually improvements to KJ, in part because it is used heavily in the implementation of [Cloudflare Workers](https://workers.cloudflare.com/). + +* The KJ Promise API now supports fibers. Fibers allow you to execute code in a synchronous style within a thread driven by an asynchronous event loop. The synchronous code runs on an alternate call stack. The code can synchronously wait on a promise, at which point the thread switches back to the main stack and runs the event loop. We generally recommend that new code be written in asynchronous style rather than using fibers, but fibers can be useful in cases where you want to call a synchronous library, and then perform asynchronous tasks in callbacks from said library. [See the pull request for more details.](https://github.com/capnproto/capnproto/pull/913) +* New API `kj::Executor` can be used to communicate directly between event loops on different threads. You can use it to execute an arbitrary lambda on a different thread's event loop. Previously, it was necessary to use some OS construct like a pipe, signal, or eventfd to wake up the receiving thread. +* KJ's mutex API now supports conditional waits, meaning you can unlock a mutex and sleep until such a time as a given lambda function, applied to the mutex's protected state, evaluates to true. +* The KJ HTTP library has continued to be developed actively for its use in [Cloudflare Workers](https://workers.cloudflare.com/). This library now handles millions of requests per second worldwide, both as a client and as a server (since most Workers are proxies), for a wide variety of web sites big and small. + +### Towards 1.0 + +Cap'n Proto has now been around for seven years, with many huge production users (such as Cloudflare). But, we're still on an 0.x release? What gives? + +Well, to be honest, there are still a lot of missing features that I feel like are critical to Cap'n Proto's vision, the most obvious one being three-party handoff. But, so far I just haven't had a real production need to implement those features. Clearly, I should stop waiting for perfection. + +Still, there are a couple smaller things I want to do for an upcoming 1.0 release: + +1. Properly document KJ, independent of Cap'n Proto. KJ has evolved into an extremely useful general-purpose C++ toolkit library. +2. Fix a mistake in the design of KJ's `AsyncOutputStream` interface. The interface currently does not have a method to write EOF; instead, EOF is implied by the destructor. This has proven to be the wrong design. Since fixing it will be a breaking API change for anyone using this interface, I want to do it before declaring 1.0. + +I aim to get these done sometime this summer... diff --git a/doc/_posts/2021-08-14-capnproto-0.9.md b/doc/_posts/2021-08-14-capnproto-0.9.md new file mode 100644 index 0000000000..edfb8d2024 --- /dev/null +++ b/doc/_posts/2021-08-14-capnproto-0.9.md @@ -0,0 +1,14 @@ +--- +layout: post +title: "Cap'n Proto 0.9" +author: kentonv +--- + + + +Today I'm releasing Cap'n Proto 0.9. + +There's no huge new features in this release, but there are many minor improvements and bug fixes. You can [read the PR history](https://github.com/capnproto/capnproto/pulls?q=is%3Apr+is%3Aclosed) to find out what has changed. + +Cap'n Proto development has continued to be primarily driven by the [Cloudflare Workers](https://workers.cloudflare.com/) project (of which I'm the lead engineer). As of the previous release, Cloudflare Workers primarily used the [KJ C++ toolkit](https://github.com/capnproto/capnproto/blob/master/kjdoc/tour.md) that is developed with Cap'n Proto, but made only light use of Cap'n Proto serialization and RPC itself. That has now changed: the implementation of [Durable Objects](https://blog.cloudflare.com/introducing-workers-durable-objects/) makes heavy use of Cap'n Proto RPC for essentially all communication within the system. diff --git a/doc/_posts/2022-06-03-capnproto-0.10.md b/doc/_posts/2022-06-03-capnproto-0.10.md new file mode 100644 index 0000000000..999fd10fbc --- /dev/null +++ b/doc/_posts/2022-06-03-capnproto-0.10.md @@ -0,0 +1,12 @@ +--- +layout: post +title: "Cap'n Proto 0.10" +author: kentonv +--- + + + +Today I'm releasing Cap'n Proto 0.10. + +Like last time, there's no huge new features in this release, but there are many minor improvements and bug fixes. You can [read the PR history](https://github.com/capnproto/capnproto/pulls?q=is%3Apr+is%3Aclosed) to find out what has changed. diff --git a/doc/_posts/2022-11-30-CVE-2022-46149-security-advisory.md b/doc/_posts/2022-11-30-CVE-2022-46149-security-advisory.md new file mode 100644 index 0000000000..01284cca4a --- /dev/null +++ b/doc/_posts/2022-11-30-CVE-2022-46149-security-advisory.md @@ -0,0 +1,13 @@ +--- +layout: post +title: "CVE-2022-46149: Possible out-of-bounds read related to list-of-pointers" +author: kentonv +--- + +David Renshaw, the author of the Rust implementation of Cap'n Proto, discovered a security vulnerability affecting both the C++ and Rust implementations of Cap'n Proto. The vulnerability was discovered using fuzzing. In theory, the vulnerability could lead to out-of-bounds reads which could cause crashes or perhaps exfiltration of memory. + +The vulnerability is exploitable only if an application performs a certain unusual set of actions. As of this writing, we are not aware of any applications that are actually affected. However, out of an abundance of caution, we are issuing a security advisory and advising everyone to patch. + +[Our security advisory](https://github.com/capnproto/capnproto/blob/master/security-advisories/2022-11-30-0-pointer-list-bounds.md) explains the impact of the bug, what an app must do to be affected, and where to find the fix. + +Check out [David's blog post](https://dwrensha.github.io/capnproto-rust/2022/11/30/out_of_bounds_memory_access_bug.html) for an in-depth explanation of the bug itself, including some of the inner workings of Cap'n Proto. diff --git a/doc/_posts/2023-07-28-capnproto-1.0.md b/doc/_posts/2023-07-28-capnproto-1.0.md new file mode 100644 index 0000000000..94073efa3c --- /dev/null +++ b/doc/_posts/2023-07-28-capnproto-1.0.md @@ -0,0 +1,74 @@ +--- +layout: post +title: "Cap'n Proto 1.0" +author: kentonv +--- + + + +It's been a little over ten years since the first release of Cap'n Proto, on April 1, 2013. Today I'm releasing version 1.0 of Cap'n Proto's C++ reference implementation. + +Don't get too excited! There's not actually much new. Frankly, I should have declared 1.0 a long time ago – probably around version 0.6 (in 2017) or maybe even 0.5 (in 2014). I didn't mostly because there were a few advanced features (like three-party handoff, or shared-memory RPC) that I always felt like I wanted to finish before 1.0, but they just kept not reaching the top of my priority list. But the reality is that Cap'n Proto has been relied upon in production for a long time. In fact, you are using Cap'n Proto right now, to view this site, which is served by Cloudflare, which uses Cap'n Proto extensively (and is also my employer, although they used Cap'n Proto before they hired me). Cap'n Proto is used to encode millions (maybe billions) of messages and gigabits (maybe terabits) of data every single second of every day. As for those still-missing features, the real world has seemingly proven that they aren't actually that important. (I still do want to complete them though.) + +Ironically, the thing that finally motivated the 1.0 release is so that we can start working on 2.0. But again here, don't get too excited! Cap'n Proto 2.0 is not slated to be a revolutionary change. Rather, there are a number of changes we (the Cloudflare Workers team) would like to make to Cap'n Proto's C++ API, and its companion, the KJ C++ toolkit library. Over the ten years these libraries have been available, I have kept their APIs pretty stable, despite being 0.x versioned. But for 2.0, we want to make some sweeping backwards-incompatible changes, in order to fix some footguns and improve developer experience for those on our team. + +Some users probably won't want to keep up with these changes. Hence, I'm releasing 1.0 now as a sort of "long-term support" release. We'll backport bugfixes as appropriate to the 1.0 branch for the long term, so that people who aren't interested in changes can just stick with it. + +## What's actually new in 1.0? + +Again, not a whole lot has changed since the last version, 0.10. But there are a few things worth mentioning: + +* A number of optimizations were made to improve performance of Cap'n Proto RPC. These include reducing the amount of memory allocation done by the RPC implementation and KJ I/O framework, adding the ability to elide certain messages from the RPC protocol to reduce traffic, and doing better buffering of small messages that are sent and received together to reduce syscalls. These are incremental improvements. + +* **Breaking change:** Previously, servers could opt into allowing RPC cancellation by calling `context.allowCancellation()` after a call was delivered. In 1.0, opting into cancellation is instead accomplished using an annotation on the schema (the `allowCancellation` annotation defined in `c++.capnp`). We made this change after observing that in practice, we almost always wanted to allow cancellation, but we almost always forgot to do so. The schema-level annotation can be set on a whole file at a time, which is easier not to forget. Moreover, the dynamic opt-in required a lot of bookkeeping that had a noticeable performance impact in practice; switching to the annotation provided a performance boost. For users that never used `context.allowCancellation()` in the first place, there's no need to change anything when upgrading to 1.0 – cancellation is still disallowed by default. (If you are affected, you will see a compile error. If there's no compile error, you have nothing to worry about.) + +* KJ now uses `kqueue()` to handle asynchronous I/O on systems that have it (MacOS and BSD derivatives). KJ has historically always used `epoll` on Linux, but until now had used a slower `poll()`-based approach on other Unix-like platforms. + +* KJ's HTTP client and server implementations now support the `CONNECT` method. + +* [A new class `capnp::RevocableServer` was introduced](https://github.com/capnproto/capnproto/pull/1700) to assist in exporting RPC wrappers around objects whose lifetimes are not controlled by the wrapper. Previously, avoiding use-after-free bugs in such scenarios was tricky. + +* Many, many smaller bug fixes and improvements. [See the PR history](https://github.com/capnproto/capnproto/pulls?q=is%3Apr+is%3Aclosed) for details. + +## What's planned for 2.0? + +The changes we have in mind for version 2.0 of Cap'n Proto's C++ implementation are mostly NOT related to the protocol itself, but rather to the C++ API and especially to KJ, the C++ toolkit library that comes with Cap'n Proto. These changes are motivated by our experience building a large codebase on top of KJ: namely, the Cloudflare Workers runtime, [`workerd`](https://github.com/cloudflare/workerd). + +KJ is a C++ toolkit library, arguably comparable to things like Boost, Google's Abseil, or Facebook's Folly. I started building KJ at the same time as Cap'n Proto in 2013, at a time when C++11 was very new and most libraries were not really designing around it yet. The intent was never to create a new standard library, but rather to address specific needs I had at the time. But over many years, I ended up building a lot of stuff. By the time I joined Cloudflare and started the Workers Runtime, KJ already featured a powerful async I/O framework, HTTP implementation, TLS bindings, and more. + +Of course, KJ has nowhere near as much stuff as Boost or Abseil, and nowhere near as much engineering effort behind it. You might argue, therefore, that it would have been better to choose one of those libraries to build on. However, KJ had a huge advantage: that we own it, and can shape it to fit our specific needs, without having to fight with anyone to get those changes upstreamed. + +One example among many: KJ's HTTP implementation features the ability to "suspend" the state of an HTTP connection, after receiving headers, and transfer it to a different thread or process to be resumed. This is an unusual thing to want, but is something we needed for resource management in the Workers Runtime. Implementing this required some deep surgery in KJ HTTP and definitely adds complexity. If we had been using someone else's HTTP library, would they have let us upstream such a change? + +That said, even though we own KJ, we've still tried to avoid making any change that breaks third-party users, and this has held back some changes that would probably benefit Cloudflare Workers. We have therefore decided to "fork" it. Version 2.0 is that fork. + +Development of version 2.0 will take place on Cap'n Proto's new `v2` branch. The `master` branch will become the 1.0 LTS branch, so that existing projects which track `master` are not disrupted by our changes. + +We don't yet know all the changes we want to make as we've only just started thinking seriously about it. But, here's some ideas we've had so far: + +* We will require a compiler with support for C++20, or maybe even C++23. Cap'n Proto 1.0 only requires C++14. + +* In particular, we will require a compiler that supports C++20 coroutines, as lots of KJ async code will be refactored to rely on coroutines. This should both make the code clearer and improve performance by reducing memory allocations. However, coroutine support is still spotty – as of this writing, GCC seems to ICE on KJ's coroutine implementation. + +* Cap'n Proto's RPC API, KJ's HTTP APIs, and others are likely to be revised to make them more coroutine-friendly. + +* `kj::Maybe` will become more ergonomic. It will no longer overload `nullptr` to represent the absence of a value; we will introduce `kj::none` instead. `KJ_IF_MAYBE` will no longer produce a pointer, but instead a reference (a trick that becomes possible by utilizing C++17 features). + +* We will drop support for compiling with exceptions disabled. KJ's coding style uses exceptions as a form of software fault isolation, or "catchable panics", such that errors can cause the "current task" to fail out without disrupting other tasks running concurrently. In practice, this ends up affecting every part of how KJ-style code is written. And yet, since the beginning, KJ and Cap'n Proto have been designed to accommodate environments where exceptions are turned off at compile time, using an elaborate system to fall back to callbacks and distinguish between fatal and non-fatal exceptions. In practice, maintaining this ability has been a drag on development – no-exceptions mode is constantly broken and must be tediously fixed before each release. Even when the tests are passing, it's likely that a lot of KJ's functionality realistically cannot be used in no-exceptions mode due to bugs and fragility. Today, I would strongly recommend against anyone using this mode except maybe for the most basic use of Cap'n Proto's serialization layer. Meanwhile, though, I'm honestly not sure if anyone uses this mode at all! In theory I would expect many people do, since many people choose to use C++ with exceptions disabled, but I've never actually received a single question or bug report related to it. It seems very likely that this was wasted effort all along. By removing support, we can simplify a lot of stuff and probably do releases more frequently going forward. + +* Similarly, we'll drop support for no-RTTI mode and other exotic modes that are a maintenance burden. + +* We may revise KJ's approach to reference counting, as the current design has proven to be unintuitive to many users. + +* We will fix a longstanding design flaw in `kj::AsyncOutputStream`, where EOF is currently signaled by destroying the stream. Instead, we'll add an explicit `end()` method that returns a Promise. Destroying the stream without calling `end()` will signal an erroneous disconnect. (There are several other aesthetic improvements I'd like to make to the KJ stream APIs as well.) + +* We may want to redesign several core I/O APIs to be a better fit for Linux's new-ish io_uring event notification paradigm. + +* The RPC implementation may switch to allowing cancellation by default. As discussed above, this is opt-in today, but in practice I find it's almost always desirable, and disallowing it can lead to subtle problems. + +* And so on. + +It's worth noting that at present, there is no plan to make any backwards-incompatible changes to the serialization format or RPC protocol. The changes being discussed only affect the C++ API. Applications written in other languages are completely unaffected by all this. + +It's likely that a formal 2.0 release will not happen for some time – probably a few years. I want to make sure we get through all the really big breaking changes we want to make, before we inflict update pain on most users. Of course, if you're willing to accept breakages, you can always track the `v2` branch. Cloudflare Workers releases from `v2` twice a week, so it should always be in good working order. diff --git a/doc/capnp-tool.md b/doc/capnp-tool.md index 88a78ae565..228dc3d660 100644 --- a/doc/capnp-tool.md +++ b/doc/capnp-tool.md @@ -65,7 +65,7 @@ This prints the value of `myConstant`, a [const](language.html#constants) declar applying variable substitution. It can also output the value in binary format (`--binary` or `--packed`). -At first glance, this may seem no more interesting that `capnp encode`: the syntax used to define +At first glance, this may seem no more interesting than `capnp encode`: the syntax used to define constants in schema files is the same as the format accepted by `capnp encode`, right? There is, however, a big difference: constants in schema files may be defined in terms of other constants, which may even be imported from other files. diff --git a/doc/cxx.md b/doc/cxx.md index fd0ebe8670..dcd8b4cca9 100644 --- a/doc/cxx.md +++ b/doc/cxx.md @@ -160,7 +160,7 @@ See the header `kj/exception.h` for details on how to register an exception call Cap'n Proto is built on top of a basic utility library called KJ. The two were actually developed together -- KJ is simply the stuff which is not specific to Cap'n Proto serialization, and may be -useful to others independently of Cap'n Proto. For now, the the two are distributed together. The +useful to others independently of Cap'n Proto. For now, the two are distributed together. The name "KJ" has no particular meaning; it was chosen to be short and easy-to-type. As of v0.3, KJ is distributed with Cap'n Proto but built as a separate library. You may need @@ -179,7 +179,7 @@ To use this code in your app, you must link against both `libcapnp` and `libkj`. flags. If you use [RPC](cxxrpc.html) (i.e., your schema defines [interfaces](language.html#interfaces)), -then you will additionally nead to link against `libcapnp-rpc` and `libkj-async`, or use the +then you will additionally need to link against `libcapnp-rpc` and `libkj-async`, or use the `capnp-rpc` `pkg-config` module. ### Setting a Namespace @@ -794,7 +794,7 @@ Here are some tips for using the C++ Cap'n Proto runtime most effectively: dead space. In the future, Cap'n Proto may be improved such that it can re-use dead space in a message. - However, this will only improve things, not fix them entirely: fragementation could still leave + However, this will only improve things, not fix them entirely: fragmentation could still leave dead space. ### Build Tips @@ -877,7 +877,7 @@ tips will apply. ## Lessons Learned from Protocol Buffers -The author of Cap'n Proto's C++ implementation also wrote (in the past) verison 2 of Google's +The author of Cap'n Proto's C++ implementation also wrote (in the past) version 2 of Google's Protocol Buffers. As a result, Cap'n Proto's implementation benefits from a number of lessons learned the hard way: diff --git a/doc/cxxrpc.md b/doc/cxxrpc.md index 22150d9a2d..3e55bcade5 100644 --- a/doc/cxxrpc.md +++ b/doc/cxxrpc.md @@ -16,13 +16,13 @@ not yet implemented. ## Sample Code -The [Calculator example](https://github.com/sandstorm-io/capnproto/tree/master/c++/samples) implements +The [Calculator example](https://github.com/capnproto/capnproto/tree/master/c++/samples) implements a fully-functional Cap'n Proto client and server. ## KJ Concurrency Framework RPC naturally requires a notion of concurrency. Unfortunately, -[all concurrency models suck](https://plus.google.com/u/0/+KentonVarda/posts/D95XKtB5DhK). +[all concurrency models suck](https://web.archive.org/web/20170718202612/https://plus.google.com/+KentonVarda/posts/D95XKtB5DhK). Cap'n Proto's RPC is based on the [KJ library](cxx.html#kj-library)'s event-driven concurrency framework. The core of the KJ asynchronous framework (events, promises, callbacks) is defined in @@ -35,8 +35,8 @@ must have its own event loop. KJ discourages fine-grained interaction between t synchronization is expensive and error-prone. Instead, threads are encouraged to communicate through Cap'n Proto RPC. -KJ's event loop model bears a lot of similarity to the Javascript concurrency model. Experienced -Javascript hackers -- especially node.js hackers -- will feel right at home. +KJ's event loop model bears a lot of similarity to the JavaScript concurrency model. Experienced +JavaScript hackers -- especially node.js hackers -- will feel right at home. _As of version 0.4, the only supported way to communicate between threads is over pipes or socketpairs. This will be improved in future versions. For now, just set up an RPC connection @@ -64,7 +64,7 @@ kj::Promise sendEmail(kj::StringPtr address, // the message has been successfully sent. {% endhighlight %} -As you will see, KJ promises are very similar to the evolving Javascript promise standard, and +As you will see, KJ promises are very similar to the evolving JavaScript promise standard, and much of the [wisdom around it](https://www.google.com/search?q=javascript+promises) can be directly applied to KJ promises. @@ -390,10 +390,11 @@ int main(int argc, const char* argv[]) { {% endhighlight %} Note that for the connect address, Cap'n Proto supports DNS host names as well as IPv4 and IPv6 -addresses. Additionally, a Unix domain socket can be specified as `unix:` followed by a path name. +addresses. Additionally, a Unix domain socket can be specified as `unix:` followed by a path name, +and an abstract Unix domain socket can be specified as `unix-abstract:` followed by an identifier. For a more complete example, see the -[calculator client sample](https://github.com/sandstorm-io/capnproto/tree/master/c++/samples/calculator-client.c++). +[calculator client sample](https://github.com/capnproto/capnproto/tree/master/c++/samples/calculator-client.c++). ### Starting a server @@ -429,10 +430,11 @@ int main(int argc, const char* argv[]) { Note that for the bind address, Cap'n Proto supports DNS host names as well as IPv4 and IPv6 addresses. The special address `*` can be used to bind to the same port on all local IPv4 and IPv6 interfaces. Additionally, a Unix domain socket can be specified as `unix:` followed by a -path name. +path name, and an abstract Unix domain socket can be specified as `unix-abstract:` followed by +an identifier. For a more complete example, see the -[calculator server sample](https://github.com/sandstorm-io/capnproto/tree/master/c++/samples/calculator-server.c++). +[calculator server sample](https://github.com/capnproto/capnproto/tree/master/c++/samples/calculator-server.c++). ## Debugging diff --git a/doc/encoding.md b/doc/encoding.md index fe9304d55a..78a203249f 100644 --- a/doc/encoding.md +++ b/doc/encoding.md @@ -119,6 +119,44 @@ Field offsets are computed by the Cap'n Proto compiler. The precise algorithm i to describe here, but you need not implement it yourself, as the compiler can produce a compiled schema format which includes offset information. +#### Default Values + +A default struct is always all-zeros. To achieve this, fields in the data section are stored xor'd +with their defined default values. An all-zero pointer is considered "null"; accessor methods +for pointer fields check for null and return a pointer to their default value in this case. + +There are several reasons why this is desirable: + +* Cap'n Proto messages are often "packed" with a simple compression algorithm that deflates + zero-value bytes. +* Newly-allocated structs only need to be zero-initialized, which is fast and requires no knowledge + of the struct type except its size. +* If a newly-added field is placed in space that was previously padding, messages written by old + binaries that do not know about this field will still have its default value set correctly -- + because it is always zero. + +#### Zero-sized structs. + +As stated above, a pointer whose bits are all zero is considered a null pointer, *not* a struct of +zero size. To encode a struct of zero size, set A, C, and D to zero, and set B (the offset) to -1. + +**Historical explanation:** A null pointer is intended to be treated as equivalent to the field's +default value. Early on, it was thought that a zero-sized struct was a suitable synonym for +null, since interpreting an empty struct as any struct type results in a struct whose fields are +all default-valued. So, the pointer encoding was designed such that a zero-sized struct's pointer +would be all-zero, so that it could conveniently be overloaded to mean "null". + +However, it turns out there are two important differences between a zero-sized struct and a null +pointer. First, applications often check for null explicitly when implementing optional fields. +Second, an empty struct is technically equivalent to the default value for the struct *type*, +whereas a null pointer is equivalent to the default value for the particular *field*. These are +not necessarily the same. + +It therefore became necessary to find a different encoding for zero-sized structs. Since the +struct has zero size, the pointer's offset can validly point to any location so long as it is +in-bounds. Since an offset of -1 points to the beginning of the pointer itself, it is known to +be in-bounds. So, we use an offset of -1 when the struct has zero size. + ### Lists A list value is encoded as a pointer to a flat array of values. @@ -140,7 +178,9 @@ A list value is encoded as a pointer to a flat array of values. 5 = 8 bytes (non-pointer) 6 = 8 bytes (pointer) 7 = composite (see below) - D (29 bits) = Number of elements in the list, except when C is 7 + D (29 bits) = Size of the list: + when C <> 7: Number of elements in the list. + when C = 7: Number of words in the list, not counting the tag word (see below). The pointed-to values are tightly-packed. In particular, `Bool`s are packed bit-by-bit in @@ -174,23 +214,6 @@ unreasonable implementation burden.) Note that even though struct lists can be d element size (except C = 1), it is NOT permitted to encode a struct list using any type other than C = 7 because doing so would interfere with the [canonicalization algorithm](#canonicalization). -#### Default Values - -A default struct is always all-zeros. To achieve this, fields in the data section are stored xor'd -with their defined default values. An all-zero pointer is considered "null" (such a pointer would -otherwise point to a zero-size struct, which might as well be considered null); accessor methods -for pointer fields check for null and return a pointer to their default value in this case. - -There are several reasons why this is desirable: - -* Cap'n Proto messages are often "packed" with a simple compression algorithm that deflates - zero-value bytes. -* Newly-allocated structs only need to be zero-initialized, which is fast and requires no knowledge - of the struct type except its size. -* If a newly-added field is placed in space that was previously padding, messages written by old - binaries that do not know about this field will still have its default value set correctly -- - because it is always zero. - ### Inter-Segment Pointers When a pointer needs to point to a different segment, offsets no longer work. We instead encode @@ -245,7 +268,7 @@ A capability pointer, then, simply contains an index into the separate capabilit C (32 bits) = Index of the capability in the message's capability table. -In [rpc.capnp](https://github.com/sandstorm-io/capnproto/blob/master/c++/src/capnp/rpc.capnp), the +In [rpc.capnp](https://github.com/capnproto/capnproto/blob/master/c++/src/capnp/rpc.capnp), the capability table is encoded as a list of `CapDescriptors`, appearing along-side the message content in the `Payload` struct. However, some use cases may call for different approaches. A message that is built and consumed within the same process need not encode the capability table at all @@ -351,6 +374,10 @@ A canonical Cap'n Proto message must adhere to the following rules: * Similarly, for a struct list, if a trailing word in a section of all structs in the list is zero, then it must be truncated from all structs in the list. (All structs in a struct list must have equal sizes, hence a trailing zero can only be removed if it is zero in all elements.) +* Any struct pointer pointing to a zero-sized struct should have an + offset of -1. + * Note that this applies _only_ to structs; other zero-sized values should have offsets + allocated in preorder, as normal. * Canonical messages are not packed. However, packing can still be applied for transmission purposes; the message must simply be unpacked before checking signatures. @@ -388,7 +415,7 @@ different limit if desired. Another reasonable strategy is to set the limit to s the original message size; however, most applications should place limits on overall message sizes anyway, so it makes sense to have one check cover both. -**List amplification:** A list of `Void` values or zero-size structs can have a very large element count while taking constant space on the wire. If the receiving application expects a list of structs, it will see these zero-sized elements as valid structs set to their default values. If it iterates through the list processing each element, it could spend a large amount of CPU time or other resources despite the message being small. To defend against this, the "traversal limit" should count a list of zero-sized elements as if each element were one word instead. This rule was introduced in the C++ implementation in [commit 1048706](https://github.com/sandstorm-io/capnproto/commit/104870608fde3c698483fdef6b97f093fc15685d). +**List amplification:** A list of `Void` values or zero-size structs can have a very large element count while taking constant space on the wire. If the receiving application expects a list of structs, it will see these zero-sized elements as valid structs set to their default values. If it iterates through the list processing each element, it could spend a large amount of CPU time or other resources despite the message being small. To defend against this, the "traversal limit" should count a list of zero-sized elements as if each element were one word instead. This rule was introduced in the C++ implementation in [commit 1048706](https://github.com/capnproto/capnproto/commit/104870608fde3c698483fdef6b97f093fc15685d). ### Stack overflow DoS attack diff --git a/doc/faq.md b/doc/faq.md index e3bf4becb5..9443e12683 100644 --- a/doc/faq.md +++ b/doc/faq.md @@ -197,15 +197,26 @@ Cap'n Proto may be layered on top of an existing encrypted transport, such as TL ### How do I report security bugs? -Please email [security@sandstorm.io](mailto:security@sandstorm.io). +Please email [kenton@cloudflare.com](mailto:kenton@cloudflare.com). ## Sandstorm ### How does Cap'n Proto relate to Sandstorm.io? -[Sandstorm.io](https://sandstorm.io) is an Open Source project and startup founded by Kenton, the author of Cap'n Proto. Cap'n Proto is owned and developed by Sandstorm the company and heavily used in Sandstorm the project. +[Sandstorm.io](https://sandstorm.io) is an Open Source project and startup founded by Kenton, the author of Cap'n Proto. Cap'n Proto was developed by Sandstorm the company and heavily used in Sandstorm the project. Sandstorm ceased most operations in 2017 and formally dissolved as a company in 2022, but the open source project continues to be developed by the community. ### How does Sandstorm use Cap'n Proto? See [this Sandstorm blog post](https://blog.sandstorm.io/news/2014-12-15-capnproto-0.5.html). +## Cloudflare + +### How does Cap'n Proto relate to Cloudflare? + +[Cloudflare Workers](https://workers.dev) is a next-generation cloud application platform. Kenton, the author of Cap'n Proto, is the lead engineer on the Workers project. Workers heavily uses Cap'n Proto in its implementation, and the Cloudflare Workers team are now the primarily developers and maintainers of Cap'n Proto's primary C++ implementation. + +### How does Cloudflare use Cap'n Proto? + +The Cloudflare Workers runtime is built on Cap'n Proto and it's associated C++ toolkit library, KJ. Cap'n Proto is used for a variety of things, such as communication between sandbox processes and their supervisors, as well between machines and datacenters, especially in the implementation of [Durable Objects](https://blog.cloudflare.com/introducing-workers-durable-objects/). + +Cloudflare has also [long used Cap'n Proto in its logging pipeline](http://www.thedotpost.com/2015/06/john-graham-cumming-i-got-10-trillion-problems-but-logging-aint-one) and [developed the Lua implementation of Cap'n Proto](https://blog.cloudflare.com/introducing-lua-capnproto-better-serialization-in-lua/) -- both of these actually predate Kenton joining the company. diff --git a/doc/feed.xml b/doc/feed.xml index 3020b0417e..d13f3edbba 100644 --- a/doc/feed.xml +++ b/doc/feed.xml @@ -14,8 +14,8 @@ layout: none {{ post.title | xml_escape }} {{ post.content | xml_escape }} {{ post.date | date: "%a, %d %b %Y %H:%M:%S %z" }} - {{ site.baseurl }}{{ post.url }} - {{ site.baseurl }}{{ post.url }} + {{ post.url }} + {{ post.url }} {% endfor %} diff --git a/doc/go/capnp/index.html b/doc/go/capnp/index.html new file mode 100644 index 0000000000..ba22070244 --- /dev/null +++ b/doc/go/capnp/index.html @@ -0,0 +1,19 @@ +--- +layout: none +--- + + + + + + + + + + + diff --git a/doc/index.md b/doc/index.md index 893d62b663..ae57b83979 100644 --- a/doc/index.md +++ b/doc/index.md @@ -8,7 +8,7 @@ title: Introduction Cap'n Proto is an insanely fast data interchange format and capability-based RPC system. Think -JSON, except binary. Or think [Protocol Buffers](http://protobuf.googlecode.com), except faster. +JSON, except binary. Or think [Protocol Buffers](https://github.com/protocolbuffers/protobuf), except faster. In fact, in benchmarks, Cap'n Proto is INFINITY TIMES faster than Protocol Buffers. This benchmark is, of course, unfair. It is only measuring the time to encode and decode a message @@ -51,7 +51,7 @@ Cap'n Proto generates classes with accessor methods that you use to traverse the Thus, Cap'n Proto checks the structural integrity of the message just like any other serialization protocol would. And, just like any other protocol, it is up to the app to check the validity of the content. -Cap'n Proto was built to be used in [Sandstorm.io](https://sandstorm.io), where security is a major concern. As of this writing, Cap'n Proto has not undergone a security review, therefore we suggest caution when handling messages from untrusted sources. That said, our response to security issues was once described by security guru Ben Laurie as ["the most awesome response I've ever had."](https://twitter.com/BenLaurie/status/575079375307153409) (Please report all security issues to [security@sandstorm.io](mailto:security@sandstorm.io).) +Cap'n Proto was built to be used in [Sandstorm.io](https://sandstorm.io), and is now heavily used in [Cloudflare Workers](https://workers.dev), two environments where security is a major concern. Cap'n Proto has undergone fuzzing and expert security review. Our response to security issues was once described by security guru Ben Laurie as ["the most awesome response I've ever had."](https://twitter.com/BenLaurie/status/575079375307153409) (Please report all security issues to [kenton@cloudflare.com](mailto:kenton@cloudflare.com).) **_Are there other advantages?_** @@ -90,7 +90,7 @@ version 2, which is the version that Google released open source. Cap'n Proto is years of experience working on Protobufs, listening to user feedback, and thinking about how things could be done better. -Note that I no longer work for Google. Cap'n Proto is not, and never has been, affiliated with Google; in fact, it is a property of [Sandstorm.io](https://sandstorm.io), of which I am co-founder. +Note that I no longer work for Google. Cap'n Proto is not, and never has been, affiliated with Google. **_OK, how do I get started?_** diff --git a/doc/install.md b/doc/install.md index 72eef02801..5a036b2ca4 100644 --- a/doc/install.md +++ b/doc/install.md @@ -18,16 +18,16 @@ This package is licensed under the [MIT License](http://opensource.org/licenses/ ### Supported Compilers -Cap'n Proto makes extensive use of C++11 language features. As a result, it requires a relatively +Cap'n Proto makes extensive use of C++14 language features. As a result, it requires a relatively new version of a well-supported compiler. The minimum versions are: -* GCC 4.8 -* Clang 3.5 -* Visual C++ 2015 +* GCC 7.0 +* Clang 6.0 +* Visual C++ 2019 If your system's default compiler is older that the above, you will need to install a newer compiler and set the `CXX` environment variable before trying to build Cap'n Proto. For example, -after installing GCC 4.8, you could set `CXX=g++-4.8` to use this compiler. +after installing GCC 7, you could set `CXX=g++-7` to use this compiler. ### Supported Operating Systems @@ -37,11 +37,10 @@ as well as on Windows. We test every Cap'n Proto release on the following platfo * Android * Linux * Mac OS X -* Windows - Cygwin * Windows - MinGW-w64 * Windows - Visual C++ -**Windows users:** Cap'n Proto requires Visual Studio 2015 Update 3 or newer. All features +**Windows users:** Cap'n Proto requires Visual Studio 2019 or newer. All features of Cap'n Proto -- including serialization, dynamic API, RPC, and schema parser -- are now supported. **Mac OS X users:** You should use the latest Xcode with the Xcode command-line @@ -74,6 +73,7 @@ Some package managers include Cap'n Proto packages. Note: These packages are not maintained by us and are sometimes not up to date with the latest Cap'n Proto release. * Debian / Ubuntu: `apt-get install capnproto` +* Arch Linux: `sudo pacman -S capnproto` * Homebrew (OSX): `brew install capnp` **From Git** @@ -83,7 +83,7 @@ If you download directly from Git, you will need to have the GNU autotools -- [automake](http://www.gnu.org/software/automake/), and [libtool](http://www.gnu.org/software/libtool/) -- installed. - git clone https://github.com/sandstorm-io/capnproto.git + git clone https://github.com/capnproto/capnproto.git cd capnproto/c++ autoreconf -i ./configure @@ -101,23 +101,27 @@ If you download directly from Git, you will need to have the GNU autotools -- 2. Find `capnp.exe`, `capnpc-c++.exe`, and `capnpc-capnp.exe` under `capnproto-tools-win32-0.6.1` in the zip and copy them somewhere. +3. If your `.capnp` files will import any of the `.capnp` files provided by the core project, or + if you use the `stream` keyword (which implicitly imports `capnp/stream.capnp`), then you need + to put those files somewhere where the capnp compiler can find them. To do this, copy the + directory `capnproto-c++-0.0.0/src` to the location of your choice, then make sure to pass the + flag `-I ` to `capnp` when you run it. + If you don't care about C++ support, you can stop here. The compiler exe can be used with plugins provided by projects implementing Cap'n Proto in other languages. If you want to use Cap'n Proto in C++ with Visual Studio, do the following: -1. Make sure that you are using Visual Studio 2015 or newer, with all updates installed. Cap'n - Proto uses C++11 language features that did not work in previous versions of Visual Studio, +1. Make sure that you are using Visual Studio 2019 or newer, with all updates installed. Cap'n + Proto uses C++14 language features that did not work in previous versions of Visual Studio, and the updates include many bug fixes that Cap'n Proto requires. 2. Install [CMake](http://www.cmake.org/) version 3.1 or later. -3. Use CMake to generate Visual Studio project files under `capnproto-c++-0.6.1` in the zip file. +3. Use CMake to generate Visual Studio project files under `capnproto-c++-0.0.0` in the zip file. You can use the CMake UI for this or run this shell command: - cmake -G "Visual Studio 14 2015" - - (For VS2017, you can use "Visual Studio 15 2017" as the generator name.) + cmake -G "Visual Studio 16 2019" 3. Open the "Cap'n Proto" solution in Visual Studio. diff --git a/doc/language.md b/doc/language.md index a277b32e9c..5b638ee840 100644 --- a/doc/language.md +++ b/doc/language.md @@ -138,11 +138,11 @@ struct Person { # ... employment :union { + # We assume that a person is only one of these. unemployed @4 :Void; employer @5 :Company; school @6 :School; selfEmployed @7 :Void; - # We assume that a person is only one of these. } } {% endhighlight %} @@ -222,9 +222,9 @@ A group is a set of fields that are encapsulated in their own scope. struct Person { # ... - # Note: This is a terrible way to use groups, and meant - # only to demonstrate the syntax. address :group { + # Note: This is a terrible way to use groups, and meant + # only to demonstrate the syntax. houseNumber @8 :UInt32; street @9 :Text; city @10 :Text; @@ -393,7 +393,7 @@ Cap'n Proto generics work very similarly to Java generics or C++ templates. Some a type is wire-compatible with any specific parameterization, so long as you interpret the `AnyPointer`s as the correct type at runtime. -* Relatedly, it is safe to cast an generic interface of a specific parameterization to a generic +* Relatedly, it is safe to cast a generic interface of a specific parameterization to a generic interface where all parameters are `AnyPointer` and vice versa, as long as the `AnyPointer`s are treated as the correct type at runtime. This means that e.g. you can implement a server in a generic way that is correct for all parameterizations but call it from clients using a specific @@ -403,17 +403,27 @@ Cap'n Proto generics work very similarly to Java generics or C++ templates. Some substituting the type parameters manually. For example, `Map(Text, Person)` is encoded exactly the same as: -
{% highlight capnp %} - struct PersonMap { - # Encoded the same as Map(Text, Person). - entries @0 :List(Entry); - struct Entry { - key @0 :Text; - value @1 :Person; +
struct PersonMap {
+    # Encoded the same as Map(Text, Person).
+    entries @0 :List(Entry);
+    struct Entry {
+      key @0 :Text;
+      value @1 :Person;
     }
-  }
-  {% endhighlight %}
-  
+ } + + {% comment %} + Highlighter manually invoked because of: https://github.com/jekyll/jekyll/issues/588 + Original code was: + struct PersonMap { + # Encoded the same as Map(Text, Person). + entries @0 :List(Entry); + struct Entry { + key @0 :Text; + value @1 :Person; + } + } + {% endcomment %} Therefore, it is possible to upgrade non-generic types to generic types while retaining backwards-compatibility. @@ -542,8 +552,8 @@ An `import` expression names the scope of some other file: {% highlight capnp %} struct Foo { - # Use type "Baz" defined in bar.capnp. baz @0 :import "bar.capnp".Baz; + # Use type "Baz" defined in bar.capnp. } {% endhighlight %} @@ -553,8 +563,8 @@ Of course, typically it's more readable to define an alias: using Bar = import "bar.capnp"; struct Foo { - # Use type "Baz" defined in bar.capnp. baz @0 :Bar.Baz; + # Use type "Baz" defined in bar.capnp. } {% endhighlight %} @@ -565,12 +575,15 @@ using import "bar.capnp".Baz; struct Foo { baz @0 :Baz; + # Use type "Baz" defined in bar.capnp. } {% endhighlight %} The above imports specify relative paths. If the path begins with a `/`, it is absolute -- in this case, the `capnp` tool searches for the file in each of the search path directories specified -with `-I`. +with `-I`, appending the path you specify to the path given to the `-I` flag. So, for example, +if you ran `capnp` with `-Ifoo/bar`, and the import statement is `import "/baz/qux.capnp"`, then +the compiler would open the file `foo/bar/baz/qux.capnp`. ### Annotations @@ -584,22 +597,23 @@ removes all of these hidden fields. You may declare annotations and use them like so: {% highlight capnp %} -# Declare an annotation 'foo' which applies to struct and enum types. annotation foo(struct, enum) :Text; +# Declare an annotation 'foo' which applies to struct and enum types. -# Apply 'foo' to to MyType. struct MyType $foo("bar") { + # Apply 'foo' to to MyType. + # ... } {% endhighlight %} -The possible targets for an annotation are: `file`, `struct`, `field`, `union`, `enum`, `enumerant`, -`interface`, `method`, `parameter`, `annotation`, `const`. You may also specify `*` to cover them -all. +The possible targets for an annotation are: `file`, `struct`, `field`, `union`, `group`, `enum`, +`enumerant`, `interface`, `method`, `param`, `annotation`, `const`. +You may also specify `*` to cover them all. {% highlight capnp %} -# 'baz' can annotate anything! annotation baz(*) :Int32; +# 'baz' can annotate anything! $baz(1); # Annotate the file. @@ -654,8 +668,8 @@ A Cap'n Proto file must have a unique 64-bit ID, and each type and annotation de also have an ID. Use `capnp id` to generate a new ID randomly. ID specifications begin with `@`: {% highlight capnp %} -# file ID @0xdbb9ad1f14bf0b36; +# file ID struct Foo @0x8db435604d0d3723 { # ... @@ -724,36 +738,62 @@ without changing the [canonical](encoding.html#canonicalization) encoding of a m * A field can be moved into a group or a union, as long as the group/union and all other fields within it are new. In other words, a field can be replaced with a group or union containing an - equivalent field and some new fields. + equivalent field and some new fields. Note that when creating a union this way, this particular + change is not fully forwards-compatible: if you create a message where one of the union's new + fields are set, and the message is read by an old program that dosen't know about the union, then + it may expect the original field to be present, and if it tries to read that field, may see a + garbage value or throw an exception. To avoid this problem, make sure to only use the new union + members when talking to programs that know about the union. This caveat only applies when moving + an existing field into a new union; adding new fields to an existing union does not create a + problem, because existing programs should already know to check the union's tag (although they + may or may not behave reasonably when the tag has a value they don't recognize). * A non-generic type can be made [generic](#generic-types), and new generic parameters may be added to an existing generic type. Other types used inside the body of the newly-generic type can be replaced with the new generic parameter so long as all existing users of the type are updated to bind that generic parameter to the type it replaced. For example: -
{% highlight capnp %} - struct Map { - entries @0 :List(Entry); - struct Entry { - key @0 :Text; - value @1 :Text; +
struct Map {
+    entries @0 :List(Entry);
+    struct Entry {
+      key @0 :Text;
+      value @1 :Text;
     }
-  }
-  {% endhighlight %}
-  
+ } + + {% comment %} + Highlighter manually invoked because of: https://github.com/jekyll/jekyll/issues/588 + Original code was: + struct Map { + entries @0 :List(Entry); + struct Entry { + key @0 :Text; + value @1 :Text; + } + } + {% endcomment %} Can change to: -
{% highlight capnp %} - struct Map(Key, Value) { - entries @0 :List(Entry); - struct Entry { - key @0 :Key; - value @1 :Value; +
struct Map(Key, Value) {
+    entries @0 :List(Entry);
+    struct Entry {
+      key @0 :Key;
+      value @1 :Value;
     }
-  }
-  {% endhighlight %}
-  
+ } + + {% comment %} + Highlighter manually invoked because of: https://github.com/jekyll/jekyll/issues/588 + Original code was: + struct Map(Key, Value) { + entries @0 :List(Entry); + struct Entry { + key @0 :Key; + value @1 :Value; + } + } + {% endcomment %} As long as all existing uses of `Map` are replaced with `Map(Text, Text)` (and any uses of `Map.Entry` are replaced with `Map(Text, Text).Entry`). diff --git a/doc/otherlang.md b/doc/otherlang.md index d63f3e29b2..b64157391b 100644 --- a/doc/otherlang.md +++ b/doc/otherlang.md @@ -14,23 +14,24 @@ project's documentation for details. ##### Serialization + RPC * [C++](cxx.html) by [@kentonv](https://github.com/kentonv) +* [C#](https://github.com/c80k/capnproto-dotnetcore) by [@c80k](https://github.com/c80k) * [Erlang](http://ecapnp.astekk.se/) by [@kaos](https://github.com/kaos) -* [Go](https://github.com/zombiezen/go-capnproto2) by [@zombiezen](https://github.com/zombiezen) (forked from [@glycerine](https://github.com/glycerine)'s serialization-only version, below) -* [Javascript (Node.js only)](https://github.com/kentonv/node-capnp) by [@kentonv](https://github.com/kentonv) -* [Python](http://jparyani.github.io/pycapnp/) by [@jparyani](https://github.com/jparyani) +* [Go](https://github.com/capnproto/go-capnp) currently maintained by [@zenhack](https://github.com/zenhack) and [@lthibault](https://github.com/lthibault) +* [Haskell](https://github.com/zenhack/haskell-capnp) by [@zenhack](https://github.com/zenhack) +* [JavaScript (Node.js only)](https://github.com/capnproto/node-capnp) by [@kentonv](https://github.com/kentonv) +* [OCaml](https://github.com/capnproto/capnp-ocaml) by [@pelzlpj](https://github.com/pelzlpj) with [RPC](https://github.com/mirage/capnp-rpc) by [@talex5](https://github.com/talex5) +* [Python](http://capnproto.github.io/pycapnp/) by [@jparyani](https://github.com/jparyani) * [Rust](https://github.com/dwrensha/capnproto-rust) by [@dwrensha](https://github.com/dwrensha) ##### Serialization only * [C](https://github.com/opensourcerouting/c-capnproto) by [OpenSourceRouting](https://www.opensourcerouting.org/) / [@eqvinox](https://github.com/eqvinox) (originally by [@jmckaskill](https://github.com/jmckaskill)) -* [C#](https://github.com/mgravell/capnproto-net) by [@mgravell](https://github.com/mgravell) -* [Go](https://github.com/glycerine/go-capnproto) by [@glycerine](https://github.com/glycerine) (originally by [@jmckaskill](https://github.com/jmckaskill)) -* [Java](https://github.com/dwrensha/capnproto-java/) by [@dwrensha](https://github.com/dwrensha) -* [Javascript](https://github.com/popham/capnp-js-base) by [@popham](https://github.com/popham) -* [Javascript](https://github.com/jscheid/capnproto-js) (older, abandoned) by [@jscheid](https://github.com/jscheid) +* [D](https://github.com/capnproto/capnproto-dlang) by [@ThomasBrixLarsen](https://github.com/ThomasBrixLarsen) +* [Java](https://github.com/capnproto/capnproto-java/) by [@dwrensha](https://github.com/dwrensha) +* [JavaScript](https://github.com/capnp-js/plugin/) by [@popham](https://github.com/popham) +* [JavaScript](https://github.com/jscheid/capnproto-js) (older, abandoned) by [@jscheid](https://github.com/jscheid) * [Lua](https://github.com/cloudflare/lua-capnproto) by [CloudFlare](http://www.cloudflare.com/) / [@calio](https://github.com/calio) * [Nim](https://github.com/zielmicha/capnp.nim) by [@zielmicha](https://github.com/zielmicha) -* [OCaml](https://github.com/pelzlpj/capnp-ocaml) by [@pelzlpj](https://github.com/pelzlpj) * [Ruby](https://github.com/cstrahan/capnp-ruby) by [@cstrahan](https://github.com/cstrahan) * [Scala](https://github.com/katis/capnp-scala) by [@katis](https://github.com/katis) @@ -42,9 +43,10 @@ new languages. * [Common Test Framework](https://github.com/kaos/capnp_test) by [@kaos](https://github.com/kaos) * [Sublime Syntax Highlighting](https://github.com/joshuawarner32/capnproto-sublime) by [@joshuawarner32](https://github.com/joshuawarner32) -* [Vim Syntax Highlighting](https://github.com/peter-edge/vim-capnp) by [@peter-edge](https://github.com/peter-edge) - (originally by [@cstrahan](https://github.com/cstrahan)) +* [Vim Syntax Highlighting](https://github.com/cstrahan/vim-capnp) by [@cstrahan](https://github.com/cstrahan) * [Wireshark Dissector Plugin](https://github.com/kaos/wireshark-plugins) by [@kaos](https://github.com/kaos) +* [VS Code Syntax Highlighter](https://marketplace.visualstudio.com/items?itemName=xmonader.vscode-capnp) by [@xmonader](https://github.com/xmonader) +* [IntelliJ Syntax Highlighter](https://github.com/xmonader/sercapnp) by [@xmonader](https://github.com/xmonader) ## Contribute Your Own! @@ -69,7 +71,7 @@ then hands the parse tree off to another binary -- known as a "plugin" -- which Plugins are independent executables (written in any language) which read a description of the schema from standard input and then generate the necessary code. The description is itself a Cap'n Proto message, defined by -[schema.capnp](https://github.com/sandstorm-io/capnproto/blob/master/c%2B%2B/src/capnp/schema.capnp). +[schema.capnp](https://github.com/capnproto/capnproto/blob/master/c%2B%2B/src/capnp/schema.capnp). Specifically, the plugin receives a `CodeGeneratorRequest`, using [standard serialization](encoding.html#serialization-over-a-stream) (not packed). (Note that installing the C++ runtime causes schema.capnp to be placed in @@ -97,8 +99,8 @@ If the user specifies an output directory, the compiler will run the plugin with as the working directory, so you do not need to worry about this. For examples of plugins, take a look at -[capnpc-capnp](https://github.com/sandstorm-io/capnproto/blob/master/c%2B%2B/src/capnp/compiler/capnpc-capnp.c%2B%2B) -or [capnpc-c++](https://github.com/sandstorm-io/capnproto/blob/master/c%2B%2B/src/capnp/compiler/capnpc-c%2B%2B.c%2B%2B). +[capnpc-capnp](https://github.com/capnproto/capnproto/blob/master/c%2B%2B/src/capnp/compiler/capnpc-capnp.c%2B%2B) +or [capnpc-c++](https://github.com/capnproto/capnproto/blob/master/c%2B%2B/src/capnp/compiler/capnpc-c%2B%2B.c%2B%2B). ### Supporting Dynamic Languages diff --git a/doc/push-site.sh b/doc/push-site.sh index 83cfad2bd7..48f89e62c0 100755 --- a/doc/push-site.sh +++ b/doc/push-site.sh @@ -9,7 +9,7 @@ if grep 'localhost:4000' *.md _posts/*.md; then fi if [ "x$(git status --porcelain)" != "x" ]; then - echo -n "git repo has uncommited changes. Continue anyway? (y/N) " >&2 + echo -n "git repo has uncommitted changes. Continue anyway? (y/N) " >&2 read -n 1 YESNO echo >&2 if [ "x$YESNO" != xy ]; then @@ -44,7 +44,7 @@ echo "Regenerating site..." rm -rf _site _site.tar.gz -jekyll build --safe $FUTURE --config $CONFIG +jekyll _3.8.1_ build --safe $FUTURE --config $CONFIG echo -n "Push now? (y/N)" read -n 1 YESNO @@ -52,7 +52,7 @@ echo if [ "x$YESNO" == "xy" ]; then echo "Pushing..." - tar cz --xform='s,_site/,,' _site/* | gce-ss ssh fe --command "cd /var/www/capnproto.org$PREFIX && tar xz" + tar cz --xform='s,_site/,,' _site/* | gce-ss ssh alpha2 --command "cd /var/www/capnproto.org$PREFIX && tar xz" else echo "Push CANCELED" fi diff --git a/doc/roadmap.md b/doc/roadmap.md index 097828609e..6b7fcc1e4d 100644 --- a/doc/roadmap.md +++ b/doc/roadmap.md @@ -50,7 +50,7 @@ these will actually happen; as always, real work is driven by real-world needs. to each struct type. The POCS type would use traditional memory allocation, thus would not support zero-copy, but would support a more traditional and easy-to-use C++ API, including the ability to mutate the object over time without convoluted memory management. POCS types - could be extracted from an inserted into messages with a single copy, allowing them to be + could be extracted from and inserted into messages with a single copy, allowing them to be used easily in non-performance-critical code. * **Multi-threading:** It should be made easy to assign different Cap'n Proto RPC objects to different threads and have them be able to safely call each other. Each thread would still diff --git a/doc/rpc.md b/doc/rpc.md index bb8e2d881e..9ef2a49fe4 100644 --- a/doc/rpc.md +++ b/doc/rpc.md @@ -30,7 +30,7 @@ bar() on the result of the first call". These messages can be sent together -- to wait for the first call to actually return. To make programming to this model easy, in your code, each call returns a "promise". Promises -work much like Javascript promises or promises/futures in other languages: the promise is returned +work much like JavaScript promises or promises/futures in other languages: the promise is returned immediately, but you must later call `wait()` on it, or call `then()` to register an asynchronous callback. @@ -142,7 +142,7 @@ performs as well as we can possibly hope for. #### Example code -[The calculator example](https://github.com/sandstorm-io/capnproto/blob/master/c++/samples/calculator-client.c++) +[The calculator example](https://github.com/capnproto/capnproto/blob/master/c++/samples/calculator-client.c++) uses promise pipelining. Take a look at the client side in particular. ### Distributed Objects @@ -244,7 +244,7 @@ stream protocol, it can easily be layered on top of SSL/TLS or other such protoc The Cap'n Proto RPC protocol is defined in terms of Cap'n Proto serialization schemas. The documentation is inline. See -[rpc.capnp](https://github.com/sandstorm-io/capnproto/blob/master/c++/src/capnp/rpc.capnp). +[rpc.capnp](https://github.com/capnproto/capnproto/blob/master/c++/src/capnp/rpc.capnp). Cap'n Proto's RPC protocol is based heavily on [CapTP](http://www.erights.org/elib/distrib/captp/index.html), the distributed capability protocol diff --git a/doc/slides-2017.05.18/index.md b/doc/slides-2017.05.18/index.md index a0e0937331..ad7ec6ebb6 100644 --- a/doc/slides-2017.05.18/index.md +++ b/doc/slides-2017.05.18/index.md @@ -531,7 +531,7 @@ getAll @3 (page :UInt32 = 0 $httpQuery) $http(method = get); # GET /?page= # Query is optional. -# JSAN (JSON array) repsonse body. +# JSAN (JSON array) response body. {% endhighlight %} diff --git a/highlighting/emacs/README.md b/highlighting/emacs/README.md index ddfc29ff4f..1eeab9eaf1 100644 --- a/highlighting/emacs/README.md +++ b/highlighting/emacs/README.md @@ -9,5 +9,4 @@ capnproto directory lives): ```elisp (add-to-list 'load-path "~/src/capnproto/highlighting/emacs") (require 'capnp-mode) -(add-to-list 'auto-mode-alist '("\\.capnp\\'" . capnp-mode)) ``` diff --git a/highlighting/emacs/capnp-mode.el b/highlighting/emacs/capnp-mode.el index 688bf3850d..2681ca80a8 100644 --- a/highlighting/emacs/capnp-mode.el +++ b/highlighting/emacs/capnp-mode.el @@ -1,9 +1,10 @@ -;;; capnp-mode.el --- major mode for editing Capn' Proto Files +;;; capnp-mode.el --- Major mode for editing Capn' Proto Files ;; This is free and unencumbered software released into the public domain. ;; Author: Brian Taylor ;; Version: 1.0.0 +;; URL: https://github.com/capnproto/capnproto ;;; Commentary: @@ -15,7 +16,6 @@ ;; ;; (add-to-list 'load-path "~/src/capnproto/highlighting/emacs") ;; (require 'capnp-mode) -;; (add-to-list 'auto-mode-alist '("\\.capnp\\'" . capnp-mode)) ;; ;;; Code: @@ -23,12 +23,11 @@ ;; command to comment/uncomment text (defun capnp-comment-dwim (arg) "Comment or uncomment current line or region in a smart way. -For detail, see `comment-dwim'." +For detail, see `comment-dwim' for ARG explanation." (interactive "*P") (require 'newcomment) (let ( - (comment-start "#") (comment-end "") - ) + (comment-start "#") (comment-end "")) (comment-dwim arg))) (defvar capnp--syntax-table @@ -72,4 +71,9 @@ For detail, see `comment-dwim'." (setq mode-name "capnp") (define-key capnp-mode-map [remap comment-dwim] 'capnp-comment-dwim)) +;;;###autoload +(add-to-list 'auto-mode-alist '("\\.capnp\\'" . capnp-mode)) + (provide 'capnp-mode) +;;; capnp-mode.el ends here + diff --git a/highlighting/qtcreator/capnp.xml b/highlighting/qtcreator/capnp.xml index bb5c7812c4..e675e5b063 100644 --- a/highlighting/qtcreator/capnp.xml +++ b/highlighting/qtcreator/capnp.xml @@ -112,7 +112,7 @@ of these, like "keyword" and "type", could be mapped to dsKeyword and dsDataType, but there's a chance the user has mapped the colors for those things to things that would conflict with the manually-defined colors here, which would probably be even more annoying - than having the colors be inconsitent from other languages. So, I use manual colors for + than having the colors be inconsistent from other languages. So, I use manual colors for everything, except comments, which I figure are less likely to have this problem. --> diff --git a/kjdoc/index.md b/kjdoc/index.md new file mode 100644 index 0000000000..96f0ac4f2b --- /dev/null +++ b/kjdoc/index.md @@ -0,0 +1,67 @@ +# Introducing KJ + +KJ is Modern C++'s missing base library. + +## What's wrong with `std`? + +The C++ language has advanced rapidly over the last decade. However, its standard library (`std`) remains a weak point. Most modern languages ship with libraries that have built-in support for common needs, such as making HTTP requests. `std`, meanwhile, not only lacks HTTP, but doesn't even support basic networking. Developers are forced either to depend on low-level, non-portable OS APIs, or pull in a bunch of third-party dependencies with inconsistent styles and quality. + +Worse, `std` was largely designed before C++ best practices were established. Much of it predates C++11, which changed almost everything about how C++ is written. Some critical parts of `std` -- such as the `iostreams` component -- were designed before anyone really knew how to write quality object-oriented code, and are atrociously bad by modern standards. + +Finally, `std` is designed by committee, which has advantages and disadvantages. On one hand, committees are less likely to make major errors in design. However, they also struggle to make bold decisions, and they move slowly. Committees can also lose touch with real-world concerns, over-engineering features that aren't needed while missing essential basics. + +## How is KJ different? + +KJ was designed and implemented primarily by one developer, Kenton Varda. Every feature was designed to solve a real-world need in a project Kenton was working on -- first [Cap'n Proto](https://capnproto.org), then [Sandstorm](https://sandstorm.io), and more recently, [Cloudflare Workers](https://workers.dev). KJ was designed from the beginning to target exclusively Modern C++ (C++11 and later). + +Since its humble beginnings in 2013, KJ has developed a huge range of practical functionality, including: + +* RAII utilities, especially for memory management +* Basic types and data structures: `Array`, `Maybe`, `OneOf`, `Tuple`, `Function`, `Quantity` (unit analysis), `String`, `Vector`, `HashMap`, `HashSet`, `TreeMap`, `TreeSet`, etc. +* Convenient stringification +* Exception/assertion framework with friggin' stack traces +* Event loop framework with `Promise` API inspired by E (which also inspired JavaScript's `Promise`). +* Threads, fibers, mutexes, lazy initialization +* I/O: Clocks, filesystem, networking +* Protocols: HTTP (client and server), TLS (via OpenSSL/BoringSSL), gzip (via libz) +* Parsers: URL, JSON (using Cap'n Proto), parser combinator framework +* Encodings: UTF-8/16/32, base64, hex, URL encoding, C escapes +* Command-line argument parsing +* Unit testing framework +* And more! + +KJ is not always perfectly organized, and admittedly has some quirks. But, it has proven pragmatic and powerful in real-world applications. + +# Getting KJ + +KJ is bundled with Cap'n Proto -- see [installing Cap'n Proto](https://capnproto.org/install.html). KJ is built as a separate set of libraries, so that you can link against it without Cap'n Proto if desired. + +KJ is officially tested on Linux (GCC and Clang), Windows (Visual Studio, MinGW, and Cygwin), MacOS, and Android. It should additionally be easy to get working on any POSIX platform targeted by GCC or Clang. + +# FAQ + +## What does KJ stand for? + +Nothing. + +The name "KJ" was chosen to be a relatively unusual combination of two letters that is easy to type (on both Qwerty and Dvorak layouts). This is important, because users of KJ will find themselves typing `kj::` very frequently. + +## Why reinvent modern `std` features that are well-designed? + +Some features of KJ appear to replace `std` features that were introduced recently with decent, modern designs. Examples include `kj::Own` vs `std::unique_ptr`, `kj::Maybe` vs `std::optional`, and `kj::Promise` vs `std::task`. + +First, in many cases, the KJ feature actually predates the corresponding `std` feature. `kj::Maybe` was one of the first KJ types, introduced in 2013; `std::optional` arrived in C++17. `kj::Promise` was also introduced in 2013; `std::task` is coming in C++20 (with coroutines). + +Second, consistency. KJ uses somewhat different idioms from `std`, resulting in some friction when trying to use KJ and `std` types together. The most obvious friction is aesthetic (e.g. naming conventions), but some deeper issues exist. For example, KJ tries to treat `const` as transitive, especially so that it can be used to help enforce thread-safety. This can lead to subtle problems (e.g. unexpected compiler errors) with `std` containers not designed with transitive constness in mind. KJ also uses a very different [philosophy around exceptions](../style-guide.md#exceptions) compared to `std`; KJ believes exception-free code is a myth, but `std` sometimes requires it. + +Third, even some modern `std` APIs have design flaws. For example, `std::optional`s can be dereferenced without an explicit null check, resulting in a crash if the value is null -- exactly what this type should have existed to prevent! `kj::Maybe`, in contrast, forces you to write an if/else block or an explicit assertion. For another example, `kj::Own` uses dynamic dispatch for deleters, which allows for lots of useful patterns that `std::unique_ptr`'s static dispatch cannot do. + +## Shouldn't modern software be moving away from memory-unsafe languages? + +Probably! + +Similarly, modern software should also move away from type-unsafe languages. Type-unsafety and memory-unsafety are both responsible for a huge number of security bugs. (Think SQL injection for an example of a security bug resulting from type-unsafety.) + +Hence, all other things being equal, I would suggest Rust for new projects. + +But it's rare that all other things are really equal, and you may have your reasons for using C++. KJ is here to help, not to judge. diff --git a/kjdoc/style-guide.md b/kjdoc/style-guide.md new file mode 120000 index 0000000000..4b6893e8c0 --- /dev/null +++ b/kjdoc/style-guide.md @@ -0,0 +1 @@ +../style-guide.md \ No newline at end of file diff --git a/kjdoc/tour.md b/kjdoc/tour.md new file mode 100644 index 0000000000..922e76941f --- /dev/null +++ b/kjdoc/tour.md @@ -0,0 +1,1061 @@ +--- +title: A tour of KJ +--- + +This page is a tour through the functionality provided by KJ. It is intended for developers new to KJ who want to learn the ropes. + +**This page is not an API reference.** KJ's reference documentation is provided by comments in the headers themselves. Keeping reference docs in the headers makes it easy to find them using your editor's "jump to declaration" hotkey. It also ensures that the documentation is never out-of-sync with the version of KJ you are using. + +Core Programming +====================================================================== + +This section covers core KJ features used throughout nearly all KJ-based code. + +Every KJ developer should familiarize themselves at least with this section. + +## Core Utility Functions + +### Argument-passing: move, copy, forward + +`kj::mv` has exactly the same semantics as `std::move`, but takes fewer keystrokes to type. Since this is used extraordinarily often, saving a few keystrokes really makes a legitimate difference. If you aren't familiar with `std::move`, I recommend reading up on [C++11 move semantics](https://stackoverflow.com/questions/3106110/what-is-move-semantics). + +`kj::cp` is invoked in a similar way to `kj::mv`, but explicitly invokes the copy constructor of its argument, returning said copy. This is occasionally useful when invoking a function that wants an rvalue reference as a parameter, which normally requires pass-by-move, but you really want to pass it a copy. + +`kj::fwd`, is equivalent to `std::forward`. It is used to implement [perfect forwarding](https://en.cppreference.com/w/cpp/utility/forward), that is, forwarding arbitrary arguments from a template function into some other function without understanding their types. + +### Deferring code to scope exit + +This macro declares some code which must execute when exiting the current scope (whether normally or by exception). It is essentially a shortcut for declaring a class with a destructor containing said code, and instantiating that destructor. Example: + +```c++ +void processFile() { + int fd = open("file.txt", O_RDONLY); + KJ_ASSERT(fd >= 0); + + // Make sure file is closed on return. + KJ_DEFER(close(fd)); + + // ... do something with the file ... +} +``` + +You can also pass a multi-line block (in curly braces) as the argument to `KJ_DEFER`. + +There is also a non-macro version, `kj::defer`, which takes a lambda as its argument, and returns an object that invokes that lambda on destruction. The returned object has move semantics. This is convenient when the scope of the deferral isn't necessarily exactly function scope, such as when capturing context in a callback. Example: + +```c++ +kj::Function processFile() { + int fd = open("file.txt", O_RDONLY); + KJ_ASSERT(fd >= 0); + + // Make sure file is closed when the returned function + // is eventually destroyed. + auto deferredClose = kj::defer([fd]() { close(fd); }); + + return [fd, deferredClose = kj::mv(deferredClose)] + (int arg) { + // ... do something with fd and arg ... + } +} +``` + +Sometimes, you want a deferred action to occur only when the scope exits normally via `return`, or only when it exits due to an exception. For those purposes, `KJ_ON_SCOPE_SUCCESS` and `KJ_ON_SCOPE_FAILURE` may be used, with the same syntax as `KJ_DEFER`. + +### Size and range helpers + +`kj::size()` accepts a built-in array or a container as an argument, and returns the number of elements. In the case of a container, the container must implement a `.size()` method. The idea is that you can use this to find out how many iterations a range-based `for` loop on that container would execute. That said, in practice `kj::size` is most commonly used with arrays, as a shortcut for something like `sizeof(array) / sizeof(array[0])`. + +```c++ +int arr[15]; +KJ_ASSERT(kj::size(arr) == 15); +``` + +`kj::range(i, j)` returns an iterable that contains all integers from `i` to `j` (including `i`, but not including `j`). This is typically used in `for` loops: + +```c++ +for (auto i: kj::range(5, 10)) { + KJ_ASSERT(i >= 5 && i < 10); +} +``` + +In the very-common case of iterating from zero, `kj::zeroTo(i)` should be used instead of `kj::range(0, i)`, in order to avoid ambiguity about what type of integer should be generated. + +`kj::indices(container)` is equivalent to `kj::zeroTo(kj::size(container))`. This is extremely convenient when iterating over parallel arrays. + +```c++ +KJ_ASSERT(foo.size() == bar.size()); +for (auto i: kj::indices(foo)) { + foo[i] = bar[i]; +} +``` + +`kj::repeat(value, n)` returns an iterable that acts like an array of size `n` where every element is `value`. This is not often used, but can be convenient for string formatting as well as generating test data. + +### Casting helpers + +`kj::implicitCast(value)` is equivalent to `static_cast(value)`, but will generate a compiler error if `value` cannot be implicitly cast to `T`. For example, `static_cast` can be used for both upcasts (derived type to base type) and downcasts (base type to derived type), but `implicitCast` can only be used for the former. + +`kj::downcast(value)` is equivalent to `static_cast(value)`, except that when compiled in debug mode with RTTI available, a runtime check (`dynamic_cast`) will be performed to verify that `value` really has type `T`. Use this in cases where you are casting a base type to a derived type, and you are confident that the object is actually an instance of the derived type. The debug-mode check will help you catch bugs. + +`kj::dynamicDowncastIfAvailable(value)` is like `dynamic_cast(value)` with two differences. First, it returns `kj::Maybe` instead of `T*`. Second, if the program is compiled without RTTI enabled, the function always returns null. This function is intended to be used to implement optimizations, where the code can do something smarter if `value` happens to be of some specific type -- but if RTTI is not available, it is safe to skip the optimization. See [KJ idiomatic use of dynamic_cast](../style-guide.md#dynamic_cast) for more background. + +### Min/max, numeric limits, and special floats + +`kj::min()` and `kj::max()` return the minimum and maximum of the input arguments, automatically choosing the appropriate return type even if the inputs are of different types. + +`kj::minValue` and `kj::maxValue` are special constants that, when cast to an integer type, become the minimum or maximum value of the respective type. For example: + +```c++ +int16_t i = kj::maxValue; +KJ_ASSERT(i == 32767); +``` + +`kj::inf()` evaluates to floating-point infinity, while `kj::nan()` evaluates to floating-point NaN. `kj::isNaN()` returns true if its argument is NaN. + +### Explicit construction and destruction + +`kj::ctor()` and `kj::dtor()` explicitly invoke a constructor or destructor in a way that is readable and convenient. The first argument is a reference to memory where the object should live. + +These functions should almost never be used in high-level code. They are intended for use in custom memory management, or occasionally with unions that contain non-trivial types (but consider using `kj::OneOf` instead). You must understand C++ memory aliasing rules to use these correctly. + +## Ownership and memory management + +KJ style makes heavy use of [RAII](../style-guide.md#raii-resource-acquisition-is-initialization). KJ-based code should never use `new` and `delete` directly. Instead, use the utilities in this section to manage memory in a RAII way. + +### Owned pointers, heap allocation, and disposers + +`kj::Own` is a pointer to a value of type `T` which is "owned" by the holder. When a `kj::Own` goes out-of-scope, the value it points to will (typically) be destroyed and freed. + +`kj::Own` has move semantics. Thus, when used as a function parameter or return type, `kj::Own` indicates that ownership of the object is being transferred. + +`kj::heap(args...)` allocates an object of type `T` on the heap, passing `args...` to its constructor, and returns a `kj::Own`. This is the most common way to create owned objects. + +However, a `kj::Own` does not necessarily refer to a heap object. A `kj::Own` is actually implemented as a pair of a pointer to the object, and a pointer to a `kj::Disposer` object that knows how to destroy it; `kj::Own`'s destructor invokes the disposer. `kj::Disposer` is an abstract interface with many implementations. `kj::heap` uses an implementation that invokes the object's destructor then frees its underlying space from the heap (like `delete` does), but other implementations exist. Alternative disposers allow an application to control memory allocation more precisely when desired. + +Some example uses of disposers include: + +* `kj::fakeOwn(ref)` returns a `kj::Own` that points to `ref` but doesn't actually destroy it. This is useful when you know for sure that `ref` will outlive the scope of the `kj::Own`, and therefore heap allocation is unnecessary. This is common in cases where, for example, the `kj::Own` is being passed into an object which itself will be destroyed before `ref` becomes invalid. It also makes sense when `ref` is actually a static value or global that lives forever. +* `kj::refcounted(args...)` allocates a `T` which uses reference counting. It returns a `kj::Own` that represents one reference to the object. Additional references can be created by calling `kj::addRef(*ptr)`. The object is destroyed when no more `kj::Own`s exist pointing at it. Note that `T` must be a subclass of `kj::Refcounted`. If references may be shared across threads, then atomic refcounting must be used; use `kj::atomicRefcounted(args...)` and inherit `kj::AtomicRefcounted`. Reference counting should be using sparingly; see [KJ idioms around reference counting](../style-guide.md#reference-counting) for a discussion of when it should be used and why it is designed the way it is. +* `kj::attachRef(ref, args...)` returns a `kj::Own` pointing to `ref` that actually owns `args...`, so that when the `kj::Own` goes out-of-scope, the other arguments are destroyed. Typically these arguments are themselves `kj::Own`s or other pass-by-move values that themselves own the object referenced by `ref`. `kj::attachVal(value, args...)` is similar, where `value` is a pass-by-move value rather than a reference; a copy of it will be allocated on the heap. Finally, `ownPtr.attach(args...)` returns a new `kj::Own` pointing to the same value that `ownPtr` pointed to, but such that `args...` are owned as well and will be destroyed together. Attachments are always destroyed after the thing they are attached to. +* `kj::SpaceFor` contains enough space for a value of type `T`, but does not construct the value until its `construct(args...)` method is called. That method returns an `kj::Own`, whose disposer destroys the value. `kj::SpaceFor` is thus a safer way to perform manual construction compared to invoking `kj::ctor()` and `kj::dtor()`. + +These disposers cover most use cases, but you can also implement your own if desired. `kj::Own` features a constructor overload that lets you pass an arbitrary disposer. + +### Arrays + +`kj::Array` is similar to `kj::Own`, but points to (and owns) an array of `T`s. + +A `kj::Array` can be allocated with `kj::heapArray(size)`, if `T` can be default-constructed. Otherwise, you will need to use a `kj::ArrayBuilder` to build the array. First call `kj::heapArrayBuilder(size)`, then invoke the builder's `add(value)` method to add each element, then finally call its `finish()` method to obtain the completed `kj::Array`. `ArrayBuilder` requires that you know the final size before you start; if you don't, you may want to use `kj::Vector` instead. + +Passing a `kj::Array` implies an ownership transfer. If you merely want to pass a pointer to an array, without transferring ownership, use `kj::ArrayPtr`. This type essentially encapsulates a pointer to the beginning of the array, plus its size. Note that a `kj::ArrayPtr` points to _the underlying memory_ backing a `kj::Array`, not to the `kj::Array` itself; thus, moving a `kj::Array` does NOT invalidate any `kj::ArrayPtr`s already pointing at it. You can also construct a `kj::ArrayPtr` pointing to any C-style array (doesn't have to be a `kj::Array`) using `kj::arrayPtr(ptr, size)` or `kj::arrayPtr(beginPtr, endPtr)`. + +Both `kj::Array` and `kj::ArrayPtr` contain a number of useful methods, like `slice()`. Be sure to check out the class definitions for more details. + +## Strings + +A `kj::String` is a segment of text. By convention, this text is expected to be Unicode encoded in UTF-8. But, `kj::String` itself is not Unicode-aware; it is merely an array of `char`s. + +NUL characters (`'\0'`) are allowed to appear anywhere in a string and do not terminate the string. However, as a convenience, the buffer backing a `kj::String` always has an additional NUL character appended to the end (but not counted in the size). This allows the text in a `kj::String` to be passed to legacy C APIs that use NUL-terminated strings without an extra copy; use the `.cStr()` method to get a `const char*` for such cases. (Of course, keep in mind that if the string contains NUL characters other than at the end, legacy C APIs will interpret the string as truncated at that point.) + +`kj::StringPtr` represents a pointer to a `kj::String`. Similar to `kj::ArrayPtr`, `kj::StringPtr` does not point at the `kj::String` object itself, but at its backing array. Thus, moving a `kj::String` does not invalidate any `kj::StringPtr`s. This is a major difference from `std::string`! Moving an `std::string` invalidates all pointers into its backing buffer (including `std::string_view`s), because `std::string` inlines small strings as an optimization. This optimization may seem clever, but means that `std::string` cannot safely be used as a way to hold and transfer ownership of a text buffer. Doing so can lead to subtle, data-dependent bugs; a program might work fine until someone gives it an unusually small input, at which point it segfaults. `kj::String` foregoes this optimization for simplicity. + +Also similar to `kj::ArrayPtr`, a `kj::StringPtr` does not have to point at a `kj::String`. It can be initialized from a string literal or any C-style NUL-terminated `const char*` without making a copy. Also, KJ defines the special literal suffix `_kj` to write a string literal whose type is implicitly `kj::StringPtr`. + +```c++ +// It's OK to initialize a StringPtr from a classic literal. +// No copy is performed; the StringPtr points directly at +// constant memory. +kj::StringPtr foo = "foo"; + +// But if you add the _kj suffix, then you don't even need +// to declare the type. `bar` will implicitly have type +// kj::StringPtr. Also, this version can be declared +// `constexpr`. +constexpr auto bar = "bar"_kj; +``` + +### Stringification + +To allocate and construct a `kj::String`, use `kj::str(args...)`. Each argument is stringified and the results are concatenated to form the final string. (You can also allocate an uninitialized string buffer with `kj::heapString(size)`.) + +```c++ +kj::String makeGreeting(kj::StringPtr name) { + return kj::str("Hello, ", name, "!"); +} +``` + +KJ knows how to stringify most primitive types as well as many KJ types automatically. Note that integers will be stringified in base 10; if you want hexadecimal, use `kj::hex(i)` as the parameter to `kj::str()`. + +You can additionally extend `kj::str()` to work with your own types by declaring a stringification method using `KJ_STRINGIFY`, like so: + +```c++ +enum MyType { A, B, C }; +kj::StringPtr KJ_STRINGIFY(MyType value) { + switch (value) { + case A: return "A"_kj; + case B: return "B"_kj; + case C: return "C"_kj; + } + KJ_UNREACHABLE; +} +``` + +The `KJ_STRINGIFY` declaration should appear either in the same namespace where the type is defined, or in the global scope. The function can return any random-access iterable sequence of `char`, such as a `kj::String`, `kj::StringPtr`, `kj::ArrayPtr`, etc. As an alternative to `KJ_STRINGIFY`, you can also declare a `toString()` method on your type, with the same return type semantics. + +When constructing very large, complex strings -- for example, when writing a code generator -- consider using `kj::StringTree`, which maintains a tree of strings and only concatenates them at the very end. For example, `kj::strTree(foo, kj::strTree(bar, baz)).flatten()` only performs one concatenation, whereas `kj::str(foo, kj::str(bar, baz))` would perform a redundant intermediate concatenation. + +## Core Utility Types + +### Maybes + +`kj::Maybe` is either `nullptr`, or contains a `T`. In KJ-based code, nullable values should always be expressed using `kj::Maybe`. Primitive pointers should never be null. Use `kj::Maybe` instead of `T*` to express that the pointer/reference can be null. + +In order to dereference a `kj::Maybe`, you must use the `KJ_IF_MAYBE` macro, which behaves like an `if` statement. + +```c++ +kj::Maybe maybeI = 123; +kj::Maybe maybeJ = nullptr; + +KJ_IF_MAYBE(i, maybeI) { + // This block will execute, with `i` being a + // pointer into `maybeI`'s value. In a better world, + // `i` would be a reference rather than a pointer, + // but we couldn't find a way to trick the compiler + // into that. + KJ_ASSERT(*i == 123); +} else { + KJ_FAIL_ASSERT("can't get here"); +} + +KJ_IF_MAYBE(j, maybeJ) { + KJ_FAIL_ASSERT("can't get here"); +} else { + // This block will execute. +} +``` + +Note that `KJ_IF_MAYBE` forces you to think about the null case. This differs from `std::optional`, which can be dereferenced using `*`, resulting in undefined behavior if the value is null. + +Similarly, `map()` and `orDefault()` allow transforming and retrieving the stored value in a safe manner without complex control flows. + +Performance nuts will be interested to know that `kj::Maybe` and `kj::Maybe>` are both optimized such that they take no more space than their underlying pointer type, using a literal null pointer to indicate nullness. For other types of `T`, `kj::Maybe` must maintain an extra boolean and so is somewhat larger than `T`. + +### Variant types + +`kj::OneOf` is a variant type that can be assigned to exactly one of the input types. To unpack the variant, use `KJ_SWITCH_ONEOF`: + +```c++ +void handle(kj::OneOf value) { + KJ_SWITCH_ONEOF(value) { + KJ_CASE_ONEOF(i, int) { + // Note that `i` is an lvalue reference to the content + // of the OneOf. This differs from `KJ_IF_MAYBE` where + // the variable is a pointer. + handleInt(i); + } + KJ_CASE_ONEOF(s, kj::String) { + handleString(s); + } + } +} +``` + +Often, in real-world usage, the type of each variant in a `kj::OneOf` is not sufficient to understand its meaning; sometimes two different variants end up having the same type used for different purposes. In these cases, it would be useful to assign a name to each variant. A common way to do this is to define a custom `struct` type for each variant, and then declare the `kj::OneOf` using those: + +```c++ +struct NotStarted { + kj::String filename; +}; +struct Running { + kj::Own file; +}; +struct Done { + kj::String result; +}; + +typedef kj::OneOf State; +``` + +### Functions + +`kj::Function` represents a callable function with the given signature. A `kj::Function` can be initialized from any callable object, such as a lambda, function pointer, or anything with `operator()`. `kj::Function` is useful when you want to write an API that accepts a lambda callback, without defining the API itself as a template. `kj::Function` supports move semantics. + +`kj::ConstFunction` is like `kj::Function`, but is used to indicate that the function should be safe to call from multiple threads. (See [KJ idioms around constness and thread-safety](../style-guide.md#constness).) + +A special optimization type, `kj::FunctionParam`, is like `kj::Function` but designed to be used specifically as the type of a callback parameter to some other function where that callback is only called synchronously; i.e., the callback won't be called anymore after the outer function returns. Unlike `kj::Function`, a `kj::FunctionParam` can be constructed entirely on the stack, with no heap allocation. + +### Vectors (growable arrays) + +Like `std::vector`, `kj::Vector` is an array that supports appending an element in amortized O(1) time. When the underlying backing array is full, an array of twice the size is allocated and all elements moved. + +### Hash/tree maps/sets... and tables + +`kj::HashMap`, `kj::HashSet`, `kj::TreeMap`, and `kj::TreeSet` do what you'd expect, with modern lambda-oriented interfaces that are less awkward than the corresponding STL types. + +All of these types are actually specific instances of the more-general `kj::Table`. A `kj::Table` can have any number of columns (whereas "sets" have exactly 1 and "maps" have exactly 2), and can maintain indexes on multiple columns at once. Each index can be hash-based, tree-based, or a custom index type that you provide. + +Unlike STL's, KJ's hashtable-based containers iterate in a well-defined deterministic order based on the order of insertion and removals. Deterministic behavior is important for reproducibility, which is important not just for debugging, but also in distributed systems where multiple systems must independently reproduce the same state. KJ's hashtable containers are also faster than `libstdc++`'s in benchmarks. + +KJ's tree-based containers use a b-tree design for better memory locality than the more traditional red-black trees. The implementation is tuned to avoid code bloat by keeping most logic out of templates, though this does make it slightly slower than `libstdc++`'s `map` and `set` in benchmarks. + +`kj::hashCode(params...)` computes a hash across all the inputs, appropriate for use in a hash table. It is extensible in a similar fashion to `kj::str()`, by using `KJ_HASHCODE` or defining a `.hashCode()` method on your custom types. `kj::Table`'s hashtable-based index uses `kj::hashCode` to compute hashes. + +## Debugging and Observability + +KJ believes that there is no such thing as bug-free code. Instead, we must expect that our code will go wrong, and try to extract as much information as possible when it does. To that end, KJ provides powerful assertion macros designed for observability. (Be sure also to read about [KJ's exception philosophy](../style-guide.md#exceptions); this section describes the actual APIs involved.) + +### Assertions + +Let's start with the basic assert: + +```c++ +KJ_ASSERT(foo == bar.baz, "the baz is not foo", bar.name, i); +``` + +When `foo == bar.baz` evaluates false, this line will throw an exception with a description like this: + +``` +src/file.c++:52: failed: expected foo == bar.baz [123 == 321]; the baz is not foo; bar.name = "banana"; i = 5 +stack: libqux.so@0x52134 libqux.so@0x16f582 bin/corge@0x12515 bin/corge@0x5552 +``` + +Notice all the information this contains: + +* The file and line number in the source code where the assertion macro was used. +* The condition which failed. +* The stringified values of the operands to the condition, i.e. `foo` and `bar.baz` (shown in `[]` brackets). +* The values of all other parameters passed to the assertion, i.e. `"the baz is not foo"`, `bar.name`, and `i`. For expressions that aren't just string literals, both the expression and the stringified result of evaluating it are shown. +* A numeric stack trace. If possible, the addresses will be given relative to their respective binary, so that ASLR doesn't make traces useless. The trace can be decoded with tools like `addr2line`. If possible, KJ will also shell out to `addr2line` itself to produce a human-readable trace. + +Note that the work of producing an error description happens only in the case that it's needed. If the condition evaluates true, then that is all the work that is done. + +`KJ_ASSERT` should be used in cases where you are checking conditions that, if they fail, represent a bug in the code where the assert appears. On the other hand, when checking for preconditions -- i.e., bugs in the _caller_ of the code -- use `KJ_REQUIRE` instead: + +```c++ +T& operator[](size_t i) { + KJ_REQUIRE(i < size(), "index out-of-bounds"); + // ... +} +``` + +`KJ_REQUIRE` and `KJ_ASSERT` do exactly the same thing; using one or the other is only a matter of self-documentation. + +`KJ_FAIL_ASSERT(...)` should be used instead of `KJ_ASSERT(false, ...)` when you want a branch always to fail. + +Assertions operate exactly the same in debug and release builds. To express a debug-only assertion, you can use `KJ_DASSERT`. However, we highly recommend letting asserts run in production, as they are frequently an invaluable tool for tracking down bugs that weren't covered in testing. + +### Logging + +The `KJ_LOG` macro can be used to log messages meant for the developer or operator without interrupting control flow. + +```c++ +if (foo.isWrong()) { + KJ_LOG(ERROR, "the foo is wrong", foo); +} +``` + +The first parameter is the log level, which can be `INFO`, `WARNING`, `ERROR`, or `FATAL`. By default, `INFO` logs are discarded, while other levels are displayed. For programs whose main function is based on `kj/main.h`, the `-v` flag can be used to enable `INFO` logging. A `FATAL` log should typically be followed by `abort()` or similar. + +Parameters other than the first are stringified in the same manner as with `KJ_ASSERT`. These parameters will not be evaluated at all, though, if the specified log level is not enabled. + +By default, logs go to standard error. However, you can implement a `kj::ExceptionCallback` (in `kj/exception.h`) to capture logs and customize how they are handled. + +### Debug printing + +Let's face it: "printf() debugging" is easy and effective. KJ embraces this with the `KJ_DBG()` macro. + +```c++ +KJ_DBG("hi", foo, bar, baz.qux) +``` + +`KJ_DBG(...)` is equivalent to `KJ_LOG(DEBUG, ...)` -- logging at the `DEBUG` level, which is always enabled. The dedicated macro exists for brevity when debugging. `KJ_DBG` is intended to be used strictly for temporary debugging code that should never be committed. We recommend setting up commit hooks to reject code that contains invocations of `KJ_DBG`. + +### System call error checking + +KJ includes special variants of its assertion macros that convert traditional C API error conventions into exceptions. + +```c++ +int fd; +KJ_SYSCALL(fd = open(filename, O_RDONLY), "couldn't open the document", filename); +``` + +This macro evaluates the first parameter, which is expected to be a system call. If it returns a negative value, indicating an error, then an exception is thrown. The exception description incorporates a description of the error code communicated by `errno`, as well as the other parameters passed to the macro (stringified in the same manner as other assertion/logging macros do). + +Additionally, `KJ_SYSCALL()` will automatically retry calls that fail with `EINTR`. Because of this, it is important that the expression is idempotent. + +Sometimes, you need to handle certain error codes without throwing. For those cases, use `KJ_SYSCALL_HANDLE_ERRORS`: + +```c++ +int fd; +KJ_SYSCALL_HANDLE_ERRORS(fd = open(filename, O_RDONLY)) { + case ENOENT: + // File didn't exist, return null. + return nullptr; + default: + // Some other error. The error code (from errno) is in a local variable `error`. + // `KJ_FAIL_SYSCALL` expects its second parameter to be this integer error code. + KJ_FAIL_SYSCALL("open()", error, "couldn't open the document", filename); +} +``` + +On Windows, two similar macros are available based on Windows API calling conventions: `KJ_WIN32` works with API functions that return a `BOOLEAN`, `HANDLE`, or pointer type. `KJ_WINSOCK` works with Winsock APIs that return negative values to indicate errors. Some Win32 APIs follow neither of these conventions, in which case you will have to write your own code to check for an error and use `KJ_FAIL_WIN32` to turn it into an exception. + +### Alternate exception types + +As described in [KJ's exception philosophy](../style-guide.md#exceptions), KJ supports a small set of exception types. Regular assertions throw `FAILED` exceptions. `KJ_SYSCALL` usually throws `FAILED`, but identifies certain error codes as `DISCONNECTED` or `OVERLOADED`. For example, `ECONNRESET` is clearly a `DISCONNECTED` exception. + +If you wish to manually construct and throw a different exception type, you may use `KJ_EXCEPTION`: + +```c++ +kj::Exception e = KJ_EXCEPTION(DISCONNECTED, "connection lost", addr); +``` + +### Throwing and catching exceptions + +KJ code usually should not use `throw` or `catch` directly, but rather use KJ's wrappers: + +```c++ +// Throw an exception. +kj::Exception e = ...; +kj::throwFatalException(kj::mv(e)); + +// Run some code catching exceptions. +kj::Maybe maybeException = kj::runCatchingExceptions([&]() { + doSomething(); +}); +KJ_IF_MAYBE(e, maybeException) { + // handle exception +} +``` + +These wrappers perform some extra bookkeeping: +* `kj::runCatchingExceptions()` will catch any kind of exception, whether it derives from `kj::Exception` or not, and will do its best to convert it into a `kj::Exception`. +* `kj::throwFatalException()` and `kj::throwRecoverableException()` invoke the thread's current `kj::ExceptionCallback` to throw the exception, allowing apps to customize how exceptions are handled. The default `ExceptionCallback` makes sure to throw the exception in such a way that it can be understood and caught by code looking for `std::exception`, such as the C++ library's standard termination handler. +* These helpers also work, to some extent, even when compiled with `-fno-exceptions` -- see below. (Note that "fatal" vs. "recoverable" exceptions are only different in this case; when exceptions are enabled, they are handled the same.) + +### Supporting `-fno-exceptions` + +KJ strongly recommends using C++ exceptions. However, exceptions are controversial, and many C++ applications are compiled with exceptions disabled. Some KJ-based libraries (especially Cap'n Proto) would like to accommodate such users. To that end, KJ's exception and assertion infrastructure is designed to degrade gracefully when compiled without exception support. In this case, exceptions are split into two types: + +* Fatal exceptions, when compiled with `-fno-exceptions`, will terminate the program when thrown. +* Recoverable exceptions, when compiled with `-fno-exceptions`, will be recorded on the side. Control flow then continues normally, possibly using a dummy value or skipping code which cannot execute. Later, the application can check if an exception has been raised and handle it. + +`KJ_ASSERT`s (and `KJ_REQUIRE`s) are fatal by default. To make them recoverable, add a "recovery block" after the assert: + +```c++ +kj::StringPtr getItem(int i) { + KJ_REQUIRE(i >= 0 && i < items.size()) { + // This is the recovery block. Recover by simply returning an empty string. + return ""; + } + return items[i]; +} +``` + +When the code above is compiled with exceptions enabled, an out-of-bounds index will result in an exception being thrown. But when compiled with `-fno-exceptions`, the function will store the exception off to the side (in KJ), and then return an empty string. + +A recovery block can indicate that control flow should continue normally even in case of error by using a `break` statement. + +```c++ +void incrementBy(int i) { + KJ_REQUIRE(i >= 0, "negative increments not allowed") { + // Pretend the caller passed `0` and continue. + i = 0; + break; + } + + value += i; +} +``` + +**WARNING:** The recovery block is executed even when exceptions are enabled. The exception is thrown upon exit from the block (even if a `return` or `break` statement is present). Therefore, be careful about side effects in the recovery block. Also, note that both GCC and Clang have a longstanding bug where a returned value's destructor is not called if the return is interrupted by an exception being thrown. Therefore, you must not return a value with a non-trivial destructor from a recovery block. + +There are two ways to handle recoverable exceptions: + +* Use `kj::runCatchingExceptions()`. When compiled with `-fno-exceptions`, this function will arrange for any recoverable exception to be stored off to the side. Upon completion of the given lambda, `kj::runCatchingExceptions()` will return the exception. +* Write a custom `kj::ExceptionCallback`, which can handle exceptions in any way you choose. + +Note that while most features of KJ work with `-fno-exceptions`, some of them have not been carefully written for this case, and may trigger fatal exceptions too easily. People relying on this mode will have to tread carefully. + +### Exceptions in Destructors + +Bugs can occur anywhere -- including in destructors. KJ encourages applications to detect bugs using assertions, which throw exceptions. As a result, exceptions can be thrown in destructors. There is no way around this. You cannot simply declare that destructors shall not have bugs. + +Because of this, KJ recommends that all destructors be declared with `noexcept(false)`, in order to negate C++11's unfortunate decision that destructors should be `noexcept` by default. + +However, this does not solve C++'s Most Unfortunate Decision, namely that throwing an exception from a destructor that was called during an unwind from another exception always terminates the program. It is very common for exceptions to cause "secondary" exceptions during unwind. For example, the destructor of a buffered stream might check whether the buffer has been flushed, and raise an exception if it has not, reasoning that this is a serious bug that could lead to data loss. But if the program is already unwinding due to some other exception, then it is likely that the failure to flush the buffer is because of that other exception. The "secondary" exception might as well be ignored. Terminating the program is the worst possible response. + +To work around the MUD, KJ offers two tools: + +First, during unwind from one exception, KJ will handle all "recoverable" exceptions as if compiled with `-fno-exceptions`, described in the previous section. So, whenever writing assertions in destructors, it is a good idea to give them a recovery block like `{break;}` or `{return;}`. + +```c++ +BufferedStream::~BufferedStream() noexcept(false) { + KJ_REQUIRE(buffer.size() == 0, "buffer was not flushed; possible data loss") { + // Don't throw if we're unwinding! + break; + } +} +``` + +Second, `kj::UnwindDetector` can be used to squelch exceptions during unwind. This is especially helpful in cases where your destructor needs to call complex external code that wasn't written with destructors in mind. Use it like so: + +```c++ +class Transaction { +public: + // ... + +private: + kj::UnwindDetector unwindDetector; + // ... +}; + +Transaction::~Transaction() noexcept(false) { + unwindDetector.catchExceptionsIfUnwinding([&]() { + if (!committed) { + rollback(); + } + }); +} +``` + +Core Systems +====================================================================== + +This section describes KJ APIs that control process execution and low-level interactions with the operating system. Most users of KJ will need to be familiar with most of this section. + +## Threads and Synchronization + +`kj::Thread` creates a thread in which the lambda passed to `kj::Thread`'s constructor will be executed. `kj::Thread`'s destructor waits for the thread to exit before continuing, and rethrows any exception that had been thrown from the thread's main function -- unless the thread's `.detach()` method has been called, in which case `kj::Thread`'s destructor does nothing. + +`kj::MutexGuarded` holds an instance of `T` that is protected by a mutex. In order to access the protected value, you must first create a lock. `.lockExclusive()` returns `kj::Locked` which can be used to access the underlying value. `.lockShared()` returns `kj::Locked`, [using constness to enforce thread-safe read-only access](../style-guide.md#constness) so that multiple threads can take the lock concurrently. In this way, KJ mutexes make it difficult to forget to take a lock before accessing the protected object. + +`kj::Locked` has a method `.wait(cond)` which temporarily releases the lock and waits, taking the lock back as soon as `cond(value)` evaluates true. This provides a much cleaner and more readable interface than traditional conditional variables. + +`kj::Lazy` is an instance of `T` that is constructed on first access in a thread-safe way. + +Macros `KJ_TRACK_LOCK_BLOCKING` and `KJ_SAVE_ACQUIRED_LOCK_INFO` can be used to enable support utilities to implement deadlock detection & analysis. +* `KJ_TRACK_LOCK_BLOCKING`: When the current thread is doing a blocking synchronous KJ operation, that operation is available via `kj::blockedReason()` (intention is for this to be invoked from the signal handler running on the thread that's doing the synchronous operation). +* `KJ_SAVE_ACQUIRED_LOCK_INFO`: When enabled, lock acquisition will save state about the location of the acquired lock. When combined with `KJ_TRACK_LOCK_BLOCKING` this can be particularly helpful because any watchdog can just forward the signal to the thread that's holding the lock. +## Asynchronous Event Loop + +### Promises + +KJ makes asynchronous programming manageable using an API modeled on E-style Promises. E-style Promises were also the inspiration for JavaScript Promises, so modern JavaScript programmers should find KJ Promises familiar, although there are some important differences. + +A `kj::Promise` represents an asynchronous background task that, upon completion, either "resolves" to a value of type `T`, or "rejects" with an exception. + +In the simplest case, a `kj::Promise` can be directly constructed from an instance of `T`: + +```c++ +int i = 123; +kj::Promise promise = i; +``` + +In this case, the promise is immediately resolved to the given value. + +A promise can also immediately reject with an exception: + +```c++ +kj::Exception e = KJ_EXCEPTION(FAILED, "problem"); +kj::Promise promise = kj::mv(e); +``` + +Of course, `Promise`s are much more interesting when they don't complete immediately. + +When a function returns a `Promise`, it means that the function performs some asynchronous operation that will complete in the future. These functions are always non-blocking -- they immediately return a `Promise`. The task completes asynchronously on the event loop. The eventual results of the promise can be obtained using `.then()` to register a callback, or, in certain situations, `.wait()` to synchronously wait. These are described in more detail below. + +### Basic event loop setup + +In order to execute `Promise`-based code, the thread must be running an event loop. Typically, at the top level of the thread, you would do something like: + +```c++ +kj::AsyncIoContext io = kj::setupAsyncIo(); + +kj::AsyncIoProvider& ioProvider = *io.provider; +kj::LowLevelAsyncIoProvider& lowLevelProvider = *io.lowLevelProvider; +kj::WaitScope& waitScope = io.waitScope; +``` + +`kj::setupAsyncIo()` constructs and returns a bunch of objects: + +* A `kj::AsyncIoProvider`, which provides access to a variety of I/O APIs, like timers, pipes, and networking. +* A `kj::LowLevelAsyncIoProvider`, which allows you to wrap existing low-level operating system handles (Unix file descriptors, or Windows `HANDLE`s) in KJ asynchronous interfaces. +* A `kj::WaitScope`, which allows you to perform synchronous waits (see next section). +* OS-specific interfaces for even lower-level access -- see the API definition for more details. + +In order to implement all this, KJ will set up the appropriate OS-specific constructs to handle I/O events on the host platform. For example, on Linux, KJ will use `epoll`, whereas on Windows, it will set up an I/O Completion Port. + +Sometimes, you may need KJ promises to cooperate with some existing event loop, rather than set up its own. For example, you might be using libuv, or Boost.Asio. Usually, a thread can only have one event loop, because it can only wait on one OS event queue (e.g. `epoll`) at a time. To accommodate this, it is possible (though not easy) to adapt KJ to run on top of some other event loop, by creating a custom implementation of `kj::EventPort`. The details of how to do this are beyond the scope of this document. + +Sometimes, you may find that you don't really need to perform operating system I/O at all. For example, a unit test might only need to call some asynchronous functions using mock I/O interfaces, or a thread in a multi-threaded program may only need to exchange events with other threads and not the OS. In these cases, you can create a simple event loop instead: + +```c++ +kj::EventLoop eventLoop; +kj::WaitScope waitScope(eventLoop); +``` + +### Synchronous waits + +In the top level of your program (or thread), the program is allowed to synchronously wait on a promise using the `kj::WaitScope` (see above). + +``` +kj::Timer& timer = io.provider->getTimer(); +kj::Promise promise = timer.afterDelay(5 * kj::SECONDS); +promise.wait(waitScope); // returns after 5 seconds' delay +``` + +`promise.wait()` will run the thread's event loop until the promise completes. It will then return the `Promise`'s result (or throw the `Promise`'s exception). `.wait()` consumes the `Promise`, as if the `Promise` has been moved away. + +Synchronous waits cannot be nested -- i.e. a `.then()` callback (see below) that is called by the event loop itself cannot execute another level of synchronous waits. Hence, synchronous waits generally can only be used at the top level of the thread. The API requires passing a `kj::WaitScope` to `.wait()` as a way to demonstrate statically that the caller is allowed to perform synchronous waits. Any function which wishes to perform synchronous waits must take a `kj::WaitScope&` as a parameter to indicate that it does this. + +Synchronous waits often make sense to use in "client" programs that only have one task to complete before they exit. On the other end of the spectrum, server programs that handle many clients generally must do everything asynchronously. At the top level of a server program, you will typically instruct the event loop to run forever, like so: + +```c++ +// Run event loop forever, do everything asynchronously. +kj::NEVER_DONE.wait(waitScope); +``` + +Libraries should always be asynchronous, so that either kind of program can use them. + +### Asynchronous callbacks + +Similar to JavaScript promises, you may register a callback to call upon completion of a KJ promise using `.then()`: + +```c++ +kj::Promise textPromise = stream.readAllText(); +kj::Promise lineCountPromise = textPromise + .then([](kj::String text) { + int lineCount = 0; + for (char c: text) { + if (c == '\n') { + ++lineCount; + } + } + return lineCount; +}); +``` + +`promise.then()` takes, as its argument, a lambda which transforms the result of the `Promise`. It returns a new `Promise` for the transformed result. We call this lambda a "continuation". + +Calling `.then()`, like `.wait()`, consumes the original promise, as if it were "moved away". Ownership of the original promise is transferred into the new, derived promise. If you want to register multiple continuations on the same promise, you must fork it first (see below). + +If the continuation itself returns another `Promise`, then the `Promise`s become chained. That is, the final type is reduced from `Promise>` to just `Promise`. + +```c++ +kj::Promise> connectPromise = + networkAddress.connect(); +kj::Promise textPromise = connectPromise + .then([](kj::Own stream) { + return stream->readAllText().attach(kj::mv(stream)); +}); +``` + +If a promise rejects (throws an exception), then the exception propagates through `.then()` to the new derived promise, without calling the continuation. If you'd like to actually handle the exception, you may pass a second lambda as the second argument to `.then()`. + +```c++ +kj::Promise promise = networkAddress.connect() + .then([](kj::Own stream) { + return stream->readAllText().attach(kj::mv(stream)); +}, [](kj::Exception&& exception) { + return kj::str("connection error: ", exception); +}); +``` + +You can also use `.catch_(errorHandler)`, which is a shortcut for `.then(identityFunction, errorHandler)`. + +### `kj::evalNow()`, `kj::evalLater()`, and `kj::evalLast()` + +These three functions take a lambda as the parameter, and return the result of evaluating the lambda. They differ in when, exactly, the execution happens. + +```c++ +kj::Promise promise = kj::evalLater([]() { + int i = doSomething(); + return i; +}); +``` + +As with `.then()` continuations, the lambda passed to these functions may itself return a `Promise`. + +`kj::evalNow()` executes the lambda immediately -- before `evalNow()` even returns. The purpose of `evalNow()` is to catch any exceptions thrown and turn them into a rejected promise. This is often a good idea when you don't want the caller to have to handle both synchronous and asynchronous exceptions -- wrapping your whole function in `kj::evalNow()` ensures that all exceptions are delivered asynchronously. + +`kj::evalLater()` executes the lambda on a future turn of the event loop. This is equivalent to `kj::Promise().then()`. + +`kj::evalLast()` arranges for the lambda to be called only after all other work queued to the event loop has completed (but before querying the OS for new I/O events). This can often be useful e.g. for batching. For example, if a program tends to make many small write()s to a socket in rapid succession, you might want to add a layer that collects the writes into a batch, then sends the whole batch in a single write from an `evalLast()`. This way, none of the bytes are significantly delayed, but they can still be coalesced. + +If multiple `evalLast()`s exist at the same time, they will execute in last-in-first-out order. If the first one out schedules more work on the event loop, that work will be completed before the next `evalLast()` executes, and so on. + +### Attachments + +Often, a task represented by a `Promise` will require that some object remains alive until the `Promise` completes. In particular, under KJ conventions, unless documented otherwise, any class method which returns a `Promise` inherently expects that the caller will ensure that the object it was called on will remain alive until the `Promise` completes (or is canceled). Put another way, member function implementations may assume their `this` pointer is valid as long as their returned `Promise` is alive. + +You may use `promise.attach(kj::mv(object))` to give a `Promise` direct ownership of an object that must be kept alive until the promise completes. `.attach()`, like `.then()`, consumes the promise and returns a new one of the same type. + +```c++ +kj::Promise> connectPromise = + networkAddress.connect(); +kj::Promise textPromise = connectPromise + .then([](kj::Own stream) { + // We must attach the stream so that it remains alive until `readAllText()` + // is done. The stream will then be discarded. + return stream->readAllText().attach(kj::mv(stream)); +}); +``` + +Using `.attach()` is semantically equivalent to using `.then()`, passing an identity function as the continuation, while having that function capture ownership of the attached object, i.e.: + +```c++ +// This... +promise.attach(kj::mv(attachment)); +// ...is equivalent to this... +promise.then([a = kj::mv(attachment)](auto x) { return kj::mv(x); }); +``` + +Note that you can use `.attach()` together with `kj::defer()` to construct a "finally" block -- code which will execute after the promise completes (or is canceled). + +```c++ +promise = promise.attach(kj::defer([]() { + // This code will execute when the promise completes or is canceled. +})); +``` + +### Background tasks + +If you construct a `Promise` and then just leave it be without calling `.then()` or `.wait()` to consume it, the task it represents will nevertheless execute when the event loop runs, "in the background". You can call `.then()` or `.wait()` later on, when you're ready. This makes it possible to run multiple concurrent tasks at once. + +Note that, when possible, KJ evaluates continuations lazily. Continuations which merely transform the result (without returning a new `Promise` that might require more waiting) are only evaluated when the final result is actually needed. This is an optimization which allows a long chain of `.then()`s to be executed all at once, rather than turning the event loop for each one. However, it can lead to some confusion when storing an unconsumed `Promise`. For example: + +```c++ +kj::Promise promise = timer.afterDelay(5 * kj::SECONDS) + .then([]() { + // This log line will never be written, because nothing + // is waiting on the final result of the promise. + KJ_LOG(WARNING, "It has been 5 seconds!!!"); +}); +kj::NEVER_DONE.wait(waitScope); +``` + +To solve this, use `.eagerlyEvaluate()`: + +```c++ +kj::Promise promise = timer.afterDelay(5 * kj::SECONDS) + .then([]() { + // This log will correctly be written after 5 seconds. + KJ_LOG(WARNING, "It has been 5 seconds!!!"); +}).eagerlyEvaluate([](kj::Exception&& exception) { + KJ_LOG(ERROR, exception); +}); +kj::NEVER_DONE.wait(waitScope); +``` + +`.eagerlyEvaluate()` takes an error handler callback as its parameter, with the same semantics as `.catch_()` or the second parameter to `.then()`. This is required because otherwise, it is very easy to forget to install an error handler on background tasks, resulting in errors being silently discarded. However, if you are certain that errors will be properly handled elsewhere, you may pass `nullptr` as the parameter to skip error checking -- this is equivalent to passing a callback that merely re-throws the exception. + +If you have lots of background tasks, use `kj::TaskSet` to manage them. Any promise added to a `kj::TaskSet` will be run to completion (with eager evaluation), with any exceptions being reported to a provided error handler callback. + +### Cancellation + +If you destroy a `Promise` before it has completed, any incomplete work will be immediately canceled. + +Upon cancellation, no further continuations are executed at all, not even error handlers. Only destructors are executed. Hence, when there is cleanup that must be performed after a task, it is not sufficient to use `.then()` to perform the cleanup in continuations. You must instead use `.attach()` to attach an object whose destructor performs the cleanup (or perhaps `.attach(kj::defer(...))`, as mentioned earlier). + +Promise cancellation has proven to be an extremely useful feature of KJ promises which is missing in other async frameworks, such as JavaScript's. However, it places new responsibility on the developer. Just as developers who allow exceptions must design their code to be "exception safe", developers using KJ promises must design their code to be "cancellation safe". + +It is especially important to note that once a promise has been canceled, then any references that were received along with the promise may no longer be valid. For example, consider this function: + +``` +kj::Promise write(kj::ArrayPtr data); +``` + +The function receives a pointer to some data owned elsewhere. By KJ convention, the caller must ensure this pointer remains valid until the promise completes _or is canceled_. If the caller decides it needs to free the data early, it may do so as long as it cancels the promise first. This property is important as otherwise it becomes impossible to reason about ownership in complex systems. + +This means that the implementation of `write()` must immediately stop using `data` as soon as cancellation occurs. For example, if `data` has been placed in some sort of queue where some other concurrent task takes items from the queue to write them, then it must be ensured that `data` will be removed from that queue upon cancellation. This "queued writes" pattern has historically been a frequent source of bugs in KJ code, to the point where experienced KJ developers now become immediately suspicious of such queuing. The `kj::AsyncOutputStream` interface explicitly prohibits overlapping calls to `write()` specifically so that the implementation need not worry about maintaining queues. + +### Promise-Fulfiller Pairs and Adapted Promises + +Sometimes, it's difficult to express asynchronous control flow as a simple chain of continuations. For example, imagine a producer-consumer queue, where producers and consumers are executing concurrently on the same event loop. The consumer doesn't directly call the producer, nor vice versa, but the consumer would like to wait for the producer to produce an item for consumption. + +For these situations, you may use a `Promise`-`Fulfiller` pair. + +```c++ +kj::PromiseFulfillerPair paf = kj::newPromiseAndFulfiller(); + +// Consumer waits for the promise. +paf.promise.then([](int i) { ... }); + +// Producer calls the fulfiller to fulfill the promise. +paf.fulfiller->fulfill(123); + +// Producer can also reject the promise. +paf.fulfiller->reject(KJ_EXCEPTION(FAILED, "something went wrong")); +``` + +**WARNING! DANGER!** When using promise-fulfiller pairs, it is very easy to forget about both exception propagation and, more importantly, cancellation-safety. + +* **Exception-safety:** If your code stops early due to an exception, it may forget to invoke the fulfiller. Upon destroying the fulfiller, the consumer end will receive a generic, unhelpful exception, merely saying that the fulfiller was destroyed unfulfilled. To aid in debugging, you should make sure to catch exceptions and call `fulfiller->reject()` to propagate them. +* **Cancellation-safety:** Either the producer or the consumer task could be canceled, and you must consider how this affects the other end. + * **Canceled consumer:** If the consumer is canceled, the producer may waste time producing an item that no one is waiting for. Or, worse, if the consumer has provided references to the producer (for example, a buffer into which results should be written), those references may become invalid upon cancellation, but the producer will continue executing, possibly resulting in a use-after-free. To avoid these problems, the producer can call `fulfiller->isWaiting()` to check if the consumer is still waiting -- this method returns false if either the consumer has been canceled, or if the producer has already fulfilled or rejected the promise previously. However, `isWaiting()` requires polling, which is not ideal. For better control, consider using an adapted promise (see below). + * **Canceled producer:** If the producer is canceled, by default it will probably destroy the fulfiller without fulfilling or reject it. As described previously, the consumer will receive a non-descript exception, which is likely unhelpful for debugging. To avoid this scenario, the producer could perhaps use `.attach(kj::defer(...))` with a lambda that checks `fulfiller->isWaiting()` and rejects it if not. + +Because of the complexity of the above issues, it is generally recommended that you **avoid promise-fulfiller pairs** except in cases where these issues very clearly don't matter (such as unit tests). + +Instead, when cancellation concerns matter, consider using "adapted promises", a more sophisticated alternative. `kj::newAdaptedPromise()` constructs an instance of the class `Adapter` (which you define) encapsulated in a returned `Promise`. `Adapter`'s constructor receives a `kj::PromiseFulfiller&` used to fulfill the promise. The constructor should then register the fulfiller with the desired producer. If the promise is canceled, `Adapter`'s destructor will be invoked, and should un-register the fulfiller. One common technique is for `Adapter` implementations to form a linked list with other `Adapter`s waiting for the same producer. Adapted promises make consumer cancellation much more explicit and easy to handle, at the expense of requiring more code. + +### Loops + +Promises, due to their construction, don't lend themselves easily to classic `for()`/`while()` loops. Instead, loops should be expressed recursively, as in a functional language. For example: + +```c++ +kj::Promise boopEvery5Seconds(kj::Timer& timer) { + return timer.afterDelay(5 * kj::SECONDS).then([&timer]() { + boop(); + // Loop by recursing. + return boopEvery5Seconds(timer); + }); +} +``` + +KJ promises include "tail call optimization" for loops like the one above, so that the promise chain length remains finite no matter how many times the loop iterates. + +**WARNING!** It is very easy to accidentally break tail call optimization, creating a memory leak. Consider the following: + +```c++ +kj::Promise boopEvery5Seconds(kj::Timer& timer) { + // WARNING! MEMORY LEAK! + return timer.afterDelay(5 * kj::SECONDS).then([&timer]() { + boop(); + // Loop by recursing. + return boopEvery5Seconds(timer); + }).catch_([](kj::Exception&& exception) { + // Oh no, an error! Log it and end the loop. + KJ_LOG(ERROR, exception); + kj::throwFatalException(kj::mv(exception)); + }); +} +``` + +The problem in this example is that the recursive call is _not_ a tail call, due to the `.catch_()` appended to the end. Every time around the loop, a new `.catch_()` is added to the promise chain. If an exception were thrown, that exception would end up being logged many times -- once for each time the loop has repeated so far. Or if the loop iterated enough times, and the top promise was then canceled, the chain could be so long that the destructors overflow the stack. + +In this case, the best fix is to pull the `.catch_()` out of the loop entirely: + +```c++ +kj::Promise boopEvery5Seconds(kj::Timer& timer) { + return boopEvery5SecondsLoop(timer) + .catch_([](kj::Exception&& exception) { + // Oh no, an error! Log it and end the loop. + KJ_LOG(ERROR, exception); + kj::throwFatalException(kj::mv(exception)); + }) +} + +kj::Promise boopEvery5SecondsLoop(kj::Timer& timer) { + // No memory leaks now! + return timer.afterDelay(5 * kj::SECONDS).then([&timer]() { + boop(); + // Loop by recursing. + return boopEvery5SecondsLoop(timer); + }); +} +``` + +Another possible fix would be to make sure the recursive continuation and the error handler are passed to the same `.then()` invocation: + +```c++ +kj::Promise boopEvery5Seconds(kj::Timer& timer) { + // No more memory leaks, but hard to reason about. + return timer.afterDelay(5 * kj::SECONDS).then([&timer]() { + boop(); + }).then([&timer]() { + // Loop by recursing. + return boopEvery5Seconds(timer); + }, [](kj::Exception&& exception) { + // Oh no, an error! Log it and end the loop. + KJ_LOG(ERROR, exception); + kj::throwFatalException(kj::mv(exception)); + }); +} +``` + +Notice that in this second case, the error handler is scoped so that it does _not_ catch exceptions thrown by the recursive call; it only catches exceptions from `boop()`. This solves the problem, but it's a bit trickier to understand and to ensure that exceptions can't accidentally slip past the error handler. + +### Forking and splitting promises + +As mentioned above, `.then()` and similar functions consume the promise on which they are called, so they can only be called once. But what if you want to start multiple tasks using the result of a promise? You could solve this in a convoluted way using adapted promises, but KJ has a built-in solution: `.fork()` + +```c++ +kj::Promise promise = ...; +kj::ForkedPromise forked = promise.fork(); +kj::Promise branch1 = promise.addBranch(); +kj::Promise branch2 = promise.addBranch(); +kj::Promise branch3 = promise.addBranch(); +``` + +A forked promise can have any number of "branches" which represent different consumers waiting for the same result. + +Forked promises use reference counting. The `ForkedPromise` itself, and each branch created from it, each represent a reference to the original promise. The original promise will only be canceled if all branches are canceled and the `ForkedPromise` itself is destroyed. + +Forked promises require that the result type has a copy constructor, so that it can be copied to each branch. (Regular promises only require the result type to be movable, not copyable.) Or, alternatively, if the result type is `kj::Own` -- which is never copyable -- then `T` must have a method `kj::Own T::addRef()`; this method will be invoked to create each branch. Typically, `addRef()` would be implemented using reference counting. + +Sometimes, the copyable requirement of `.fork()` can be burdensome and unnecessary. If the result type has multiple components, and each branch really only needs one of the components, then being able to copy (or refcount) is unnecessary. In these cases, you can use `.split()` instead. `.split()` converts a promise for a `kj::Tuple` into a `kj::Tuple` of promises. That is: + +```c++ +kj::Promise, kj::String>> promise = ...; +kj::Tuple>, kj::Promise> promises = promise.split(); +``` + +### Joining promises + +The opposite of forking promises is joining promises. There are two types of joins: +* **Exclusive** joins wait for any one input promise to complete, then cancel the rest, returning the result of the promise that completed. +* **Inclusive** joins wait for all input promises to complete, and render all of the results. + +For an exclusive join, use `promise.exclusiveJoin(kj::mv(otherPromise))`. The two promises must return the same type. The result is a promise that returns whichever result is produced first, and cancels the other promise at that time. (To exclusively join more than two promises, call `.exclusiveJoin()` multiple times in a chain.) + +To perform an inclusive join, use `kj::joinPromises()` or `kj::joinPromisesFailFast()`. These turn a `kj::Array>` into a `kj::Promise>`. However, note that `kj::joinPromises()` has a couple common gotchas: +* Trailing continuations on the promises passed to `kj::joinPromises()` are evaluated lazily after all the promises become ready. Use `.eagerlyEvaluate()` on each one to force trailing continuations to happen eagerly. (See earlier discussion under "Background Tasks".) +* If any promise in the array rejects, the exception will be held until all other promises have completed (or rejected), and only then will the exception propagate. In practice we've found that most uses of `kj::joinPromises()` would prefer "exclusive" or "fail-fast" behavior in the case of an exception. + +`kj::joinPromisesFailFast()` addresses the gotchas described above: promise continuations are evaluated eagerly, and if any promise results in an exception, the join promise is immediately rejected with that exception. + +### Threads + +The KJ async framework is designed around single-threaded event loops. However, you can have multiple threads, with each running its own loop. + +All KJ async objects, unless specifically documented otherwise, are intrinsically tied to the thread and event loop on which they were created. These objects must not be accessed from any other thread. + +To communicate between threads, you may use `kj::Executor`. Each thread (that has an event loop) may call `kj::getCurrentThreadExecutor()` to get a reference to its own `Executor`. That reference may then be shared with other threads. The other threads can use the methods of `Executor` to queue functions to execute on the owning thread's event loop. + +The threads which call an `Executor` do not have to have KJ event loops themselves. Thus, you can use an `Executor` to signal a KJ event loop thread from a non-KJ thread. + +### Fibers + +Fibers allow code to be written in a synchronous / blocking style while running inside the KJ event loop, by executing the code on an alternate call stack. The code running on this alternate stack is given a special `kj::WaitScope&`, which it can pass to `promise.wait()` to perform synchronous waits. When such a `.wait()` is invoked, the thread switches back to the main call stack and continues running the event loop there. When the waited promise resolves, execution switches back to the alternate call stack and `.wait()` returns (or throws). + +```c++ +constexpr size_t STACK_SIZE = 65536; +kj::Promise promise = + kj::startFiber(STACK_SIZE, [](kj::WaitScope& waitScope) { + int i = someAsyncFunc().wait(waitScope); + i += anotherAsyncFunc().wait(waitScope); + return i; +}); +``` + +**CAUTION:** Fibers produce attractive-looking code, but have serious drawbacks. Every fiber must allocate a new call stack, which is typically rather large. The above example allocates a 64kb stack, which is the _minimum_ supported size. Some programs and libraries expect to be able to allocate megabytes of data on the stack. On modern Linux systems, a default stack size of 8MB is typical. Stack space is allocated lazily on page faults, but just setting up the memory mapping is much more expensive than a typical `malloc()`. If you create lots of fibers, you should use `kj::FiberPool` to reduce allocation costs -- but while this reduces allocation overhead, it will increase memory usage. + +Because of this, fibers should not be used just to make code look nice (C++20's `co_await`, described below, is a better way to do that). Instead, the main use case for fibers is to be able to call into existing libraries that are not designed to operate in an asynchronous way. For example, say you find a library that performs stream I/O, and lets you provide your own `read()`/`write()` implementations, but expects those implementations to operate in a blocking fashion. With fibers, you can use such a library within the asynchronous KJ event loop. + +### Coroutines + +C++20 brings us coroutines, which, like fibers, allow code to be written in a synchronous / blocking style while running inside the KJ event loop. Coroutines accomplish this with a different strategy than fibers: instead of running code on an alternate stack and switching stacks on suspension, coroutines save local variables and temporary objects in a heap-allocated "coroutine frame" and always unwind the stack on suspension. + +A C++ function is a KJ coroutine if it follows these two rules: +- The function returns a `kj::Promise`. +- The function uses a `co_await` or `co_return` keyword in its implementation. + +```c++ +kj::Promise aCoroutine() { + int i = co_await someAsyncFunc(); + i += co_await anotherAsyncFunc(); + co_return i; +}); + +// Call like any regular promise-returning function. +auto promise = aCoroutine(); +``` + +The promise returned by a coroutine owns the coroutine frame. If you destroy the promise, any objects alive in the frame will be destroyed, and the frame freed, thus cancellation works exactly as you'd expect. + +There are some caveats one should be aware of while writing coroutines: +- Lambda captures **do not** live inside of the coroutine frame, meaning lambda objects must outlive any coroutine Promises they return, or else the coroutine will encounter dangling references to captured objects. This is a defect in the C++ standard: https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rcoro-capture. To safely use a capturing lambda as a coroutine, first wrap it using `kj::coCapture([captures]() { ... })`, then invoke that object. +- Holding a mutex lock across a `co_await` is almost always a bad idea, with essentially the same problems as holding a lock while calling `promise.wait(waitScope)`. This would cause the coroutine to hold the lock for however many turns of the event loop is required to drive the coroutine to release the lock; if I/O is involved, this could cause significant problems. Additionally, a reentrant call to the coroutine on the same thread would deadlock. Instead, if a coroutine must temporarily hold a lock, always keep the lock in a new lexical scope without any `co_await`. +- Attempting to define (and use) a variable-length array will cause a compile error, because the size of coroutine frames must be knowable at compile-time. The error message that clang emits for this, "Coroutines cannot handle non static allocas yet", suggests this may be relaxed in the future. + +As of this writing, KJ supports C++20 coroutines and Coroutines TS coroutines, the latter being an experimental precursor to C++20 coroutines. They are functionally the same thing, but enabled with different compiler/linker flags: + +- Enable C++20 coroutines by requesting that language standard from your compiler. +- Enable Coroutines TS coroutines with `-fcoroutines-ts` in C++17 clang, and `/await` in MSVC. + +KJ prefers C++20 coroutines when both implementations are available. + +### Unit testing tips + +When unit-testing promise APIs, two tricky challenges frequently arise: + +* Testing that a promise has completed when it is supposed to. You can use `promise.wait()`, but if the promise has not completed as expected, then the test may simply hang. This can be frustrating to debug. +* Testing that a promise has not completed prematurely. You obviously can't use `promise.wait()`, because you _expect_ the promise has not completed, and therefore this would hang. You might try using `.then()` with a continuation that sets a flag, but if the flag is not set, it's hard to tell whether this is because the promise really has not completed, or merely because the event loop hasn't yet called the `.then()` continuation. + +To solve these problems, you can use `promise.poll(waitScope)`. This function runs the event loop until either the promise completes, or there is nothing left to do except to wait. This includes running any continuations in the queue as well as checking for I/O events from the operating system, repeatedly, until nothing is left. The only thing `.poll()` will not do is block. `.poll()` returns true if the promise has completed, false if it hasn't. + +```c++ +// In a unit test... +kj::Promise promise = waitForBoop(); + +// The promise should not be done yet because we haven't booped yet. +KJ_ASSERT(!promise.poll(waitScope)); + +boop(); + +// Assert the promise is done, to make sure wait() won't hang! +KJ_ASSERT(promise.poll(waitScope)); + +promise.wait(waitScope); +``` + +Sometimes, you may need to ensure that some promise has completed that you don't have a reference to, so you can observe that some side effect has occurred. You can use `waitScope.poll()` to flush the event loop without waiting for a specific promise to complete. + +## System I/O + +### Async I/O + +On top of KJ's async framework (described earlier), KJ provides asynchronous APIs for byte streams, networking, and timers. + +As mentioned previously, `kj::setupAsyncIo()` allocates an appropriate OS-specific event queue (such as `epoll` on Linux), returning implementations of `kj::AsyncIoProvider` and `kj::LowLevelAsyncIoProvider` implemented in terms of that queue. `kj::AsyncIoProvider` provides an OS-independent API for byte streams, networking, and timers. `kj::LowLevelAsyncIoProvider` allows native OS handles (file descriptors on Unix, `HANDLE`s on Windows) to be wrapped in KJ byte stream APIs, like `kj::AsyncIoStream`. + +Please refer to the API reference (the header files) for details on these APIs. + +### Synchronous I/O + +Although most complex KJ applications use async I/O, sometimes you want something a little simpler. + +`kj/io.h` provides some more basic, synchronous streaming interfaces, like `kj::InputStream` and `kj::OutputStream`. Implementations are provided on top of file descriptors and Windows `HANDLE`s. + +Additionally, the important utility class `kj::AutoCloseFd` (and `kj::AutoCloseHandle` for Windows) can be found here. This is an RAII wrapper around a file descriptor (or `HANDLE`), which you will likely want to use any time you are manipulating raw file descriptors (or `HANDLE`s) in KJ code. + +### Filesystem + +KJ provides an advanced, cross-platform filesystem API in `kj/filesystem.h`. Features include: + +* Paths represented using `kj::Path`. In addition to providing common-sense path parsing and manipulation functions, this class is designed to defend against path injection attacks. +* All interfaces are abstract, allowing multiple implementations. +* An in-memory implementation is provided, useful in particular for mocking the filesystem in unit tests. +* On Unix, disk `kj::Directory` objects are backed by open file descriptors and use the `openat()` family of system calls. +* Makes it easy to use atomic replacement when writing new files -- and even whole directories. +* Symlinks, hard links, listing directories, recursive delete, recursive create parents, recursive copy directory, memory mapping, and unnamed temporary files are all exposed and easy to use. +* Sparse files ("hole punching"), copy-on-write file cloning (`FICLONE`, `FICLONERANGE`), `sendfile()`-based copying, `renameat2()` atomic replacements, and more will automatically be used when available. + +See the API reference (header file) for details. + +### Clocks and time + +KJ provides a time library in `kj/time.h` which uses the type system to enforce unit safety. + +`kj::Duration` represents a length of time, such as a number of seconds. Multiply an integer by `kj::SECONDS`, `kj::MINUTES`, `kj::NANOSECONDS`, etc. to get a `kj::Duration` value. Divide by the appropriate constant to get an integer. + +`kj::Date` represents a point in time in the real world. `kj::UNIX_EPOCH` represents January 1st, 1970, 00:00 UTC. Other dates can be constructed by adding a `kj::Duration` to `kj::UNIX_EPOCH`. Taking the difference between to `kj::Date`s produces a `kj::Duration`. + +`kj::TimePoint` represents a time point measured against an unspecified origin time. This is typically used with monotonic clocks that don't necessarily reflect calendar time. Unlike `kj::Date`, there is no implicit guarantee that two `kj::TimePoint`s are measured against the same origin and are therefore comparable; it is up to the application to track which clock any particular `kj::TimePoint` came from. + +`kj::Clock` is a simple interface whose `now()` method returns the current `kj::Date`. `kj::MonotonicClock` is a similar interface returning a `kj::TimePoint`, but with the guarantee that times returned always increase (whereas a `kj::Clock` might go "back in time" if the user manually modifies their system clock). + +`kj::systemCoarseCalendarClock()`, `kj::systemPreciseCalendarClock()`, `kj::systemCoarseMonotonicClock()`, `kj::systemPreciseMonotonicClock()` are global functions that return implementations of `kj::Clock` or `kJ::MonotonicClock` based on system time. + +`kj::Timer` provides an async (promise-based) interface to wait for a specified time to pass. A `kj::Timer` is provided via `kj::AsyncIoProvider`, constructed using `kj::setupAsyncIo()` (see earlier discussion on async I/O). + +## Program Harness + +TODO: kj::Main, unit test framework + +Libraries +====================================================================== + +TODO: parser combinator framework, HTTP, TLS, URL, encoding, JSON diff --git a/release.sh b/release.sh index 79ec998534..4225682956 100755 --- a/release.sh +++ b/release.sh @@ -2,7 +2,7 @@ set -euo pipefail -if [ "$1" != "package" ]; then +if [ "$1" != "package" ] && [ "$1" != "bump-major" ]; then if (grep -r KJ_DBG c++/src | egrep -v '/debug(-test)?[.]' | grep -v 'See KJ_DBG\.$'); then echo '*** Error: There are instances of KJ_DBG in the code.' >&2 exit 1 @@ -20,7 +20,7 @@ doit() { } get_version() { - local VERSION=$(grep AC_INIT c++/configure.ac | sed -e 's/^[^]]*],\[\([^]]*\)].*$/\1/g') + local VERSION=$(grep '^AC_INIT' c++/configure.ac | sed -e 's/^[^]]*],\[\([^]]*\)].*$/\1/g') if [[ ! "$VERSION" =~ $1 ]]; then echo "Couldn't parse version: $VERSION" >&2 exit 1 @@ -50,7 +50,7 @@ update_version() { c++/src/capnp/common.h local NEW_COMBINED=$(( ${NEW_ARR[0]} * 1000000 + ${NEW_ARR[1]} * 1000 + ${NEW_ARR[2]:-0 })) - doit sed -i -re "s/^#if CAPNP_VERSION != [0-9]*\$/#if CAPNP_VERSION != $NEW_COMBINED/g" \ + doit sed -i -re "s/^#elif CAPNP_VERSION != [0-9]*\$/#elif CAPNP_VERSION != $NEW_COMBINED/g" \ c++/src/*/*.capnp.h c++/src/*/*/*.capnp.h doit git commit -a -m "Set $BRANCH_DESC version to $NEW." @@ -120,7 +120,7 @@ done_banner() { y | Y ) doit git push origin $PUSH doit gce-ss copy-files capnproto-c++-$VERSION.tar.gz capnproto-c++-win32-$VERSION.zip \ - fe:/var/www/capnproto.org + alpha2:/var/www/capnproto.org if [ "$FINAL" = yes ]; then echo "=========================================================================" @@ -146,6 +146,14 @@ done_banner() { BRANCH=$(git rev-parse --abbrev-ref HEAD) case "${1-}:$BRANCH" in + bump-major:* ) + echo "Bump major version number on HEAD." + HEAD_VERSION=$(get_version '^[0-9]+[.][0-9]+-dev$') + OLD_MAJOR=$(echo $HEAD_VERSION | cut -d. -f1) + NEW_VERSION=$(( OLD_MAJOR + 1 )).0-dev + update_version $HEAD_VERSION $NEW_VERSION "mainline" + ;; + # ====================================================================================== candidate:master ) echo "New major release." @@ -177,7 +185,7 @@ case "${1-}:$BRANCH" in declare -a VERSION_ARR=(${RELEASE_VERSION//./ }) NEXT_VERSION=${VERSION_ARR[0]}.$((VERSION_ARR[1] + 1)) - update_version $HEAD_VERSION $NEXT_VERSION-dev "mainlaine" + update_version $HEAD_VERSION $NEXT_VERSION-dev "mainline" done_banner $RELEASE_VERSION-rc1 "master release-$RELEASE_VERSION" no ;; diff --git a/security-advisories/2015-03-02-0-c++-integer-overflow.md b/security-advisories/2015-03-02-0-c++-integer-overflow.md index d25b2a19fc..300647e2fe 100644 --- a/security-advisories/2015-03-02-0-c++-integer-overflow.md +++ b/security-advisories/2015-03-02-0-c++-integer-overflow.md @@ -35,7 +35,7 @@ Fixed in - Unix: https://capnproto.org/capnproto-c++-0.4.1.1.tar.gz - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/f343f0dbd0a2e87f17cd74f14186ed73e3fbdbfa +[0]: https://github.com/capnproto/capnproto/commit/f343f0dbd0a2e87f17cd74f14186ed73e3fbdbfa Details ======= @@ -97,6 +97,6 @@ following preventative measures going forward: I am pleased that measures 1, 2, and 3 all detected this bug, suggesting that they have a high probability of catching any similar bugs. -[1]: https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-0-all-cpu-amplification.md -[2]: https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md +[1]: https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-0-all-cpu-amplification.md +[2]: https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-1-c++-integer-underflow.md [3]: https://capnproto.org/news/2015-03-02-security-advisory-and-integer-overflow-protection.html diff --git a/security-advisories/2015-03-02-1-c++-integer-underflow.md b/security-advisories/2015-03-02-1-c++-integer-underflow.md index 970f8b9aec..06a3cd2f40 100644 --- a/security-advisories/2015-03-02-1-c++-integer-underflow.md +++ b/security-advisories/2015-03-02-1-c++-integer-underflow.md @@ -37,7 +37,7 @@ Fixed in - Unix: https://capnproto.org/capnproto-c++-0.4.1.1.tar.gz - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/26bcceda72372211063d62aab7e45665faa83633 +[0]: https://github.com/capnproto/capnproto/commit/26bcceda72372211063d62aab7e45665faa83633 Details ======= @@ -106,5 +106,5 @@ cleanup, but [check the Cap'n Proto blog for an in-depth discussion][2]. This problem is also caught by capnp/fuzz-test.c++, which *has* been merged into master but likely doesn't have as broad coverage. -[1]: https://github.com/sandstorm-io/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md +[1]: https://github.com/capnproto/capnproto/tree/master/security-advisories/2015-03-02-0-c++-integer-overflow.md [2]: https://capnproto.org/news/2015-03-02-security-advisory-and-integer-overflow-protection.html diff --git a/security-advisories/2015-03-02-2-all-cpu-amplification.md b/security-advisories/2015-03-02-2-all-cpu-amplification.md index 94ad336128..1bc4bccd2a 100644 --- a/security-advisories/2015-03-02-2-all-cpu-amplification.md +++ b/security-advisories/2015-03-02-2-all-cpu-amplification.md @@ -35,7 +35,7 @@ Fixed in - Unix: https://capnproto.org/capnproto-c++-0.4.1.1.tar.gz - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/104870608fde3c698483fdef6b97f093fc15685d +[0]: https://github.com/capnproto/capnproto/commit/104870608fde3c698483fdef6b97f093fc15685d Details ======= diff --git a/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md b/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md index aee7f1782c..bd25698d5e 100644 --- a/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md +++ b/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md @@ -37,7 +37,7 @@ Fixed in - Unix: https://capnproto.org/capnproto-c++-0.4.1.2.tar.gz - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/80149744bdafa3ad4eedc83f8ab675e27baee868 +[0]: https://github.com/capnproto/capnproto/commit/80149744bdafa3ad4eedc83f8ab675e27baee868 Details ======= @@ -55,7 +55,7 @@ loop that doesn't call any application code. Only CPU time is possibly consumed, not RAM or other resources. However, it is still possible to create significant delays for the receiver with a specially-crafted message. -[1]: https://github.com/sandstorm-io/capnproto/blob/master/security-advisories/2015-03-02-2-all-cpu-amplification.md +[1]: https://github.com/capnproto/capnproto/blob/master/security-advisories/2015-03-02-2-all-cpu-amplification.md Preventative measures ===================== diff --git a/security-advisories/2017-04-17-0-apple-clang-elides-bounds-check.md b/security-advisories/2017-04-17-0-apple-clang-elides-bounds-check.md index 683b8e3d0d..49221fc26b 100644 --- a/security-advisories/2017-04-17-0-apple-clang-elides-bounds-check.md +++ b/security-advisories/2017-04-17-0-apple-clang-elides-bounds-check.md @@ -42,7 +42,7 @@ Fixed in - Windows: https://capnproto.org/capnproto-c++-win32-0.5.3.1.zip - release 0.6 (future) -[0]: https://github.com/sandstorm-io/capnproto/commit/52bc956459a5e83d7c31be95763ff6399e064ae4 +[0]: https://github.com/capnproto/capnproto/commit/52bc956459a5e83d7c31be95763ff6399e064ae4 Details ======= @@ -144,4 +144,4 @@ technically-correct solution has been implemented in the next commit, extensive refactoring, it is not appropriate for cherry-picking, and will only land in versions 0.6 and up. -[2]: https://github.com/sandstorm-io/capnproto/commit/2ca8e41140ebc618b8fb314b393b0a507568cf21 +[2]: https://github.com/capnproto/capnproto/commit/2ca8e41140ebc618b8fb314b393b0a507568cf21 diff --git a/security-advisories/2022-11-30-0-pointer-list-bounds.md b/security-advisories/2022-11-30-0-pointer-list-bounds.md new file mode 100644 index 0000000000..50605ce195 --- /dev/null +++ b/security-advisories/2022-11-30-0-pointer-list-bounds.md @@ -0,0 +1,127 @@ +Problem +======= + +Out-of-bounds read due to logic error handling list-of-list. + +Discovered by +============= + +David Renshaw <dwrenshaw@gmail.com>, the maintainer of Cap'n Proto's Rust +implementation, which is affected by the same bug. David discovered this bug +while running his own fuzzer. + +Announced +========= + +2022-11-30 + +CVE +=== + +CVE-2022-46149 + +Impact +====== + +- Remotely segfault a peer by sending it a malicious message, if the victim + performs certain actions on a list-of-pointer type. +- Possible exfiltration of memory, if the victim performs additional certain + actions on a list-of-pointer type. +- To be vulnerable, an application must perform a specific sequence of actions, + described below. At present, **we are not aware of any vulnerable + application**, but we advise updating regardless. + +Fixed in +======== + +Unfortunately, the bug is present in inlined code, therefore the fix will +require rebuilding dependent applications. + +C++ fix: + +- git commit [25d34c67863fd960af34fc4f82a7ca3362ee74b9][0] +- release 0.11 (future) +- release 0.10.3: + - Unix: https://capnproto.org/capnproto-c++-0.10.3.tar.gz + - Windows: https://capnproto.org/capnproto-c++-win32-0.10.3.zip +- release 0.9.2: + - Unix: https://capnproto.org/capnproto-c++-0.9.2.tar.gz + - Windows: https://capnproto.org/capnproto-c++-win32-0.9.2.zip +- release 0.8.1: + - Unix: https://capnproto.org/capnproto-c++-0.8.1.tar.gz + - Windows: https://capnproto.org/capnproto-c++-win32-0.8.1.zip +- release 0.7.1: + - Unix: https://capnproto.org/capnproto-c++-0.7.1.tar.gz + - Windows: https://capnproto.org/capnproto-c++-win32-0.7.1.zip + +Rust fix: + +- `capnp` crate version `0.15.2`, `0.14.11`, or `0.13.7`. + +[0]: https://github.com/capnproto/capnproto/commit/25d34c67863fd960af34fc4f82a7ca3362ee74b9 + +Details +======= + +A specially-crafted pointer could escape bounds checking by exploiting +inconsistent handling of pointers when a list-of-structs is downgraded to a +list-of-pointers. + +For an in-depth explanation of how this bug works, see [David Renshaw's +blog post][1]. This details below focus only on determining whether an +application is vulnerable. + +In order to be vulnerable, an application must have certain properties. + +First, the application must accept messages with a schema in which a field has +list-of-pointer type. This includes `List(Text)`, `List(Data)`, +`List(List(T))`, or `List(C)` where `C` is an interface type. In the following +discussion, we will assume this field is named `foo`. + +Second, the application must accept a message of this schema from a malicious +source, where the attacker can maliciously encode the pointer representing the +field `foo`. + +Third, the application must call `getFoo()` to obtain a `List::Reader` for +the field, and then use it in one of the following two ways: + +1. Pass it as the parameter to another message's `setFoo()`, thus copying the + field into a new message. Note that copying the parent struct as a whole + will *not* trigger the bug; the bug only occurs if the specific field `foo` + is get/set on its own. + +2. Convert it into `AnyList::Reader`, and then attempt to access it through + that. This is much less likely; very few apps use the `AnyList` API. + +The dynamic API equivalents of these actions (`capnp/dynamic.h`) are also +affected. + +If the application does these steps, the attacker may be able to cause the +Cap'n Proto implementation to read beyond the end of the message. This could +induce a segmentation fault. Or, worse, data that happened to be in memory +immediately after the message might be returned as if it were part of the +message. In the latter case, if the application then forwards that data back +to the attacker or sends it to another third party, this could result in +exfiltration of secrets. + +Any exfiltration of data would have the following limitations: + +* The attacker could exfiltrate no more than 512 KiB of memory immediately + following the message buffer. + * The attacker chooses in advance how far past the end of the message to + read. + * The attacker's message itself must be larger than the exfiltrated data. + Note that a sufficiently large message buffer will likely be allocated + using mmap() in which case the attack will likely segfault. +* The attack can only work if the 8 bytes immediately following the + exfiltrated data contains a valid in-bounds Cap'n Proto pointer. The + easiest way to achieve this is if the pointer is null, i.e. 8 bytes of zero. + * The attacker must specify exactly how much data to exfiltrate, so must + guess exactly where such a valid pointer will exist. + * If the exfiltrated data is not followed by a valid pointer, the attack + will throw an exception. If an application has chosen to ignore exceptions + (e.g. by compiling with `-fno-exceptions` and not registering an + alternative exception callback) then the attack may be able to proceed + anyway. + +[1]: https://dwrensha.github.io/capnproto-rust/2022/11/30/out_of_bounds_memory_access_bug.html diff --git a/security-advisories/README.md b/security-advisories/README.md index b6490a0138..ce6b322261 100644 --- a/security-advisories/README.md +++ b/security-advisories/README.md @@ -4,8 +4,6 @@ This directory contains security advisories issued for Cap'n Proto. Each advisory explains not just the bug that was fixed, but measures we are taking to avoid the class of bugs in the future. -Note that Cap'n Proto has not yet undergone formal security review and therefore should not yet be trusted for reading possibly-malicious input. Even so, Cap'n Proto intends to be secure and we treat security bugs no less seriously than we would had security review already taken place. - ## Reporting Bugs Please report security bugs to [security@sandstorm.io](mailto:security@sandstorm.io). diff --git a/style-guide.md b/style-guide.md index f85b12a5cd..3dc663333d 100644 --- a/style-guide.md +++ b/style-guide.md @@ -67,7 +67,7 @@ KJ code is RAII-strict. Whenever it is the case that "this block of code cannot Use the macros `KJ_DEFER`, `KJ_ON_SCOPE_SUCCESS`, and `KJ_ON_SCOPE_FAILURE` to easily specify some code that must be executed on exit from the current scope, without the need to define a whole class with a destructor. -Be careful when writing complicated destructors. If a destructor performs multiple cleanup actions, you generally need to make sure that the latter actions occur even if the former ones throw an exception. For this reason, a destructor should generally perform no more than one cleanup action. If you need to clean up multiple things, have you class contain multiple members representing the different things that need cleanup, each with its own destructor. This way, if one member's destructor throws, the others still run. +Be careful when writing complicated destructors. If a destructor performs multiple cleanup actions, you generally need to make sure that the latter actions occur even if the former ones throw an exception. For this reason, a destructor should generally perform no more than one cleanup action. If you need to clean up multiple things, have your class contain multiple members representing the different things that need cleanup, each with its own destructor. This way, if one member's destructor throws, the others still run. ### Ownership @@ -109,7 +109,7 @@ Keep in mind that atomic (thread-safe) reference counting can be extremely slow. A "singleton" is any mutable object or value that is globally accessible. "Globally accessible" means that the object is declared as a global variable or static member variable, or that the object can be found by following pointers from such variables. -Never use singletons. Singletons cause invisible and unexpected dependencies between components of your software that appear unrelated. Worse, the assumption that "there should only be one of this object per process" is almost always wrong, but its wrongness only becomes apparent after so much code uses the singleton that it is infeasible to change. Singleton interfaces ofter turn into unusable monstrosities in an attempt to work around the fact that they should never have been a singleton in the first place. +Never use singletons. Singletons cause invisible and unexpected dependencies between components of your software that appear unrelated. Worse, the assumption that "there should only be one of this object per process" is almost always wrong, but its wrongness only becomes apparent after so much code uses the singleton that it is infeasible to change. Singleton interfaces often turn into unusable monstrosities in an attempt to work around the fact that they should never have been a singleton in the first place. See ["Singletons Considered Harmful"](http://www.object-oriented-security.org/lets-argue/singletons) for a complete discussion. @@ -429,7 +429,7 @@ We use: * Clang for compiling. * `KJ_DBG()` for simple debugging. * Valgrind for complicated debugging. -* [Ekam](https://github.com/sandstorm-io/ekam) for a build system. +* [Ekam](https://github.com/capnproto/ekam) for a build system. * Git for version control. ## Irrelevant formatting rules @@ -449,7 +449,7 @@ As a code reviewer, when you see a violation of formatting rules, think carefull **Rationale:** There has never been broad agreement on C++ naming style. The closest we have is the C++ standard library. Unfortunately, the C++ standard library made the awful decision of naming types and values in the same style, losing a highly useful visual cue that makes programming more pleasant, and preventing variables from being named after their type (which in many contexts is perfectly appropriate). -Meanwhile, the Java style, which KJ emulates, has been broadly adopted to varying degrees in other languages, from Javascript to Haskell. Using a similar style in KJ code makes it less jarring to switch between C++ and those other languages. Being consistent with Javascript is especially useful because it is the one language that everyone pretty much has to use, due to its use in the web platform. +Meanwhile, the Java style, which KJ emulates, has been broadly adopted to varying degrees in other languages, from JavaScript to Haskell. Using a similar style in KJ code makes it less jarring to switch between C++ and those other languages. Being consistent with JavaScript is especially useful because it is the one language that everyone pretty much has to use, due to its use in the web platform. There has also never been any agreement on C++ file extensions, for some reason. The extension `.c++`, though not widely used, is accepted by all reasonable tools and is clearly the most precise choice. @@ -458,7 +458,9 @@ There has also never been any agreement on C++ file extensions, for some reason. * Indents are two spaces. * Never use tabs. * Maximum line length is 100 characters. -* Indent a continuation line by four spaces, *or* line them up nicely with the previous line if it makes it easier to read. +* Indent continuation lines for braced init lists by two spaces. +* Indent all other continuation lines by four spaces. +* Alternatively, line up continuation lines with previous lines if it makes them easier to read. * Place a space between a keyword and an open parenthesis, e.g.: `if (foo)` * Do not place a space between a function name and an open parenthesis, e.g.: `foo(bar)` * Place an opening brace at the end of the statement which initiates the block, not on its own line. @@ -469,6 +471,7 @@ There has also never been any agreement on C++ file extensions, for some reason. * Statements inside a `namespace` are **not** indented unless the namespace is a short block that is just forward-declaring things at the top of a file. * Set your editor to strip trailing whitespace on save, otherwise other people who use this setting will see spurious diffs when they edit a file after you. +
if (foo) { bar(); @@ -567,8 +570,7 @@ Headers: // // Licensed under the Whatever License blah blah no warranties. - #ifndef HEADER_PATH_FILENAME_H_ - #define HEADER_PATH_FILENAME_H_ + #pragma once // Documentation for file. #include @@ -585,8 +587,6 @@ Headers: } // namespace myproject - #endif // HEADER_PATH_FILENAME_H_ - Source code: // Project Name - Project brief description diff --git a/super-test.sh b/super-test.sh index 78e3df2f81..8d63659d09 100755 --- a/super-test.sh +++ b/super-test.sh @@ -7,10 +7,32 @@ doit() { "$@" } +function test_samples() { + echo "@@@@ ./addressbook (in various configurations)" + ./addressbook write | ./addressbook read + ./addressbook dwrite | ./addressbook dread + rm -f /tmp/capnp-calculator-example-$$ + ./calculator-server unix:/tmp/capnp-calculator-example-$$ & + local SERVER_PID=$! + sleep 1 + ./calculator-client unix:/tmp/capnp-calculator-example-$$ + # `kill %./calculator-server` doesn't seem to work on recent Cygwins, but we can kill by PID. + kill -9 $SERVER_PID + # This `fg` will fail if bash happens to have already noticed the quit and reaped the process + # before `fg` is invoked, so in that case we just proceed. + fg %./calculator-server || true + rm -f /tmp/capnp-calculator-example-$$ +} + QUICK= +CPP_FEATURES= +EXTRA_LIBS= PARALLEL=$(nproc 2>/dev/null || echo 1) +# Have automake dump test failure to stdout. Important for CI. +export VERBOSE=true + while [ $# -gt 0 ]; do case "$1" in -j* ) @@ -21,6 +43,30 @@ while [ $# -gt 0 ]; do quick ) QUICK=quick ;; + cpp-features ) + if [ "$#" -lt 2 ] || [ -n "$CPP_FEATURES" ]; then + echo "usage: $0 cpp-features CPP_DEFINES" >&2 + echo "e.g. $0 cpp-features '-DSOME_VAR=5 -DSOME_OTHER_VAR=6'" >&2 + if [ -n "$CPP_FEATURES" ]; then + echo "cpp-features provided multiple times" >&2 + fi + exit 1 + fi + CPP_FEATURES="$2" + shift + ;; + extra-libs ) + if [ "$#" -lt 2 ] || [ -n "$EXTRA_LIBS" ]; then + echo "usage: $0 extra-libs EXTRA_LIBS" >&2 + echo "e.g. $0 extra-libs '-lrt'" >&2 + if [ -n "$EXTRA_LIBS" ]; then + echo "extra-libs provided multiple times" >&2 + fi + exit 1 + fi + EXTRA_LIBS="$2" + shift + ;; caffeinate ) # Re-run preventing sleep. shift @@ -75,17 +121,19 @@ while [ $# -gt 0 ]; do export CXX="$2" shift ;; - clang ) - export CXX=clang++ - ;; - gcc-4.9 ) - export CXX=g++-4.9 + clang* ) + # Need to set CC as well for configure to handle -fcoroutines-ts. + export CC=clang${1#clang} + export CXX=clang++${1#clang} + if [ "$1" != "clang-5.0" ]; then + export LIB_FUZZING_ENGINE=-fsanitize=fuzzer + fi ;; - gcc-4.8 ) - export CXX=g++-4.8 + gcc* ) + export CXX=g++${1#gcc} ;; - gcc-4.7 ) - export CXX=g++-4.7 + g++* ) + export CXX=$1 ;; mingw ) if [ "$#" -ne 2 ]; then @@ -100,7 +148,7 @@ while [ $# -gt 0 ]; do export WINEPATH='Z:\usr\'"$CROSS_HOST"'\lib;Z:\usr\lib\gcc\'"$CROSS_HOST"'\6.3-win32;Z:'"$PWD"'\.libs' - doit ./configure --host="$CROSS_HOST" --disable-shared CXXFLAGS='-static-libgcc -static-libstdc++' + doit ./configure --host="$CROSS_HOST" --disable-shared CXXFLAGS="-static-libgcc -static-libstdc++ $CPP_FEATURES" LIBS="$EXTRA_LIBS" doit make -j$PARALLEL check doit make distclean @@ -112,14 +160,17 @@ while [ $# -gt 0 ]; do # - Download command-line tools: https://developer.android.com/studio/index.html#command-tools # - Run $SDK_HOME/tools/bin/sdkmanager platform-tools 'platforms;android-25' 'system-images;android-25;google_apis;armeabi-v7a' emulator 'build-tools;25.0.2' ndk-bundle # - Run $SDK_HOME/tools/bin/avdmanager create avd -n capnp -k 'system-images;android-25;google_apis;armeabi-v7a' -b google_apis/armeabi-v7a - # - Run $SDK_HOME/ndk-bundle/build/tools/make_standalone_toolchain.py --arch arm --api 24 --install-dir $TOOLCHAIN_HOME if [ "$#" -ne 4 ]; then - echo "usage: $0 android SDK_HOME TOOLCHAIN_HOME CROSS_HOST" >&2 + echo "usage: $0 android SDK_HOME CROSS_HOST COMPILER_PREFIX" >&2 + echo + echo "SDK_HOME: Location where android-sdk is installed." >&2 + echo "CROSS_HOST: E.g. arm-linux-androideabi" >&2 + echo "COMPILER_PREFIX: E.g. armv7a-linux-androideabi24" >&2 exit 1 fi SDK_HOME=$2 - TOOLCHAIN_HOME=$3 - CROSS_HOST=$4 + CROSS_HOST=$3 + COMPILER_PREFIX=$4 cd c++ test -e configure || doit autoreconf -i @@ -130,9 +181,9 @@ while [ $# -gt 0 ]; do cp capnp capnp-host cp capnpc-c++ capnpc-c++-host - export PATH="$TOOLCHAIN_HOME/bin:$PATH" + export PATH="$SDK_HOME/ndk-bundle/toolchains/llvm/prebuilt/linux-x86_64/bin:$PATH" doit make distclean - doit ./configure --host="$CROSS_HOST" --with-external-capnp --disable-shared CXXFLAGS='-pie -fPIE' CAPNP=./capnp-host CAPNPC_CXX=./capnpc-c++-host + doit ./configure --host="$CROSS_HOST" CC="$COMPILER_PREFIX-clang" CXX="$COMPILER_PREFIX-clang++" --with-external-capnp --disable-shared CXXFLAGS="-fPIE $CPP_FEATURES" LDFLAGS='-pie' LIBS="-static-libstdc++ -static-libgcc -ldl $EXTRA_LIBS" CAPNP=./capnp-host CAPNPC_CXX=./capnpc-c++-host doit make -j$PARALLEL doit make -j$PARALLEL capnp-test @@ -162,6 +213,75 @@ while [ $# -gt 0 ]; do cd cmake-build doit cmake -G "Unix Makefiles" .. doit make -j$PARALLEL check + exit 0 + ;; + cmake-package ) + # Test that a particular configuration of Cap'n Proto can be discovered and configured against + # by a CMake project using the find_package() command. This is currently implemented by + # building the samples against the desired configuration. + # + # Takes one argument, the build configuration, which must be one of: + # + # autotools-shared + # autotools-static + # cmake-shared + # cmake-static + + if [ "$#" -ne 2 ]; then + echo "usage: $0 cmake-package CONFIGURATION" >&2 + echo " where CONFIGURATION is one of {autotools,cmake}-{static,shared}" >&2 + exit 1 + fi + + CONFIGURATION=$2 + WORKSPACE=$(pwd)/cmake-package/$CONFIGURATION + SOURCE_DIR=$(pwd)/c++ + + rm -rf $WORKSPACE + mkdir -p $WORKSPACE/{build,build-samples,inst} + + # Configure + cd $WORKSPACE/build + case "$CONFIGURATION" in + autotools-shared ) + autoreconf -i $SOURCE_DIR + doit $SOURCE_DIR/configure --prefix="$WORKSPACE/inst" --disable-static + ;; + autotools-static ) + autoreconf -i $SOURCE_DIR + doit $SOURCE_DIR/configure --prefix="$WORKSPACE/inst" --disable-shared + ;; + cmake-shared ) + doit cmake $SOURCE_DIR -G "Unix Makefiles" -DCMAKE_INSTALL_PREFIX="$WORKSPACE/inst" \ + -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=ON + # The CMake build does not currently set the rpath of the capnp compiler tools. + export LD_LIBRARY_PATH="$WORKSPACE/inst/lib" + ;; + cmake-static ) + doit cmake $SOURCE_DIR -G "Unix Makefiles" -DCMAKE_INSTALL_PREFIX="$WORKSPACE/inst" \ + -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=OFF + ;; + * ) + echo "Unrecognized cmake-package CONFIGURATION argument, must be {autotools,cmake}-{static,shared}" >&2 + exit 1 + ;; + esac + + # Build and install + doit make -j$PARALLEL install + + # Configure, build, and execute the samples. + cd $WORKSPACE/build-samples + doit cmake $SOURCE_DIR/samples -G "Unix Makefiles" -DCMAKE_PREFIX_PATH="$WORKSPACE/inst" \ + -DCAPNPC_FLAGS=--no-standard-import -DCAPNPC_IMPORT_DIRS="$WORKSPACE/inst/include" + doit make -j$PARALLEL + + test_samples + + echo "=========================================================================" + echo "Cap'n Proto ($CONFIGURATION) installs a working CMake config package." + echo "=========================================================================" + exit 0 ;; exotic ) @@ -181,6 +301,13 @@ while [ $# -gt 0 ]; do echo "CMake" echo "=========================================================================" "$0" cmake + echo "=========================================================================" + echo "CMake config packages" + echo "=========================================================================" + "$0" cmake-package autotools-shared + "$0" cmake-package autotools-static + "$0" cmake-package cmake-shared + "$0" cmake-package cmake-static exit 0 ;; clean ) @@ -223,7 +350,14 @@ done # sign-compare warnings than probably all other warnings combined and I've never seen it flag a # real problem. Disable unused parameters because it's stupidly noisy and never a real problem. # Enable expensive release-gating tests. -export CXXFLAGS="-O2 -DDEBUG -Wall -Wextra -Werror -Wno-strict-aliasing -Wno-sign-compare -Wno-unused-parameter -DCAPNP_EXPENSIVE_TESTS=1" +export CXXFLAGS="-O2 -DDEBUG -Wall -Wextra -Werror -Wno-strict-aliasing -Wno-sign-compare -Wno-unused-parameter -DCAPNP_EXPENSIVE_TESTS=1 ${CPP_FEATURES}" +export LIBS="$EXTRA_LIBS" + +if [ "${CXX:-}" != "g++-5" ]; then + # This warning flag is missing on g++-5 but available on all other GCC/Clang versions we target + # in CI. + export CXXFLAGS="$CXXFLAGS -Wimplicit-fallthrough" +fi STAGING=$PWD/tmp-staging @@ -244,7 +378,20 @@ echo "Building c++" echo "=========================================================================" # Apple now aliases gcc to clang, so probe to find out what compiler we're really using. -if (${CXX:-g++} -dM -E -x c++ /dev/null 2>&1 | grep -q '__clang__'); then +# +# NOTE: You might be tempted to use `grep -q` here instead of sending output to /dev/null. However, +# we cannot, because `grep -q` exits immediately upon seeing a match. If it exits too soon, the +# first stage of the pipeline gets killed, and the whole expression is considered to have failed +# since we are running bash with the `pipefail` option enabled. +# FUN STORY: We used to use grep -q. One day, we found that Clang 9 when running under GitHub +# Actions was detected as *not* Clang. But if we ran it twice, it would succeed on the second +# try. It turns out that under previous versions of Clang, the `__clang__` define was pretty +# close to the end of the list, so it always managed to write the whole list before `grep -q` +# exited. But under Clang 9, there's a bunch more defines after this one, giving more time for +# `grep -q` to exit and break everything. But if the compiler had executed once recently then +# the second run would go faster due to caching (I guess) and manage to get all the data out +# to the buffer in time. +if (${CXX:-g++} -dM -E -x c++ /dev/null 2>&1 | grep '__clang__' > /dev/null); then IS_CLANG=yes DISABLE_OPTIMIZATION_IF_GCC= else @@ -255,22 +402,47 @@ fi if [ $IS_CLANG = yes ]; then # Don't fail out on this ridiculous "argument unused during compilation" warning. export CXXFLAGS="$CXXFLAGS -Wno-error=unused-command-line-argument" + + # Enable coroutines if supported. + if [ "${CXX#*-}" -ge 14 ] 2>/dev/null; then + # Somewhere between version 10 and 14, Clang started supporting coroutines as a C++20 feature, + # and started issuing deprecating warnings for -fcoroutines-ts. (I'm not sure which version it + # was exactly since our CI jumped from 10 to 14, so I'm somewhat arbitrarily choosing 14 as the + # cutoff.) + export CXXFLAGS="$CXXFLAGS -std=c++20 -stdlib=libc++" + export LDFLAGS="-stdlib=libc++" + + # TODO(someday): On Ubuntu 22.04, clang-14 with -stdlib=libc++ fails to link with libfuzzer, + # which looks like it might itself be linked against libstdc++? Need to investigate. + unset LIB_FUZZING_ENGINE + elif [ "${CXX#*-}" -ge 10 ] 2>/dev/null; then + # At the moment, only our clang-10 CI run seems to like -fcoroutines-ts. Earlier versions seem + # to have a library misconfiguration causing ./configure to result in the following error: + # conftest.cpp:12:12: fatal error: 'initializer_list' file not found + # #include + # Let's use any clang version >= 10 so that if we move to a newer version, we'll get additional + # coverage by default. + export CXXFLAGS="$CXXFLAGS -std=gnu++17 -stdlib=libc++ -fcoroutines-ts" + export LDFLAGS="-fcoroutines-ts -stdlib=libc++" + fi else # GCC emits uninitialized warnings all over and they seem bogus. We use valgrind to test for # uninitialized memory usage later on. GCC 4 also emits strange bogus warnings with # -Wstrict-overflow, so we disable it. CXXFLAGS="$CXXFLAGS -Wno-maybe-uninitialized -Wno-strict-overflow" + + # TODO(someday): Enable coroutines in g++ if supported. fi cd c++ doit autoreconf -i -doit ./configure --prefix="$STAGING" +doit ./configure --prefix="$STAGING" || (cat config.log && exit 1) doit make -j$PARALLEL check if [ $IS_CLANG = no ]; then # Verify that generated code compiles with pedantic warnings. Make sure to treat capnp headers # as system headers so warnings in them are ignored. - doit ${CXX:-g++} -isystem src -std=c++11 -fno-permissive -pedantic -Wall -Wextra -Werror \ + doit ${CXX:-g++} -isystem src -std=c++14 -fno-permissive -pedantic -Wall -Wextra -Werror \ -c src/capnp/test.capnp.c++ -o /dev/null fi @@ -286,27 +458,30 @@ test "x$(which capnpc-c++)" = "x$STAGING/bin/capnpc-c++" cd samples doit capnp compile -oc++ addressbook.capnp -I"$STAGING"/include --no-standard-import -doit ${CXX:-g++} -std=c++11 addressbook.c++ addressbook.capnp.c++ -o addressbook \ +doit ${CXX:-g++} -std=c++14 addressbook.c++ addressbook.capnp.c++ -o addressbook \ $CXXFLAGS $(pkg-config --cflags --libs capnp) -echo "@@@@ ./addressbook (in various configurations)" -./addressbook write | ./addressbook read -./addressbook dwrite | ./addressbook dread -rm addressbook addressbook.capnp.c++ addressbook.capnp.h doit capnp compile -oc++ calculator.capnp -I"$STAGING"/include --no-standard-import -doit ${CXX:-g++} -std=c++11 calculator-client.c++ calculator.capnp.c++ -o calculator-client \ +doit ${CXX:-g++} -std=c++14 calculator-client.c++ calculator.capnp.c++ -o calculator-client \ $CXXFLAGS $(pkg-config --cflags --libs capnp-rpc) -doit ${CXX:-g++} -std=c++11 calculator-server.c++ calculator.capnp.c++ -o calculator-server \ +doit ${CXX:-g++} -std=c++14 calculator-server.c++ calculator.capnp.c++ -o calculator-server \ $CXXFLAGS $(pkg-config --cflags --libs capnp-rpc) -rm -f /tmp/capnp-calculator-example-$$ -./calculator-server unix:/tmp/capnp-calculator-example-$$ & -sleep 0.1 -./calculator-client unix:/tmp/capnp-calculator-example-$$ -kill %+ -wait %+ || true -rm calculator-client calculator-server calculator.capnp.c++ calculator.capnp.h /tmp/capnp-calculator-example-$$ -cd .. +test_samples +rm addressbook addressbook.capnp.c++ addressbook.capnp.h +rm calculator-client calculator-server calculator.capnp.c++ calculator.capnp.h + +rm -rf cmake-build +mkdir cmake-build +cd cmake-build + +doit cmake .. -G "Unix Makefiles" -DCMAKE_PREFIX_PATH="$STAGING" \ + -DCAPNPC_FLAGS=--no-standard-import -DCAPNPC_IMPORT_DIRS="$STAGING/include" +doit make -j$PARALLEL + +test_samples +cd ../.. +rm -rf samples/cmake-build if [ "$QUICK" = quick ]; then doit make maintainer-clean @@ -315,16 +490,7 @@ if [ "$QUICK" = quick ]; then fi echo "=========================================================================" -echo "Testing --with-external-capnp" -echo "=========================================================================" - -doit make distclean -doit ./configure --prefix="$STAGING" --disable-shared \ - --with-external-capnp CAPNP=$STAGING/bin/capnp -doit make -j$PARALLEL check - -echo "=========================================================================" -echo "Testing --disable-reflection" +echo "Testing --with-external-capnp and --disable-reflection" echo "=========================================================================" doit make distclean @@ -334,32 +500,20 @@ doit make -j$PARALLEL check doit make distclean # Test 32-bit build now while we have $STAGING available for cross-compiling. -if [ "x`uname -m`" = "xx86_64" ]; then +# +# Cygwin64 can cross-compile to Cygwin32 but can't actually run the cross-compiled binaries. Let's +# just skip this test on Cygwin since it's so slow and honestly no one cares. +# +# MacOS apparently no longer distributes 32-bit standard libraries. OK fine let's restrict this to +# Linux. +if [ "x`uname -m`" = "xx86_64" ] && [ "x`uname`" = xLinux ]; then echo "=========================================================================" echo "Testing 32-bit build" echo "=========================================================================" - if [[ "`uname`" =~ CYGWIN ]]; then - # It's just not possible to run cygwin32 binaries from within cygwin64. - - # Build as if we are cross-compiling, using the capnp we installed to $STAGING. - doit ./configure --prefix="$STAGING" --disable-shared --host=i686-pc-cygwin \ - --with-external-capnp CAPNP=$STAGING/bin/capnp - doit make -j$PARALLEL - doit make -j$PARALLEL capnp-test.exe - - # Expect a cygwin32 sshd to be listening at localhost port 2222, and use it - # to run the tests. - doit scp -P 2222 capnp-test.exe localhost:~/tmp-capnp-test.exe - doit ssh -p 2222 localhost './tmp-capnp-test.exe && rm tmp-capnp-test.exe' - - doit make distclean - - elif [ "x${CXX:-g++}" != "xg++-4.8" ]; then - doit ./configure CXX="${CXX:-g++} -m32" --disable-shared - doit make -j$PARALLEL check - doit make distclean - fi + doit ./configure CXX="${CXX:-g++} -m32" CXXFLAGS="$CXXFLAGS ${ADDL_M32_FLAGS:-}" --disable-shared + doit make -j$PARALLEL check + doit make distclean fi echo "=========================================================================" @@ -396,32 +550,23 @@ echo "=========================================================================" # is inlined in hundreds of other places without issue, so I have no idea how to narrow down the # bug. Clang works fine. So, for now, we disable optimizations on GCC for -fno-exceptions tests. -doit ./configure --disable-shared CXXFLAGS="$CXXFLAGS -fno-rtti" -doit make -j$PARALLEL check -doit make distclean -doit ./configure --disable-shared CXXFLAGS="$CXXFLAGS -fno-exceptions $DISABLE_OPTIMIZATION_IF_GCC" -doit make -j$PARALLEL check -doit make distclean doit ./configure --disable-shared CXXFLAGS="$CXXFLAGS -fno-rtti -fno-exceptions $DISABLE_OPTIMIZATION_IF_GCC" doit make -j$PARALLEL check -# Valgrind is currently "experimental and mostly broken" on OSX and fails to run the full test -# suite, but I have it installed because it did manage to help me track down a bug or two. Anyway, -# skip it on OSX for now. -if [ "x`uname`" != xDarwin ] && which valgrind > /dev/null; then +if [ "x`uname`" = xLinux ]; then doit make distclean echo "=========================================================================" echo "Testing with valgrind" echo "=========================================================================" - doit ./configure --disable-shared CXXFLAGS="-g" + doit ./configure --disable-shared CXXFLAGS="-g $CPP_FEATURES" doit make -j$PARALLEL doit make -j$PARALLEL capnp-test # Running the fuzz tests under Valgrind is a great thing to do -- but it takes # some 40 minutes. So, it needs to be done as a separate step of the release # process, perhaps along with the AFL tests. - CAPNP_SKIP_FUZZ_TEST=1 doit valgrind --leak-check=full --track-fds=yes --error-exitcode=1 ./capnp-test + CAPNP_SKIP_FUZZ_TEST=1 doit valgrind --leak-check=full --track-fds=yes --error-exitcode=1 --child-silent-after-fork=yes --sim-hints=lax-ioctls --suppressions=valgrind.supp ./capnp-test fi doit make maintainer-clean