diff --git a/.asf.yaml b/.asf.yaml index 22042b355b2fa..3935a525ff3c4 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -31,6 +31,8 @@ github: merge: false squash: true rebase: true + ghp_branch: master + ghp_path: /docs notifications: pullrequests: reviews@spark.apache.org diff --git a/.github/labeler.yml b/.github/labeler.yml index a74b4ab5d5a33..6617acbf9187e 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -26,16 +26,14 @@ INFRA: '.asf.yaml', '.gitattributes', '.gitignore', - 'dev/merge_spark_pr.py', - 'dev/run-tests-jenkins*' + 'dev/merge_spark_pr.py' ] BUILD: - changed-files: - all-globs-to-any-file: [ 'dev/**/*', - '!dev/merge_spark_pr.py', - '!dev/run-tests-jenkins*' + '!dev/merge_spark_pr.py' ] - any-glob-to-any-file: [ 'build/**/*', @@ -199,6 +197,7 @@ YARN: KUBERNETES: - changed-files: - any-glob-to-any-file: [ + 'bin/docker-image-tool.sh', 'resource-managers/kubernetes/**/*' ] @@ -225,7 +224,7 @@ DEPLOY: CONNECT: - changed-files: - any-glob-to-any-file: [ - 'connect/**/*', + 'sql/connect/**/*', 'connector/connect/**/*', 'python/pyspark/sql/**/connect/**/*', 'python/pyspark/ml/**/connect/**/*' diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 576f64f3a0869..3117872e21680 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -58,6 +58,12 @@ jobs: outputs: required: ${{ steps.set-outputs.outputs.required }} image_url: ${{ steps.infra-image-outputs.outputs.image_url }} + image_docs_url: ${{ steps.infra-image-docs-outputs.outputs.image_docs_url }} + image_docs_url_link: ${{ steps.infra-image-link.outputs.image_docs_url_link }} + image_lint_url: ${{ steps.infra-image-lint-outputs.outputs.image_lint_url }} + image_lint_url_link: ${{ steps.infra-image-link.outputs.image_lint_url_link }} + image_sparkr_url: ${{ steps.infra-image-sparkr-outputs.outputs.image_sparkr_url }} + image_sparkr_url_link: ${{ steps.infra-image-link.outputs.image_sparkr_url_link }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -134,6 +140,44 @@ jobs: IMG_NAME="apache-spark-ci-image:${{ inputs.branch }}-${{ github.run_id }}" IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" echo "image_url=$IMG_URL" >> $GITHUB_OUTPUT + - name: Generate infra image URL (Documentation) + id: infra-image-docs-outputs + run: | + # Convert to lowercase to meet Docker repo name requirement + REPO_OWNER=$(echo "${{ github.repository_owner }}" | tr '[:upper:]' '[:lower:]') + IMG_NAME="apache-spark-ci-image-docs:${{ inputs.branch }}-${{ github.run_id }}" + IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" + echo "image_docs_url=$IMG_URL" >> $GITHUB_OUTPUT + - name: Generate infra image URL (Linter) + id: infra-image-lint-outputs + run: | + # Convert to lowercase to meet Docker repo name requirement + REPO_OWNER=$(echo "${{ github.repository_owner }}" | tr '[:upper:]' '[:lower:]') + IMG_NAME="apache-spark-ci-image-lint:${{ inputs.branch }}-${{ github.run_id }}" + IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" + echo "image_lint_url=$IMG_URL" >> $GITHUB_OUTPUT + - name: Generate infra image URL (SparkR) + id: infra-image-sparkr-outputs + run: | + # Convert to lowercase to meet Docker repo name requirement + REPO_OWNER=$(echo "${{ github.repository_owner }}" | tr '[:upper:]' '[:lower:]') + IMG_NAME="apache-spark-ci-image-sparkr:${{ inputs.branch }}-${{ github.run_id }}" + IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" + echo "image_sparkr_url=$IMG_URL" >> $GITHUB_OUTPUT + - name: Link the docker images + id: infra-image-link + run: | + # Set the image URL for job "docs" + # Should delete the link and directly use image_docs_url after SPARK 3.x EOL + if [[ "${{ inputs.branch }}" == 'branch-3.5' ]]; then + echo "image_docs_url_link=${{ steps.infra-image-outputs.outputs.image_url }}" >> $GITHUB_OUTPUT + echo "image_lint_url_link=${{ steps.infra-image-outputs.outputs.image_url }}" >> $GITHUB_OUTPUT + echo "image_sparkr_url_link=${{ steps.infra-image-outputs.outputs.image_url }}" >> $GITHUB_OUTPUT + else + echo "image_docs_url_link=${{ steps.infra-image-docs-outputs.outputs.image_docs_url }}" >> $GITHUB_OUTPUT + echo "image_lint_url_link=${{ steps.infra-image-lint-outputs.outputs.image_lint_url }}" >> $GITHUB_OUTPUT + echo "image_sparkr_url_link=${{ steps.infra-image-sparkr-outputs.outputs.image_sparkr_url }}" >> $GITHUB_OUTPUT + fi # Build: build Spark and run the tests for specified modules. build: @@ -264,20 +308,20 @@ jobs: with: distribution: zulu java-version: ${{ matrix.java }} - - name: Install Python 3.9 + - name: Install Python 3.11 uses: actions/setup-python@v5 # We should install one Python that is higher than 3+ for SQL and Yarn because: # - SQL component also has Python related tests, for example, IntegratedUDFTestUtils. # - Yarn has a Python specific test too, for example, YarnClusterSuite. if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') with: - python-version: '3.9' + python-version: '3.11' architecture: x64 - - name: Install Python packages (Python 3.9) + - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') run: | - python3.9 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==4.25.1' - python3.9 -m pip list + python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.28.3' + python3.11 -m pip list # Run the tests. - name: Run tests env: ${{ fromJSON(inputs.envs) }} @@ -304,7 +348,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} - path: "**/target/unit-tests.log" + path: "**/target/*.log" infra-image: name: "Base image build" @@ -343,7 +387,7 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Build and push id: docker_build - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: ./dev/infra/ push: true @@ -351,6 +395,40 @@ jobs: ${{ needs.precondition.outputs.image_url }} # Use the infra image cache to speed up cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ inputs.branch }} + - name: Build and push (Documentation) + if: hashFiles('dev/spark-test-image/docs/Dockerfile') != '' + id: docker_build_docs + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/docs/ + push: true + tags: | + ${{ needs.precondition.outputs.image_docs_url }} + # Use the infra image cache to speed up + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ inputs.branch }} + - name: Build and push (Linter) + if: hashFiles('dev/spark-test-image/lint/Dockerfile') != '' + id: docker_build_lint + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/lint/ + push: true + tags: | + ${{ needs.precondition.outputs.image_lint_url }} + # Use the infra image cache to speed up + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-lint-cache:${{ inputs.branch }} + - name: Build and push (SparkR) + if: hashFiles('dev/spark-test-image/sparkr/Dockerfile') != '' + id: docker_build_sparkr + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/sparkr/ + push: true + tags: | + ${{ needs.precondition.outputs.image_sparkr_url }} + # Use the infra image cache to speed up + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-sparkr-cache:${{ inputs.branch }} + pyspark: needs: [precondition, infra-image] @@ -509,7 +587,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 180 container: - image: ${{ needs.precondition.outputs.image_url }} + image: ${{ needs.precondition.outputs.image_sparkr_url_link }} env: HADOOP_PROFILE: ${{ inputs.hadoop }} HIVE_PROFILE: hive2.3 @@ -606,18 +684,22 @@ jobs: - name: Breaking change detection against branch-3.5 uses: bufbuild/buf-breaking-action@v1 with: - input: connect/common/src/main + input: sql/connect/common/src/main against: 'https://github.com/apache/spark.git#branch=branch-3.5,subdir=connector/connect/common/src/main' - - name: Install Python 3.9 + - name: Install Python 3.11 uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.11' - name: Install dependencies for Python CodeGen check run: | - python3.9 -m pip install 'black==23.9.1' 'protobuf==4.25.1' 'mypy==1.8.0' 'mypy-protobuf==3.3.0' - python3.9 -m pip list - - name: Python CodeGen check + python3.11 -m pip install 'black==23.9.1' 'protobuf==5.28.3' 'mypy==1.8.0' 'mypy-protobuf==3.3.0' + python3.11 -m pip list + - name: Python CodeGen check for branch-3.5 + if: inputs.branch == 'branch-3.5' run: ./dev/connect-check-protos.py + - name: Python CodeGen check + if: inputs.branch != 'branch-3.5' + run: ./dev/check-protos.py # Static analysis lint: @@ -635,7 +717,7 @@ jobs: PYSPARK_PYTHON: python3.9 GITHUB_PREV_SHA: ${{ github.event.before }} container: - image: ${{ needs.precondition.outputs.image_url }} + image: ${{ needs.precondition.outputs.image_lint_url_link }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -702,13 +784,6 @@ jobs: run: ./dev/lint-java - name: Spark connect jvm client mima check run: ./dev/connect-jvm-client-mima-check - - name: Install Python linter dependencies for branch-3.4 - if: inputs.branch == 'branch-3.4' - run: | - # SPARK-44554: Copy from https://github.com/apache/spark/blob/a05c27e85829fe742c1828507a1fd180cdc84b54/.github/workflows/build_and_test.yml#L571-L578 - # Should delete this section after SPARK 3.4 EOL. - python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' - python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.48.1' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - name: Install Python linter dependencies for branch-3.5 if: inputs.branch == 'branch-3.5' run: | @@ -716,18 +791,8 @@ jobs: # Should delete this section after SPARK 3.5 EOL. python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.982' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.56.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - - name: Install Python dependencies for python linter and documentation generation - if: inputs.branch != 'branch-3.4' && inputs.branch != 'branch-3.5' - run: | - # Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 - # See 'ipython_genutils' in SPARK-38517 - # See 'docutils<0.18.0' in SPARK-39421 - python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ - 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ - 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ - 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' - python3.9 -m pip list + - name: List Python packages + run: python3.9 -m pip list - name: Python linter run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python # Should delete this section after SPARK 3.5 EOL. @@ -745,16 +810,16 @@ jobs: if: inputs.branch == 'branch-3.5' run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi # Should delete this section after SPARK 3.5 EOL. - - name: Install JavaScript linter dependencies for branch-3.4, branch-3.5 - if: inputs.branch == 'branch-3.4' || inputs.branch == 'branch-3.5' + - name: Install JavaScript linter dependencies for branch-3.5 + if: inputs.branch == 'branch-3.5' run: | apt update apt-get install -y nodejs npm - name: JS linter run: ./dev/lint-js # Should delete this section after SPARK 3.5 EOL. - - name: Install R linter dependencies for branch-3.4, branch-3.5 - if: inputs.branch == 'branch-3.4' || inputs.branch == 'branch-3.5' + - name: Install R linter dependencies for branch-3.5 + if: inputs.branch == 'branch-3.5' run: | apt update apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev \ @@ -783,7 +848,7 @@ jobs: PYSPARK_PYTHON: python3.9 GITHUB_PREV_SHA: ${{ github.event.before }} container: - image: ${{ needs.precondition.outputs.image_url }} + image: ${{ needs.precondition.outputs.image_docs_url_link }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -833,20 +898,8 @@ jobs: with: distribution: zulu java-version: ${{ inputs.java }} - - name: Install Python dependencies for python linter and documentation generation - if: inputs.branch != 'branch-3.4' && inputs.branch != 'branch-3.5' - run: | - # Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 - # See 'ipython_genutils' in SPARK-38517 - # See 'docutils<0.18.0' in SPARK-39421 - python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ - 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ - 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ - 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' - python3.9 -m pip list - - name: Install dependencies for documentation generation for branch-3.4, branch-3.5 - if: inputs.branch == 'branch-3.4' || inputs.branch == 'branch-3.5' + - name: Install dependencies for documentation generation for branch-3.5 + if: inputs.branch == 'branch-3.5' run: | # pandoc is required to generate PySpark APIs as well in nbsphinx. apt-get update -y @@ -860,6 +913,8 @@ jobs: python3.9 -m pip install ipython_genutils # See SPARK-38517 python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421 + - name: List Python packages + run: python3.9 -m pip list - name: Install dependencies for documentation generation run: | # Keep the version of Bundler here in sync with the following locations: @@ -867,20 +922,26 @@ jobs: # - docs/README.md gem install bundler -v 2.4.22 cd docs - bundle install + bundle install --retry=100 - name: Run documentation build run: | - # We need this link because the jekyll build calls `python`. - ln -s "$(which python3.9)" "/usr/local/bin/python" + # We need this link to make sure `python3` points to `python3.9` which contains the prerequisite packages. + ln -s "$(which python3.9)" "/usr/local/bin/python3" # Build docs first with SKIP_API to ensure they are buildable without requiring any # language docs to be built beforehand. - cd docs; SKIP_API=1 bundle exec jekyll build; cd .. + cd docs; SKIP_ERRORDOC=1 SKIP_API=1 bundle exec jekyll build; cd .. if [ -f "./dev/is-changed.py" ]; then # Skip PySpark and SparkR docs while keeping Scala/Java/SQL docs pyspark_modules=`cd dev && python3.9 -c "import sparktestsupport.modules as m; print(','.join(m.name for m in m.all_modules if m.name.startswith('pyspark')))"` if [ `./dev/is-changed.py -m $pyspark_modules` = false ]; then export SKIP_PYTHONDOC=1; fi if [ `./dev/is-changed.py -m sparkr` = false ]; then export SKIP_RDOC=1; fi fi + # Print the values of environment variables `SKIP_ERRORDOC`, `SKIP_SCALADOC`, `SKIP_PYTHONDOC`, `SKIP_RDOC` and `SKIP_SQLDOC` + echo "SKIP_ERRORDOC: $SKIP_ERRORDOC" + echo "SKIP_SCALADOC: $SKIP_SCALADOC" + echo "SKIP_PYTHONDOC: $SKIP_PYTHONDOC" + echo "SKIP_RDOC: $SKIP_RDOC" + echo "SKIP_SQLDOC: $SKIP_SQLDOC" cd docs bundle exec jekyll build - name: Tar documentation @@ -1106,14 +1167,16 @@ jobs: with: distribution: zulu java-version: ${{ inputs.java }} - - name: start minikube + - name: Install R run: | - # See more in "Installation" https://minikube.sigs.k8s.io/docs/start/ - curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube-linux-amd64 - sudo install minikube-linux-amd64 /usr/local/bin/minikube - rm minikube-linux-amd64 + sudo apt update + sudo apt-get install r-base + - name: Start Minikube + uses: medyagh/setup-minikube@v0.0.18 + with: # Github Action limit cpu:2, memory: 6947MB, limit to 2U6G for better resource statistic - minikube start --cpus 2 --memory 6144 + cpus: 2 + memory: 6144m - name: Print K8S pods and nodes info run: | kubectl get pods -A @@ -1126,11 +1189,10 @@ jobs: export PVC_TESTS_VM_PATH=$PVC_TMP_DIR minikube mount ${PVC_TESTS_HOST_PATH}:${PVC_TESTS_VM_PATH} --gid=0 --uid=185 & kubectl create clusterrolebinding serviceaccounts-cluster-admin --clusterrole=cluster-admin --group=system:serviceaccounts || true - kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.2/installer/volcano-development.yaml || true - if [[ "${{ inputs.branch }}" == 'branch-3.5' || "${{ inputs.branch }}" == 'branch-3.4' ]]; then + if [[ "${{ inputs.branch }}" == 'branch-3.5' ]]; then kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.7.0/installer/volcano-development.yaml || true else - kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.2/installer/volcano-development.yaml || true + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.9.0/installer/volcano-development.yaml || true fi eval $(minikube docker-env) build/sbt -Phadoop-3 -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" diff --git a/.github/workflows/build_infra_images_cache.yml b/.github/workflows/build_infra_images_cache.yml index 49b2e2e80d9ac..a6beacedeebd4 100644 --- a/.github/workflows/build_infra_images_cache.yml +++ b/.github/workflows/build_infra_images_cache.yml @@ -27,6 +27,9 @@ on: - 'branch-*' paths: - 'dev/infra/Dockerfile' + - 'dev/spark-test-image/docs/Dockerfile' + - 'dev/spark-test-image/lint/Dockerfile' + - 'dev/spark-test-image/sparkr/Dockerfile' - '.github/workflows/build_infra_images_cache.yml' # Create infra image when cutting down branches/tags create: @@ -51,7 +54,7 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} - name: Build and push id: docker_build - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: ./dev/infra/ push: true @@ -60,3 +63,42 @@ jobs: cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }},mode=max - name: Image digest run: echo ${{ steps.docker_build.outputs.digest }} + - name: Build and push (Documentation) + if: hashFiles('dev/spark-test-image/docs/Dockerfile') != '' + id: docker_build_docs + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/docs/ + push: true + tags: ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }}-static + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }} + cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ github.ref_name }},mode=max + - name: Image digest (Documentation) + if: hashFiles('dev/spark-test-image/docs/Dockerfile') != '' + run: echo ${{ steps.docker_build_docs.outputs.digest }} + - name: Build and push (Linter) + if: hashFiles('dev/spark-test-image/lint/Dockerfile') != '' + id: docker_build_lint + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/lint/ + push: true + tags: ghcr.io/apache/spark/apache-spark-github-action-image-lint-cache:${{ github.ref_name }}-static + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-lint-cache:${{ github.ref_name }} + cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-lint-cache:${{ github.ref_name }},mode=max + - name: Image digest (Linter) + if: hashFiles('dev/spark-test-image/lint/Dockerfile') != '' + run: echo ${{ steps.docker_build_lint.outputs.digest }} + - name: Build and push (SparkR) + if: hashFiles('dev/spark-test-image/sparkr/Dockerfile') != '' + id: docker_build_sparkr + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/sparkr/ + push: true + tags: ghcr.io/apache/spark/apache-spark-github-action-image-sparkr-cache:${{ github.ref_name }}-static + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-sparkr-cache:${{ github.ref_name }} + cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-sparkr-cache:${{ github.ref_name }},mode=max + - name: Image digest (SparkR) + if: hashFiles('dev/spark-test-image/sparkr/Dockerfile') != '' + run: echo ${{ steps.docker_build_sparkr.outputs.digest }} diff --git a/.github/workflows/build_maven.yml b/.github/workflows/build_maven.yml index c3a23c02f6a61..b5546c61eb11b 100644 --- a/.github/workflows/build_maven.yml +++ b/.github/workflows/build_maven.yml @@ -30,8 +30,3 @@ jobs: name: Run uses: ./.github/workflows/maven_test.yml if: github.repository == 'apache/spark' - with: - envs: >- - { - "SKIP_SPARK_RELEASE_VERSIONS": "3.4.2" - } diff --git a/.github/workflows/build_maven_java21_macos14.yml b/.github/workflows/build_maven_java21_macos15.yml similarity index 92% rename from .github/workflows/build_maven_java21_macos14.yml rename to .github/workflows/build_maven_java21_macos15.yml index fb5e609f4eae0..cc6d0ea4e90da 100644 --- a/.github/workflows/build_maven_java21_macos14.yml +++ b/.github/workflows/build_maven_java21_macos15.yml @@ -17,7 +17,7 @@ # under the License. # -name: "Build / Maven (master, Scala 2.13, Hadoop 3, JDK 21, macos-14)" +name: "Build / Maven (master, Scala 2.13, Hadoop 3, JDK 21, MacOS-15)" on: schedule: @@ -32,7 +32,7 @@ jobs: if: github.repository == 'apache/spark' with: java: 21 - os: macos-14 + os: macos-15 envs: >- { "OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES" diff --git a/.github/workflows/build_python_3.11_macos.yml b/.github/workflows/build_python_3.11_macos.yml new file mode 100644 index 0000000000000..4caae55b5fea8 --- /dev/null +++ b/.github/workflows/build_python_3.11_macos.yml @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "Build / Python-only (master, Python 3.11, MacOS)" + +on: + schedule: + - cron: '0 21 * * *' + +jobs: + run-build: + permissions: + packages: write + name: Run + uses: ./.github/workflows/python_macos_test.yml + if: github.repository == 'apache/spark' diff --git a/.github/workflows/build_branch34.yml b/.github/workflows/build_python_3.13.yml similarity index 70% rename from .github/workflows/build_branch34.yml rename to .github/workflows/build_python_3.13.yml index deb6c42407970..6f67cf383584f 100644 --- a/.github/workflows/build_branch34.yml +++ b/.github/workflows/build_python_3.13.yml @@ -17,11 +17,11 @@ # under the License. # -name: "Build (branch-3.4, Scala 2.13, Hadoop 3, JDK 8)" +name: "Build / Python-only (master, Python 3.13)" on: schedule: - - cron: '0 9 * * *' + - cron: '0 20 * * *' jobs: run-build: @@ -31,21 +31,15 @@ jobs: uses: ./.github/workflows/build_and_test.yml if: github.repository == 'apache/spark' with: - java: 8 - branch: branch-3.4 + java: 17 + branch: master hadoop: hadoop3 envs: >- { - "SCALA_PROFILE": "scala2.13", - "PYTHON_TO_TEST": "", - "ORACLE_DOCKER_IMAGE_NAME": "gvenzl/oracle-xe:21.3.0" + "PYTHON_TO_TEST": "python3.13" } jobs: >- { - "build": "true", - "sparkr": "true", - "tpcds-1g": "true", - "docker-integration-tests": "true", - "k8s-integration-tests": "true", - "lint" : "true" + "pyspark": "true", + "pyspark-pandas": "true" } diff --git a/.github/workflows/build_branch34_python.yml b/.github/workflows/build_python_3.9.yml similarity index 88% rename from .github/workflows/build_branch34_python.yml rename to .github/workflows/build_python_3.9.yml index c109ba2dc7922..b2401fcf2aa14 100644 --- a/.github/workflows/build_branch34_python.yml +++ b/.github/workflows/build_python_3.9.yml @@ -17,11 +17,11 @@ # under the License. # -name: "Build / Python-only (branch-3.4)" +name: "Build / Python-only (master, Python 3.9)" on: schedule: - - cron: '0 9 * * *' + - cron: '0 21 * * *' jobs: run-build: @@ -31,12 +31,12 @@ jobs: uses: ./.github/workflows/build_and_test.yml if: github.repository == 'apache/spark' with: - java: 8 - branch: branch-3.4 + java: 17 + branch: master hadoop: hadoop3 envs: >- { - "PYTHON_TO_TEST": "" + "PYTHON_TO_TEST": "python3.9" } jobs: >- { diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index 8458cdf771b19..471ad31279da4 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -71,7 +71,7 @@ jobs: python packaging/connect/setup.py sdist cd dist pip install pyspark*connect-*.tar.gz - pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting + pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting 'plotly>=4.8' - name: Run tests env: SPARK_TESTING: 1 @@ -84,7 +84,7 @@ jobs: # Start a Spark Connect server for local PYTHONPATH="python/lib/pyspark.zip:python/lib/py4j-0.10.9.7-src.zip:$PYTHONPATH" ./sbin/start-connect-server.sh \ --driver-java-options "-Dlog4j.configurationFile=file:$GITHUB_WORKSPACE/conf/log4j2.properties" \ - --jars "`find connect/server/target -name spark-connect-*SNAPSHOT.jar`,`find connector/protobuf/target -name spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name spark-avro*SNAPSHOT.jar`" + --jars "`find connector/protobuf/target -name spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name spark-avro*SNAPSHOT.jar`" # Remove Py4J and PySpark zipped library to make sure there is no JVM connection mv python/lib lib.back @@ -93,7 +93,7 @@ jobs: # Several tests related to catalog requires to run them sequencially, e.g., writing a table in a listener. ./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-connect,pyspark-ml-connect # None of tests are dependent on each other in Pandas API on Spark so run them in parallel - ./python/run-tests --parallelism=4 --python-executables=python3 --modules pyspark-pandas-connect-part0,pyspark-pandas-connect-part1,pyspark-pandas-connect-part2,pyspark-pandas-connect-part3 + ./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-pandas-connect-part0,pyspark-pandas-connect-part1,pyspark-pandas-connect-part2,pyspark-pandas-connect-part3 # Stop Spark Connect server. ./sbin/stop-connect-server.sh @@ -104,7 +104,7 @@ jobs: PYTHONPATH="python/lib/pyspark.zip:python/lib/py4j-0.10.9.7-src.zip:$PYTHONPATH" ./sbin/start-connect-server.sh \ --master "local-cluster[2, 4, 1024]" \ --driver-java-options "-Dlog4j.configurationFile=file:$GITHUB_WORKSPACE/conf/log4j2.properties" \ - --jars "`find connect/server/target -name spark-connect-*SNAPSHOT.jar`,`find connector/protobuf/target -name spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name spark-avro*SNAPSHOT.jar`" + --jars "`find connector/protobuf/target -name spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name spark-avro*SNAPSHOT.jar`" # Remove Py4J and PySpark zipped library to make sure there is no JVM connection mv python/lib lib.back diff --git a/.github/workflows/build_python_connect35.yml b/.github/workflows/build_python_connect35.yml index b00fdddb4b0e0..ad250d95fb844 100644 --- a/.github/workflows/build_python_connect35.yml +++ b/.github/workflows/build_python_connect35.yml @@ -70,7 +70,7 @@ jobs: pip install 'numpy==1.25.1' 'pyarrow==12.0.1' 'pandas<=2.0.3' scipy unittest-xml-reporting plotly>=4.8 'mlflow>=2.3.1' coverage 'matplotlib==3.7.2' openpyxl 'memory-profiler==0.60.0' 'scikit-learn==1.1.*' # Add Python deps for Spark Connect. - pip install 'grpcio>=1.48,<1.57' 'grpcio-status>=1.48,<1.57' 'protobuf==3.20.3' 'googleapis-common-protos==1.56.4' 'graphviz==0.20.3' + pip install 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.28.3' 'googleapis-common-protos==1.65.0' 'graphviz==0.20.3' # Add torch as a testing dependency for TorchDistributor pip install 'torch==2.0.1' 'torchvision==0.15.2' torcheval @@ -87,7 +87,7 @@ jobs: # Start a Spark Connect server for local PYTHONPATH="python/lib/pyspark.zip:python/lib/py4j-0.10.9.7-src.zip:$PYTHONPATH" ./sbin/start-connect-server.sh \ --driver-java-options "-Dlog4j.configurationFile=file:$GITHUB_WORKSPACE/conf/log4j2.properties" \ - --jars "`find connect/server/target -name spark-connect-*SNAPSHOT.jar`,`find connector/protobuf/target -name spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name spark-avro*SNAPSHOT.jar`" + --jars "`find connector/protobuf/target -name spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name spark-avro*SNAPSHOT.jar`" # Checkout to branch-3.5 to use the tests in branch-3.5. cd .. @@ -98,7 +98,7 @@ jobs: # Run branch-3.5 tests ./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-connect # None of tests are dependent on each other in Pandas API on Spark so run them in parallel - ./python/run-tests --parallelism=4 --python-executables=python3 --modules pyspark-pandas-connect,pyspark-pandas-slow-connect + ./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-pandas-connect,pyspark-pandas-slow-connect - name: Upload test results to report if: always() uses: actions/upload-artifact@v4 diff --git a/.github/workflows/build_python_pypy3.9.yml b/.github/workflows/build_python_pypy3.10.yml similarity index 96% rename from .github/workflows/build_python_pypy3.9.yml rename to .github/workflows/build_python_pypy3.10.yml index e05071ef034a0..163af2f4aec8b 100644 --- a/.github/workflows/build_python_pypy3.9.yml +++ b/.github/workflows/build_python_pypy3.10.yml @@ -17,7 +17,7 @@ # under the License. # -name: "Build / Python-only (master, PyPy 3.9)" +name: "Build / Python-only (master, PyPy 3.10)" on: schedule: diff --git a/.github/workflows/build_sparkr_window.yml b/.github/workflows/build_sparkr_window.yml index cf879d9ebd306..b97251a461715 100644 --- a/.github/workflows/build_sparkr_window.yml +++ b/.github/workflows/build_sparkr_window.yml @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # -name: "Build / SparkR-only (master, 4.4.1, windows-2022)" +name: "Build / SparkR-only (master, 4.4.2, windows-2022)" on: schedule: @@ -50,10 +50,10 @@ jobs: with: distribution: zulu java-version: 17 - - name: Install R 4.4.1 + - name: Install R 4.4.2 uses: r-lib/actions/setup-r@v2 with: - r-version: 4.4.1 + r-version: 4.4.2 - name: Install R dependencies run: | Rscript -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival', 'arrow', 'xml2'), repos='https://cloud.r-project.org/')" @@ -85,6 +85,7 @@ jobs: shell: cmd env: NOT_CRAN: true + SPARKR_SUPPRESS_DEPRECATION_WARNING: 1 # See SPARK-27848. Currently installing some dependent packages causes # "(converted from warning) unable to identify current timezone 'C':" for an unknown reason. # This environment variable works around to test SparkR against a higher version. diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index fa30bd3abc8a5..6965fb4968af3 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -40,7 +40,7 @@ on: description: OS to run this build. required: false type: string - default: ubuntu-22.04 + default: ubuntu-latest envs: description: Additional environment variables to set when running the tests. Should be in JSON format. required: false @@ -178,7 +178,7 @@ jobs: - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql#core')) || contains(matrix.modules, 'connect') run: | - python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==4.25.1' + python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.28.3' python3.11 -m pip list # Run the tests. - name: Run tests @@ -194,7 +194,7 @@ jobs: if [[ "$INCLUDED_TAGS" != "" ]]; then ./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pjvm-profiler -Pspark-ganglia-lgpl -Pkinesis-asl -Djava.version=${JAVA_VERSION/-ea} -Dtest.include.tags="$INCLUDED_TAGS" test -fae elif [[ "$MODULES_TO_TEST" == "connect" ]]; then - ./build/mvn $MAVEN_CLI_OPTS -Dtest.exclude.tags="$EXCLUDED_TAGS" -Djava.version=${JAVA_VERSION/-ea} -pl connector/connect/client/jvm,connect/common,connect/server test -fae + ./build/mvn $MAVEN_CLI_OPTS -Dtest.exclude.tags="$EXCLUDED_TAGS" -Djava.version=${JAVA_VERSION/-ea} -pl connector/connect/client/jvm,sql/connect/common,sql/connect/server test -fae elif [[ "$EXCLUDED_TAGS" != "" ]]; then ./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pjvm-profiler -Pspark-ganglia-lgpl -Pkinesis-asl -Djava.version=${JAVA_VERSION/-ea} -Dtest.exclude.tags="$EXCLUDED_TAGS" test -fae elif [[ "$MODULES_TO_TEST" == *"sql#hive-thriftserver"* ]]; then diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml new file mode 100644 index 0000000000000..8729012c2b8d2 --- /dev/null +++ b/.github/workflows/pages.yml @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: GitHub Pages deployment + +on: + push: + branches: + - master + +concurrency: + group: 'docs preview' + cancel-in-progress: false + +jobs: + docs: + name: Build and deploy documentation + runs-on: ubuntu-latest + permissions: + id-token: write + pages: write + environment: + name: github-pages # https://github.com/actions/deploy-pages/issues/271 + env: + SPARK_TESTING: 1 # Reduce some noise in the logs + RELEASE_VERSION: 'In-Progress' + if: github.repository == 'apache/spark' + steps: + - name: Checkout Spark repository + uses: actions/checkout@v4 + with: + repository: apache/spark + ref: 'master' + - name: Install Java 17 + uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: 17 + - name: Install Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + architecture: x64 + cache: 'pip' + - name: Install Python dependencies + run: | + pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow 'pandas==2.2.3' 'plotly>=4.8' 'docutils<0.18.0' \ + 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ + 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ + 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' + - name: Install Ruby for documentation generation + uses: ruby/setup-ruby@v1 + with: + ruby-version: '3.3' + bundler-cache: true + - name: Install Pandoc + run: | + sudo apt-get update -y + sudo apt-get install pandoc + - name: Install dependencies for documentation generation + run: | + cd docs + gem install bundler -v 2.4.22 -n /usr/local/bin + bundle install --retry=100 + - name: Run documentation build + run: | + sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml + sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml + sed -i".tmp3" "s/'facetFilters':.*$/'facetFilters': [\"version:$RELEASE_VERSION\"]/g" docs/_config.yml + sed -i".tmp4" 's/__version__: str = .*$/__version__: str = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py + cd docs + SKIP_RDOC=1 bundle exec jekyll build + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: 'docs/_site' + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml index 1b5bd0ba61288..a5854d96a4d1a 100644 --- a/.github/workflows/publish_snapshot.yml +++ b/.github/workflows/publish_snapshot.yml @@ -28,7 +28,7 @@ on: description: 'list of branches to publish (JSON)' required: true # keep in sync with default value of strategy matrix 'branch' - default: '["master", "branch-3.5", "branch-3.4"]' + default: '["master", "branch-3.5"]' jobs: publish-snapshot: @@ -38,7 +38,7 @@ jobs: fail-fast: false matrix: # keep in sync with default value of workflow_dispatch input 'branch' - branch: ${{ fromJSON( inputs.branch || '["master", "branch-3.5", "branch-3.4"]' ) }} + branch: ${{ fromJSON( inputs.branch || '["master", "branch-3.5"]' ) }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -52,13 +52,13 @@ jobs: restore-keys: | snapshot-maven- - name: Install Java 8 for branch-3.x - if: matrix.branch == 'branch-3.5' || matrix.branch == 'branch-3.4' + if: matrix.branch == 'branch-3.5' uses: actions/setup-java@v4 with: distribution: temurin java-version: 8 - name: Install Java 17 - if: matrix.branch != 'branch-3.5' && matrix.branch != 'branch-3.4' + if: matrix.branch != 'branch-3.5' uses: actions/setup-java@v4 with: distribution: temurin diff --git a/.github/workflows/python_macos_test.yml b/.github/workflows/python_macos_test.yml new file mode 100644 index 0000000000000..cca133dab541a --- /dev/null +++ b/.github/workflows/python_macos_test.yml @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: Build and test PySpark on macOS + +on: + workflow_call: + inputs: + java: + required: false + type: string + default: 17 + python: + required: false + type: string + default: 3.11 + branch: + description: Branch to run the build against + required: false + type: string + default: master + hadoop: + description: Hadoop version to run with. HADOOP_PROFILE environment variable should accept it. + required: false + type: string + default: hadoop3 + envs: + description: Additional environment variables to set when running the tests. Should be in JSON format. + required: false + type: string + default: '{}' +jobs: + build: + name: "PySpark test on macos: ${{ matrix.modules }}" + runs-on: macos-15 + strategy: + fail-fast: false + matrix: + java: + - ${{ inputs.java }} + python: + - ${{inputs.python}} + modules: + - >- + pyspark-sql, pyspark-resource, pyspark-testing + - >- + pyspark-core, pyspark-errors, pyspark-streaming + - >- + pyspark-mllib, pyspark-ml, pyspark-ml-connect + - >- + pyspark-connect + - >- + pyspark-pandas + - >- + pyspark-pandas-slow + - >- + pyspark-pandas-connect-part0 + - >- + pyspark-pandas-connect-part1 + - >- + pyspark-pandas-connect-part2 + - >- + pyspark-pandas-connect-part3 + env: + MODULES_TO_TEST: ${{ matrix.modules }} + PYTHON_TO_TEST: python${{inputs.python}} + HADOOP_PROFILE: ${{ inputs.hadoop }} + HIVE_PROFILE: hive2.3 + # GitHub Actions' default miniconda to use in pip packaging test. + CONDA_PREFIX: /usr/share/miniconda + GITHUB_PREV_SHA: ${{ github.event.before }} + SPARK_LOCAL_IP: localhost + SKIP_UNIDOC: true + SKIP_MIMA: true + SKIP_PACKAGING: true + METASPACE_SIZE: 1g + BRANCH: ${{ inputs.branch }} + steps: + - name: Checkout Spark repository + uses: actions/checkout@v4 + # In order to fetch changed files + with: + fetch-depth: 0 + repository: apache/spark + ref: ${{ inputs.branch }} + - name: Sync the current branch with the latest in Apache Spark + if: github.repository != 'apache/spark' + run: | + echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV + git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/} + git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD + git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty + # Cache local repositories. Note that GitHub Actions cache has a 10G limit. + - name: Cache SBT and Maven + uses: actions/cache@v4 + with: + path: | + build/apache-maven-* + build/*.jar + ~/.sbt + key: build-${{ hashFiles('**/pom.xml', 'project/build.properties', 'build/mvn', 'build/sbt', 'build/sbt-launch-lib.bash', 'build/spark-build-info') }} + restore-keys: | + build- + - name: Cache Coursier local repository + uses: actions/cache@v4 + with: + path: ~/.cache/coursier + key: pyspark-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} + restore-keys: | + pyspark-coursier- + - name: Install Java ${{ matrix.java }} + uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: ${{ matrix.java }} + - name: Install Python packages (Python ${{matrix.python}}) + run: | + python${{matrix.python}} -m pip install --ignore-installed 'blinker>=1.6.2' + python${{matrix.python}} -m pip install --ignore-installed 'six==1.16.0' + python${{matrix.python}} -m pip install numpy 'pyarrow>=15.0.0' 'six==1.16.0' 'pandas==2.2.3' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' unittest-xml-reporting && \ + python${{matrix.python}} -m pip install 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.28.3' 'googleapis-common-protos==1.65.0' 'graphviz==0.20.3' && \ + python${{matrix.python}} -m pip cache purge && \ + python${{matrix.python}} -m pip list + # Run the tests. + - name: Run tests + env: ${{ fromJSON(inputs.envs) }} + run: | + if [[ "$MODULES_TO_TEST" == *"pyspark-errors"* ]]; then + export SKIP_PACKAGING=false + echo "Python Packaging Tests Enabled!" + fi + ./dev/run-tests --parallelism 1 --modules "$MODULES_TO_TEST" --python-executables "$PYTHON_TO_TEST" + - name: Upload test results to report + env: ${{ fromJSON(inputs.envs) }} + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results-${{ matrix.modules }}--${{ matrix.java }}-${{ inputs.hadoop }}-hive2.3-${{ env.PYTHON_TO_TEST }} + path: "**/target/test-reports/*.xml" + - name: Upload unit tests log files + env: ${{ fromJSON(inputs.envs) }} + if: ${{ !success() }} + uses: actions/upload-artifact@v4 + with: + name: unit-tests-log-${{ matrix.modules }}--${{ matrix.java }}-${{ inputs.hadoop }}-hive2.3-${{ env.PYTHON_TO_TEST }} + path: "**/target/unit-tests.log" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index f270673844551..e2db95083efea 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -25,6 +25,7 @@ on: jobs: stale: + if: github.repository == 'apache/spark' runs-on: ubuntu-latest steps: - uses: actions/stale@c201d45ef4b0ccbd3bb0616f93bae13e73d0a080 # pin@v1.1.0 diff --git a/.github/workflows/test_report.yml b/.github/workflows/test_report.yml index c6225e6a1abe5..9ab69af42c818 100644 --- a/.github/workflows/test_report.yml +++ b/.github/workflows/test_report.yml @@ -30,14 +30,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Download test results to report - uses: dawidd6/action-download-artifact@09385b76de790122f4da9c82b17bccf858b9557c # pin@v2 + uses: dawidd6/action-download-artifact@bf251b5aa9c2f7eeb574a96ee720e24f801b7c11 # pin @v6 with: github_token: ${{ secrets.GITHUB_TOKEN }} workflow: ${{ github.event.workflow_run.workflow_id }} commit: ${{ github.event.workflow_run.head_commit.id }} workflow_conclusion: completed - name: Publish test report - uses: scacap/action-surefire-report@482f012643ed0560e23ef605a79e8e87ca081648 # pin@v1 + uses: scacap/action-surefire-report@a2911bd1a4412ec18dde2d93b1758b3e56d2a880 # pin @v1.8.0 with: check_name: Report test results github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 787eb6180c35c..0a4138ec26948 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ *.swp *~ .java-version +.python-version .DS_Store .ammonite .bloop @@ -26,6 +27,7 @@ .scala_dependencies .settings .vscode +artifacts/ /lib/ R-unit-tests.log R/unit-tests.out diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/LICENSE-binary b/LICENSE-binary index b6971798e5577..40d28fbe71e6b 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -267,6 +267,7 @@ io.fabric8:kubernetes-model-scheduling io.fabric8:kubernetes-model-storageclass io.fabric8:zjsonpatch io.github.java-diff-utils:java-diff-utils +io.jsonwebtoken:jjwt-api io.netty:netty-all io.netty:netty-buffer io.netty:netty-codec @@ -401,7 +402,6 @@ org.xerial.snappy:snappy-java org.yaml:snakeyaml oro:oro stax:stax-api -xerces:xercesImpl core/src/main/java/org/apache/spark/util/collection/TimSort.java core/src/main/resources/org/apache/spark/ui/static/bootstrap* @@ -421,6 +421,11 @@ Python Software Foundation License python/pyspark/loose_version.py +BSD 0-Clause +------------ +org.tukaani:xz + + BSD 2-Clause ------------ com.github.luben:zstd-jni @@ -507,7 +512,6 @@ Eclipse Distribution License (EDL) 1.0 com.sun.istack:istack-commons-runtime jakarta.xml.bind:jakarta.xml.bind-api org.glassfish.jaxb:jaxb-runtime -org.glassfish.jaxb:txw2 Eclipse Public License (EPL) 2.0 -------------------------------- @@ -520,12 +524,6 @@ org.glassfish.hk2:hk2-locator org.glassfish.hk2:hk2-utils org.glassfish.hk2:osgi-resource-locator - -Public Domain -------------- -org.tukaani:xz - - Creative Commons CC0 1.0 Universal Public Domain Dedication ----------------------------------------------------------- (see LICENSE-CC0.txt) diff --git a/NOTICE-binary b/NOTICE-binary index c4cfe0e9f8b31..3f36596b9d6d6 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -448,27 +448,6 @@ which has the following notices: * Alec Wysoker * Performance and memory usage improvement -The binary distribution of this product bundles binaries of -Xerces2 Java Parser 2.9.1, -which has the following notices: - * ========================================================================= - == NOTICE file corresponding to section 4(d) of the Apache License, == - == Version 2.0, in this case for the Apache Xerces Java distribution. == - ========================================================================= - - Apache Xerces Java - Copyright 1999-2007 The Apache Software Foundation - - This product includes software developed at - The Apache Software Foundation (http://www.apache.org/). - - Portions of this software were originally based on the following: - - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. - - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. - - voluntary contributions made by Paul Eng on behalf of the - Apache Software Foundation that were originally developed at iClick, Inc., - software copyright (c) 1999. - Apache Commons Collections Copyright 2001-2015 The Apache Software Foundation diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index f7dd261c10fd2..49000c62d1063 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -57,6 +57,7 @@ Collate: 'types.R' 'utils.R' 'window.R' + 'zzz.R' RoxygenNote: 7.1.2 VignetteBuilder: knitr NeedsCompilation: no diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b91124f96a6fa..9c825a99be180 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3965,19 +3965,11 @@ setMethod("row_number", #' yields unresolved \code{a.b.c} #' @return Column object wrapping JVM UnresolvedNamedLambdaVariable #' @keywords internal -unresolved_named_lambda_var <- function(...) { - jc <- newJObject( - "org.apache.spark.sql.Column", - newJObject( - "org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable", - lapply(list(...), function(x) { - handledCallJStatic( - "org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable", - "freshVarName", - x) - }) - ) - ) +unresolved_named_lambda_var <- function(name) { + jc <- handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "unresolvedNamedLambdaVariable", + name) column(jc) } @@ -3990,7 +3982,6 @@ unresolved_named_lambda_var <- function(...) { #' @return JVM \code{LambdaFunction} object #' @keywords internal create_lambda <- function(fun) { - as_jexpr <- function(x) callJMethod(x@jc, "expr") # Process function arguments parameters <- formals(fun) @@ -4011,22 +4002,18 @@ create_lambda <- function(fun) { stopifnot(class(result) == "Column") # Convert both Columns to Scala expressions - jexpr <- as_jexpr(result) - jargs <- handledCallJStatic( "org.apache.spark.api.python.PythonUtils", "toSeq", - handledCallJStatic( - "java.util.Arrays", "asList", lapply(args, as_jexpr) - ) + handledCallJStatic("java.util.Arrays", "asList", lapply(args, function(x) { x@jc })) ) # Create Scala LambdaFunction - newJObject( - "org.apache.spark.sql.catalyst.expressions.LambdaFunction", - jexpr, - jargs, - FALSE + handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "lambdaFunction", + result@jc, + jargs ) } @@ -4039,20 +4026,18 @@ create_lambda <- function(fun) { #' @return a \code{Column} representing name applied to cols with funs #' @keywords internal invoke_higher_order_function <- function(name, cols, funs) { - as_jexpr <- function(x) { + as_col <- function(x) { if (class(x) == "character") { x <- column(x) } - callJMethod(x@jc, "expr") + x@jc } - - jexpr <- do.call(newJObject, c( - paste("org.apache.spark.sql.catalyst.expressions", name, sep = "."), - lapply(cols, as_jexpr), - lapply(funs, create_lambda) - )) - - column(newJObject("org.apache.spark.sql.Column", jexpr)) + jcol <- handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "fn", + name, + c(lapply(cols, as_col), lapply(funs, create_lambda))) # check varargs invocation + column(jcol) } #' @details @@ -4068,7 +4053,7 @@ setMethod("array_aggregate", signature(x = "characterOrColumn", initialValue = "Column", merge = "function"), function(x, initialValue, merge, finish = NULL) { invoke_higher_order_function( - "ArrayAggregate", + "aggregate", cols = list(x, initialValue), funs = if (is.null(finish)) { list(merge) @@ -4129,7 +4114,7 @@ setMethod("array_exists", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayExists", + "exists", cols = list(x), funs = list(f) ) @@ -4145,7 +4130,7 @@ setMethod("array_filter", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayFilter", + "filter", cols = list(x), funs = list(f) ) @@ -4161,7 +4146,7 @@ setMethod("array_forall", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayForAll", + "forall", cols = list(x), funs = list(f) ) @@ -4291,7 +4276,7 @@ setMethod("array_sort", column(callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc)) } else { invoke_higher_order_function( - "ArraySort", + "array_sort", cols = list(x), funs = list(comparator) ) @@ -4309,7 +4294,7 @@ setMethod("array_transform", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayTransform", + "transform", cols = list(x), funs = list(f) ) @@ -4374,7 +4359,7 @@ setMethod("arrays_zip_with", signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"), function(x, y, f) { invoke_higher_order_function( - "ZipWith", + "zip_with", cols = list(x, y), funs = list(f) ) @@ -4447,7 +4432,7 @@ setMethod("map_filter", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "MapFilter", + "map_filter", cols = list(x), funs = list(f)) }) @@ -4504,7 +4489,7 @@ setMethod("transform_keys", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "TransformKeys", + "transform_keys", cols = list(x), funs = list(f) ) @@ -4521,7 +4506,7 @@ setMethod("transform_values", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "TransformValues", + "transform_values", cols = list(x), funs = list(f) ) @@ -4552,7 +4537,7 @@ setMethod("map_zip_with", signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"), function(x, y, f) { invoke_higher_order_function( - "MapZipWith", + "map_zip_with", cols = list(x, y), funs = list(f) ) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 0be7e5da24d23..1b5faad376eaa 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -403,7 +403,6 @@ sparkR.session <- function( sparkPackages = "", enableHiveSupport = TRUE, ...) { - sparkConfigMap <- convertNamedListToEnv(sparkConfig) namedParams <- list(...) if (length(namedParams) > 0) { diff --git a/connector/docker/spark-test/base/Dockerfile b/R/pkg/R/zzz.R similarity index 64% rename from connector/docker/spark-test/base/Dockerfile rename to R/pkg/R/zzz.R index c397abc211e24..947bd543b75e0 100644 --- a/connector/docker/spark-test/base/Dockerfile +++ b/R/pkg/R/zzz.R @@ -14,15 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# zzz.R - package startup message -FROM ubuntu:20.04 - -# Upgrade package index -# install a few other useful packages plus Open Java 17 -# Remove unneeded /var/lib/apt/lists/* after install to reduce the -# docker image size (by ~30MB) -RUN apt-get update && \ - apt-get install -y less openjdk-17-jre-headless iproute2 vim-tiny sudo openssh-server && \ - rm -rf /var/lib/apt/lists/* - -ENV SPARK_HOME /opt/spark +.onAttach <- function(...) { + if (Sys.getenv("SPARKR_SUPPRESS_DEPRECATION_WARNING") == "") { + packageStartupMessage( + paste0( + "Warning: ", + "SparkR is deprecated in Apache Spark 4.0.0 and will be removed in a future release. ", + "To continue using Spark in R, we recommend using sparklyr instead: ", + "https://spark.posit.co/get-started/" + ) + ) + } +} diff --git a/R/pkg/README.md b/R/pkg/README.md index da9f042b4fded..c05a75812245c 100644 --- a/R/pkg/README.md +++ b/R/pkg/README.md @@ -1,4 +1,4 @@ -# R on Spark +# R on Spark (deprecated) SparkR is an R package that provides a light-weight frontend to use Spark from R. diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index c44924e55087f..c93b92edbff8e 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -4152,7 +4152,8 @@ test_that("catalog APIs, listTables, getTable, listColumns, listFunctions, funct c <- listColumns("cars") expect_equal(nrow(c), 2) expect_equal(colnames(c), - c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) + c("name", "description", "dataType", "nullable", "isPartition", "isBucket", + "isCluster")) expect_equal(collect(c)[[1]][[1]], "speed") expect_error(listColumns("zxwtyswklpf", "default"), "[TABLE_OR_VIEW_NOT_FOUND]*`spark_catalog`.`default`.`zxwtyswklpf`*") diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 88114f8bd82b8..74734746c129f 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -147,8 +147,7 @@ test_that("Unsupported operation", { # memory sink without aggregation df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), - paste0(".*(start : analysis error - Complete output mode not supported when there ", - "are no streaming aggregations on streaming DataFrames/Datasets).*")) + ".*analysis error.*complete.*not supported.*no streaming aggregations*") }) test_that("Terminated by error", { diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 4a510763afb6c..338b74110fb65 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -52,6 +52,8 @@ old_java_opt <- Sys.getenv("_JAVA_OPTIONS") Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ``` +SparkR is deprecated from Apache Spark 4.0.0 and will be removed in a future version. + ## Overview SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](https://spark.apache.org/mllib/). diff --git a/R/run-tests.sh b/R/run-tests.sh index 90a60eda03871..3a90b44c2b659 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -30,9 +30,9 @@ if [[ $(echo $SPARK_AVRO_JAR_PATH | wc -l) -eq 1 ]]; then fi if [ -z "$SPARK_JARS" ]; then - SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE + SPARKR_SUPPRESS_DEPRECATION_WARNING=1 SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE else - SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --jars $SPARK_JARS --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE + SPARKR_SUPPRESS_DEPRECATION_WARNING=1 SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --jars $SPARK_JARS --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE fi FAILED=$((PIPESTATUS[0]||$FAILED)) diff --git a/README.md b/README.md index b9a20075f6a17..552b71215cb92 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,8 @@ rich set of higher-level tools including Spark SQL for SQL and DataFrames, pandas API on Spark for pandas workloads, MLlib for machine learning, GraphX for graph processing, and Structured Streaming for stream processing. - +- Official version: +- Development version: [![GitHub Actions Build](https://github.com/apache/spark/actions/workflows/build_main.yml/badge.svg)](https://github.com/apache/spark/actions/workflows/build_main.yml) [![PySpark Coverage](https://codecov.io/gh/apache/spark/branch/master/graph/badge.svg)](https://codecov.io/gh/apache/spark) diff --git a/assembly/README b/assembly/README index ad1305c5b4d56..10c8254ae153e 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=3.4.0 + -Dhadoop.version=3.4.1 diff --git a/assembly/pom.xml b/assembly/pom.xml index 9377849cf1cdc..17bb81fa023ba 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -109,11 +109,27 @@ ${project.version} provided + + + org.apache.spark + spark-connect-client-jvm_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + + provided + com.google.guava @@ -159,6 +175,44 @@ + + + org.codehaus.mojo + exec-maven-plugin + + + copy-connect-client-repl-jars + package + + exec + + + cp + + -r + ${basedir}/../connector/connect/client/jvm/target/connect-repl + ${basedir}/target/scala-${scala.binary.version}/jars/ + + + + + copy-connect-client-jar + package + + exec + + + cp + + ${basedir}/../connector/connect/client/jvm/target/spark-connect-client-jvm_${scala.binary.version}-${project.version}.jar + ${basedir}/target/scala-${scala.binary.version}/jars/connect-repl + + + + + @@ -291,6 +345,14 @@ + + + jjwt + + compile + + + diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index a137a2fba52ee..efbd63a3037d6 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -261,18 +261,18 @@ Examples: $0 -m -t testing build - Build PySpark docker image - $0 -r docker.io/myrepo -t v3.4.0 -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build + $0 -r docker.io/myrepo -t v4.0.0 -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build - - Build and push image with tag "v3.4.0" to docker.io/myrepo - $0 -r docker.io/myrepo -t v3.4.0 build - $0 -r docker.io/myrepo -t v3.4.0 push + - Build and push image with tag "v4.0.0" to docker.io/myrepo + $0 -r docker.io/myrepo -t v4.0.0 build + $0 -r docker.io/myrepo -t v4.0.0 push - - Build and push Java11-based image with tag "v3.4.0" to docker.io/myrepo - $0 -r docker.io/myrepo -t v3.4.0 -b java_image_tag=11-jre build - $0 -r docker.io/myrepo -t v3.4.0 push + - Build and push Java17-based image with tag "v4.0.0" to docker.io/myrepo + $0 -r docker.io/myrepo -t v4.0.0 -b java_image_tag=17 build + $0 -r docker.io/myrepo -t v4.0.0 push - Build and push image for multiple archs to docker.io/myrepo - $0 -r docker.io/myrepo -t v3.4.0 -X build + $0 -r docker.io/myrepo -t v4.0.0 -X build # Note: buildx, which does cross building, needs to do the push during build # So there is no separate push step with -X diff --git a/bin/spark-shell b/bin/spark-shell index e920137974980..8411158187260 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -34,7 +34,7 @@ fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options] -Scala REPL options: +Scala REPL options, Spark Classic only: -I preload , enforcing line-by-line interpretation" # SPARK-4161: scala does not assume use of the java classpath, @@ -45,6 +45,9 @@ Scala REPL options: SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Dscala.usejavacp=true" function main() { + export SPARK_SCALA_SHELL=1 + # In case of Spark Connect shell, the main class (and resource) is replaced in + # SparkSubmitCommandBuilder. if $cygwin; then # Workaround for issue involving JLine and Cygwin # (see http://sourceforge.net/p/jline/bugs/40/). diff --git a/binder/Dockerfile b/binder/Dockerfile new file mode 100644 index 0000000000000..6e3dd9155fb7a --- /dev/null +++ b/binder/Dockerfile @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM python:3.10-slim +# install the notebook package +RUN pip install --no-cache notebook jupyterlab + +# create user with a home directory +ARG NB_USER +ARG NB_UID +ENV USER ${NB_USER} +ENV HOME /home/${NB_USER} + +RUN adduser --disabled-password \ + --gecos "Default user" \ + --uid ${NB_UID} \ + ${NB_USER} +WORKDIR ${HOME} +USER ${USER} + +# Make sure the contents of our repo are in ${HOME} +COPY . ${HOME} +USER root +RUN chown -R ${NB_UID} ${HOME} +RUN apt-get update && apt-get install -y openjdk-17-jre git coreutils +USER ${NB_USER} + +RUN binder/postBuild + diff --git a/binder/apt.txt b/binder/apt.txt deleted file mode 100644 index aa441a15db3bd..0000000000000 --- a/binder/apt.txt +++ /dev/null @@ -1,2 +0,0 @@ -openjdk-17-jre -git diff --git a/binder/postBuild b/binder/postBuild index 26e23b5a940ae..0b326f62e8f51 100644 --- a/binder/postBuild +++ b/binder/postBuild @@ -26,7 +26,7 @@ set -o pipefail set -e VERSION=$(python -c "exec(open('python/pyspark/version.py').read()); print(__version__)") -TAG=$(git describe --tags --exact-match 2>/dev/null) +TAG=$(git describe --tags --exact-match 2> /dev/null || true) # If a commit is tagged, exactly specified version of pyspark should be installed to avoid # a kind of accident that an old version of pyspark is installed in the live notebook environment. diff --git a/build/mvn b/build/mvn index 3735461562e54..fef589fc03476 100755 --- a/build/mvn +++ b/build/mvn @@ -56,9 +56,9 @@ install_app() { local binary="${_DIR}/$6" local remote_tarball="${mirror_host}/${url_path}${url_query}" local local_checksum="${local_tarball}.${checksum_suffix}" - local remote_checksum="https://archive.apache.org/dist/${url_path}.${checksum_suffix}" + local remote_checksum="${mirror_host}/${url_path}.${checksum_suffix}${url_query}" - local curl_opts="--silent --show-error -L" + local curl_opts="--retry 3 --silent --show-error -L" local wget_opts="--no-verbose" if [ ! -f "$binary" ]; then @@ -115,6 +115,10 @@ function version { echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }' # install maven under the build/ folder if needed. install_mvn() { local MVN_VERSION=`grep "" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` + MVN_BIN="${_DIR}/apache-maven-${MVN_VERSION}/bin/mvn" + if [ -f "$MVN_BIN" ]; then + return + fi MVN_BIN="$(command -v mvn)" if [ "$MVN_BIN" ]; then local MVN_DETECTED_VERSION="$(mvn --version | head -n1 | awk '{print $3}')" diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 13a9d89f4705c..7f8d6c58aec7e 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -255,7 +255,8 @@ public Iterator iterator() { iteratorTracker.add(new WeakReference<>(it)); return it; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } }; diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index 69757fdc65d68..29ed37ffa44e5 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -127,7 +127,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; @@ -151,7 +151,8 @@ public T next() { next = null; return ret; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java index dc7ad0be5c007..4bc2b233fe12d 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java @@ -287,7 +287,8 @@ public Iterator iterator() { iteratorTracker.add(new WeakReference<>(it)); return it; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } }; diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java index a98b0482e35cc..e350ddc2d445a 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java @@ -113,7 +113,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; @@ -137,7 +137,8 @@ public T next() { next = null; return ret; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index cdb5bd72158a1..cbe4836b58da5 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -194,6 +194,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 4c144a73a9299..a9df47645d36f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -290,9 +290,11 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); } catch (ExecutionException e) { - throw Throwables.propagate(e.getCause()); + Throwables.throwIfUnchecked(e.getCause()); + throw new RuntimeException(e.getCause()); } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index e1f19f956cc0a..d64b8c8f838e9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -342,7 +342,8 @@ public void operationComplete(final Future handshakeFuture) { logger.error("Exception while bootstrapping client after {} ms", e, MDC.of(LogKeys.BOOTSTRAP_TIME$.MODULE$, bootstrapTimeMs)); client.close(); - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } long postBootstrap = System.nanoTime(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java index 08e2c084fe67b..2e9ccd0e0ad21 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -22,7 +22,6 @@ import java.security.GeneralSecurityException; import java.util.concurrent.TimeoutException; -import com.google.common.base.Throwables; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; @@ -80,7 +79,7 @@ public void doBootstrap(TransportClient client, Channel channel) { doSparkAuth(client, channel); client.setClientId(appId); } catch (GeneralSecurityException | IOException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } catch (RuntimeException e) { // There isn't a good exception that can be caught here to know whether it's really // OK to switch back to SASL (because the server doesn't speak the new protocol). So diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 65367743e24f9..087e3d21e22bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -132,7 +132,8 @@ protected boolean doAuthChallenge( try { engine.close(); } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index 355c552720185..33494aee4444d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -17,32 +17,12 @@ package org.apache.spark.network.crypto; -import com.google.common.annotations.VisibleForTesting; -import com.google.crypto.tink.subtle.Hex; -import com.google.crypto.tink.subtle.Hkdf; import io.netty.channel.Channel; -import javax.crypto.spec.SecretKeySpec; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; interface TransportCipher { String getKeyId() throws GeneralSecurityException; void addToChannel(Channel channel) throws IOException, GeneralSecurityException; } - -class TransportCipherUtil { - /* - * This method is used for testing to verify key derivation. - */ - @VisibleForTesting - static String getKeyId(SecretKeySpec key) throws GeneralSecurityException { - byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256", - key.getEncoded(), - null, - "keyID".getBytes(StandardCharsets.UTF_8), - 32); - return Hex.encode(keyIdBytes); - } -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipherUtil.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipherUtil.java new file mode 100644 index 0000000000000..1df2732f240cc --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipherUtil.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import javax.crypto.spec.SecretKeySpec; + +import com.google.common.annotations.VisibleForTesting; +import com.google.crypto.tink.subtle.Hex; +import com.google.crypto.tink.subtle.Hkdf; + +class TransportCipherUtil { + /** + * This method is used for testing to verify key derivation. + */ + @VisibleForTesting + static String getKeyId(SecretKeySpec key) throws GeneralSecurityException { + byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256", + key.getEncoded(), + null, + "keyID".getBytes(StandardCharsets.UTF_8), + 32); + return Hex.encode(keyIdBytes); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 3600c1045dbf4..a61b1c3c0c416 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -29,7 +29,6 @@ import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import org.apache.spark.internal.SparkLogger; @@ -62,7 +61,7 @@ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, bool this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, saslProps, new ClientCallbackHandler()); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -72,7 +71,7 @@ public synchronized byte[] firstToken() { try { return saslClient.evaluateChallenge(new byte[0]); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } else { return new byte[0]; @@ -98,7 +97,7 @@ public synchronized byte[] response(byte[] token) { try { return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index b897650afe832..f32fd5145c7c5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -31,7 +31,6 @@ import java.util.Map; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -94,7 +93,7 @@ public SparkSaslServer( this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps, new DigestCallbackHandler()); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -119,7 +118,7 @@ public synchronized byte[] response(byte[] token) { try { return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java index 5796e34a6f05e..2ac549775449a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java @@ -17,8 +17,6 @@ package org.apache.spark.network.shuffledb; -import com.google.common.base.Throwables; - import java.io.IOException; import java.util.Map; import java.util.NoSuchElementException; @@ -47,7 +45,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java index d33895d6c2d62..2737ab8ed754c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java @@ -19,7 +19,6 @@ import java.io.IOException; -import com.google.common.base.Throwables; import org.rocksdb.RocksDBException; /** @@ -37,7 +36,7 @@ public void put(byte[] key, byte[] value) { try { db.put(key, value); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -46,7 +45,7 @@ public byte[] get(byte[] key) { try { return db.get(key); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -55,7 +54,7 @@ public void delete(byte[] key) { try { db.delete(key); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java index 78562f91a4b75..829a7ded6330b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java @@ -22,7 +22,6 @@ import java.util.Map; import java.util.NoSuchElementException; -import com.google.common.base.Throwables; import org.rocksdb.RocksIterator; /** @@ -52,7 +51,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 0f7036ef746cc..49e6e08476151 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -113,6 +113,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + commons-io commons-io diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index f9c0c60c2f2c6..62fcda701d948 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -19,10 +19,7 @@ import java.io.*; import java.util.concurrent.TimeUnit; -import java.util.zip.Adler32; -import java.util.zip.CRC32; -import java.util.zip.CheckedInputStream; -import java.util.zip.Checksum; +import java.util.zip.*; import com.google.common.io.ByteStreams; @@ -66,6 +63,13 @@ private static Checksum[] getChecksumsByAlgorithm(int num, String algorithm) { } } + case "CRC32C" -> { + checksums = new CRC32C[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new CRC32C(); + } + } + default -> throw new UnsupportedOperationException( "Unsupported shuffle checksum algorithm: " + algorithm); } diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index a5ef9847859a7..cf15301273303 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -104,6 +104,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index b9868ca665a65..97c8bbe562aff 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.util; import com.ibm.icu.lang.UCharacter; +import com.ibm.icu.lang.UProperty; import com.ibm.icu.text.BreakIterator; import com.ibm.icu.text.Collator; import com.ibm.icu.text.RuleBasedCollator; @@ -26,15 +27,15 @@ import org.apache.spark.unsafe.UTF8StringBuilder; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.Platform.copyMemory; import static org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType; import java.text.CharacterIterator; import java.text.StringCharacterIterator; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Map; /** @@ -48,19 +49,28 @@ public class CollationAwareUTF8String { */ private static final int MATCH_NOT_FOUND = -1; + /** + * `COMBINED_ASCII_SMALL_I_COMBINING_DOT` is an internal representation of the combined + * lowercase code point for ASCII lowercase letter i with an additional combining dot character + * (U+0307). This integer value is not a valid code point itself, but rather an artificial code + * point marker used to represent the two lowercase characters that are the result of converting + * the uppercase Turkish dotted letter I with a combining dot character (U+0130) to lowercase. + */ + private static final int COMBINED_ASCII_SMALL_I_COMBINING_DOT = + SpecialCodePointConstants.ASCII_SMALL_I << 16 | SpecialCodePointConstants.COMBINING_DOT; + /** * Returns whether the target string starts with the specified prefix, starting from the * specified position (0-based index referring to character position in UTF8String), with respect - * to the UTF8_LCASE collation. The method assumes that the prefix is already lowercased - * prior to method call to avoid the overhead of calling .toLowerCase() multiple times on the - * same prefix string. + * to the UTF8_LCASE collation. The method assumes that the prefix is already lowercased prior + * to method call to avoid the overhead of lowercasing the same prefix string multiple times. * * @param target the string to be searched in * @param lowercasePattern the string to be searched for * @param startPos the start position for searching (in the target string) * @return whether the target string starts with the specified prefix in UTF8_LCASE */ - public static boolean lowercaseMatchFrom( + private static boolean lowercaseMatchFrom( final UTF8String target, final UTF8String lowercasePattern, int startPos) { @@ -86,12 +96,44 @@ private static int lowercaseMatchLengthFrom( final UTF8String lowercasePattern, int startPos) { assert startPos >= 0; - for (int len = 0; len <= target.numChars() - startPos; ++len) { - if (target.substring(startPos, startPos + len).toLowerCase().equals(lowercasePattern)) { - return len; + // Use code point iterators for efficient string search. + Iterator targetIterator = target.codePointIterator(); + Iterator patternIterator = lowercasePattern.codePointIterator(); + // Skip to startPos in the target string. + for (int i = 0; i < startPos; ++i) { + if (targetIterator.hasNext()) { + targetIterator.next(); + } else { + return MATCH_NOT_FOUND; } } - return MATCH_NOT_FOUND; + // Compare the characters in the target and pattern strings. + int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { + if (codePointBuffer != -1) { + targetCodePoint = codePointBuffer; + codePointBuffer = -1; + } else { + // Use buffered lowercase code point iteration to handle one-to-many case mappings. + targetCodePoint = getLowercaseCodePoint(targetIterator.next()); + if (targetCodePoint == COMBINED_ASCII_SMALL_I_COMBINING_DOT) { + targetCodePoint = SpecialCodePointConstants.ASCII_SMALL_I; + codePointBuffer = SpecialCodePointConstants.COMBINING_DOT; + } + ++matchLength; + } + patternCodePoint = patternIterator.next(); + if (targetCodePoint != patternCodePoint) { + return MATCH_NOT_FOUND; + } + } + // If the pattern string has more characters, or the match is found at the middle of a + // character that maps to multiple characters in lowercase, then match is not found. + if (patternIterator.hasNext() || codePointBuffer != -1) { + return MATCH_NOT_FOUND; + } + // If all characters are equal, return the length of the match in the target string. + return matchLength; } /** @@ -123,15 +165,14 @@ private static int lowercaseFind( * Returns whether the target string ends with the specified suffix, ending at the specified * position (0-based index referring to character position in UTF8String), with respect to the * UTF8_LCASE collation. The method assumes that the suffix is already lowercased prior - * to method call to avoid the overhead of calling .toLowerCase() multiple times on the same - * suffix string. + * to method call to avoid the overhead of lowercasing the same suffix string multiple times. * * @param target the string to be searched in * @param lowercasePattern the string to be searched for * @param endPos the end position for searching (in the target string) * @return whether the target string ends with the specified suffix in lowercase */ - public static boolean lowercaseMatchUntil( + private static boolean lowercaseMatchUntil( final UTF8String target, final UTF8String lowercasePattern, int endPos) { @@ -156,13 +197,45 @@ private static int lowercaseMatchLengthUntil( final UTF8String target, final UTF8String lowercasePattern, int endPos) { - assert endPos <= target.numChars(); - for (int len = 0; len <= endPos; ++len) { - if (target.substring(endPos - len, endPos).toLowerCase().equals(lowercasePattern)) { - return len; + assert endPos >= 0; + // Use code point iterators for efficient string search. + Iterator targetIterator = target.reverseCodePointIterator(); + Iterator patternIterator = lowercasePattern.reverseCodePointIterator(); + // Skip to startPos in the target string. + for (int i = endPos; i < target.numChars(); ++i) { + if (targetIterator.hasNext()) { + targetIterator.next(); + } else { + return MATCH_NOT_FOUND; } } - return MATCH_NOT_FOUND; + // Compare the characters in the target and pattern strings. + int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { + if (codePointBuffer != -1) { + targetCodePoint = codePointBuffer; + codePointBuffer = -1; + } else { + // Use buffered lowercase code point iteration to handle one-to-many case mappings. + targetCodePoint = getLowercaseCodePoint(targetIterator.next()); + if (targetCodePoint == COMBINED_ASCII_SMALL_I_COMBINING_DOT) { + targetCodePoint = SpecialCodePointConstants.COMBINING_DOT; + codePointBuffer = SpecialCodePointConstants.ASCII_SMALL_I; + } + ++matchLength; + } + patternCodePoint = patternIterator.next(); + if (targetCodePoint != patternCodePoint) { + return MATCH_NOT_FOUND; + } + } + // If the pattern string has more characters, or the match is found at the middle of a + // character that maps to multiple characters in lowercase, then match is not found. + if (patternIterator.hasNext() || codePointBuffer != -1) { + return MATCH_NOT_FOUND; + } + // If all characters are equal, return the length of the match in the target string. + return matchLength; } /** @@ -191,10 +264,9 @@ private static int lowercaseRFind( } /** - * Lowercase UTF8String comparison used for UTF8_LCASE collation. While the default - * UTF8String comparison is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase()), this - * method uses code points to compare the strings in a case-insensitive manner using ICU rules, - * as well as handling special rules for one-to-many case mappings (see: lowerCaseCodePoints). + * Lowercase UTF8String comparison used for UTF8_LCASE collation. This method uses lowercased + * code points to compare the strings in a case-insensitive manner using ICU rules, taking into + * account special rules for one-to-many case mappings (see: lowerCaseCodePoints). * * @param left The first UTF8String to compare. * @param right The second UTF8String to compare. @@ -238,115 +310,82 @@ private static int compareLowerCaseSlow(final UTF8String left, final UTF8String return lowerCaseCodePoints(left).binaryCompare(lowerCaseCodePoints(right)); } - /* + /** * Performs string replacement for ICU collations by searching for instances of the search - * string in the `src` string, with respect to the specified collation, and then replacing + * string in the `target` string, with respect to the specified collation, and then replacing * them with the replace string. The method returns a new UTF8String with all instances of the * search string replaced using the replace string. Similar to UTF8String.findInSet behavior - * used for UTF8_BINARY, the method returns the `src` string if the `search` string is empty. + * used for UTF8_BINARY, the method returns the `target` string if the `search` string is empty. * - * @param src the string to be searched in + * @param target the string to be searched in * @param search the string to be searched for * @param replace the string to be used as replacement * @param collationId the collation ID to use for string search * @return the position of the first occurrence of `match` in `set` */ - public static UTF8String replace(final UTF8String src, final UTF8String search, + public static UTF8String replace(final UTF8String target, final UTF8String search, final UTF8String replace, final int collationId) { // This collation aware implementation is based on existing implementation on UTF8String - if (src.numBytes() == 0 || search.numBytes() == 0) { - return src; - } - - StringSearch stringSearch = CollationFactory.getStringSearch(src, search, collationId); - - // Find the first occurrence of the search string. - int end = stringSearch.next(); - if (end == StringSearch.DONE) { - // Search string was not found, so string is unchanged. - return src; + if (target.numBytes() == 0 || search.numBytes() == 0) { + return target; } - // Initialize byte positions - int c = 0; - int byteStart = 0; // position in byte - int byteEnd = 0; // position in byte - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } + String targetStr = target.toValidString(); + String searchStr = search.toValidString(); + StringSearch stringSearch = CollationFactory.getStringSearch(targetStr, searchStr, collationId); - // At least one match was found. Estimate space needed for result. - // The 16x multiplier here is chosen to match commons-lang3's implementation. - int increase = Math.max(0, Math.abs(replace.numBytes() - search.numBytes())) * 16; - final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); - while (end != StringSearch.DONE) { - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); - buf.append(replace); - - // Move byteStart to the beginning of the current match - byteStart = byteEnd; - int cs = c; - // Move cs to the end of the current match - // This is necessary because the search string may contain 'multi-character' characters - while (byteStart < src.numBytes() && cs < c + stringSearch.getMatchLength()) { - byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart)); - cs += 1; - } - // Go to next match - end = stringSearch.next(); - // Update byte positions - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } + StringBuilder sb = new StringBuilder(); + int start = 0; + int matchStart = stringSearch.first(); + while (matchStart != StringSearch.DONE) { + sb.append(targetStr, start, matchStart); + sb.append(replace.toValidString()); + start = matchStart + stringSearch.getMatchLength(); + matchStart = stringSearch.next(); } - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, - src.numBytes() - byteStart); - return buf.build(); + sb.append(targetStr, start, targetStr.length()); + return UTF8String.fromString(sb.toString()); } - /* + /** * Performs string replacement for UTF8_LCASE collation by searching for instances of the search - * string in the src string, with respect to lowercased string versions, and then replacing + * string in the target string, with respect to lowercased string versions, and then replacing * them with the replace string. The method returns a new UTF8String with all instances of the * search string replaced using the replace string. Similar to UTF8String.findInSet behavior - * used for UTF8_BINARY, the method returns the `src` string if the `search` string is empty. + * used for UTF8_BINARY, the method returns the `target` string if the `search` string is empty. * - * @param src the string to be searched in + * @param target the string to be searched in * @param search the string to be searched for * @param replace the string to be used as replacement - * @param collationId the collation ID to use for string search * @return the position of the first occurrence of `match` in `set` */ - public static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, + public static UTF8String lowercaseReplace(final UTF8String target, final UTF8String search, final UTF8String replace) { - if (src.numBytes() == 0 || search.numBytes() == 0) { - return src; + if (target.numBytes() == 0 || search.numBytes() == 0) { + return target; } - // TODO(SPARK-48725): Use lowerCaseCodePoints instead of UTF8String.toLowerCase. - UTF8String lowercaseSearch = search.toLowerCase(); + UTF8String lowercaseSearch = lowerCaseCodePoints(search); int start = 0; - int end = lowercaseFind(src, lowercaseSearch, start); + int end = lowercaseFind(target, lowercaseSearch, start); if (end == -1) { // Search string was not found, so string is unchanged. - return src; + return target; } // At least one match was found. Estimate space needed for result. // The 16x multiplier here is chosen to match commons-lang3's implementation. int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16; - final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + final UTF8StringBuilder buf = new UTF8StringBuilder(target.numBytes() + increase); while (end != -1) { - buf.append(src.substring(start, end)); + buf.append(target.substring(start, end)); buf.append(replace); // Update character positions - start = end + lowercaseMatchLengthFrom(src, lowercaseSearch, end); - end = lowercaseFind(src, lowercaseSearch, start); + start = end + lowercaseMatchLengthFrom(target, lowercaseSearch, end); + end = lowercaseFind(target, lowercaseSearch, start); } - buf.append(src.substring(start, src.numChars())); + buf.append(target.substring(start, target.numChars())); return buf.build(); } @@ -433,28 +472,16 @@ private static UTF8String toLowerCaseSlow(final UTF8String target, final int col */ private static void appendLowercaseCodePoint(final int codePoint, final StringBuilder sb) { int lowercaseCodePoint = getLowercaseCodePoint(codePoint); - if (lowercaseCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) { + if (lowercaseCodePoint == COMBINED_ASCII_SMALL_I_COMBINING_DOT) { // Latin capital letter I with dot above is mapped to 2 lowercase characters. - sb.appendCodePoint(0x0069); - sb.appendCodePoint(0x0307); + sb.appendCodePoint(SpecialCodePointConstants.ASCII_SMALL_I); + sb.appendCodePoint(SpecialCodePointConstants.COMBINING_DOT); } else { // All other characters should follow context-unaware ICU single-code point case mapping. sb.appendCodePoint(lowercaseCodePoint); } } - /** - * `CODE_POINT_COMBINED_LOWERCASE_I_DOT` is an internal representation of the combined lowercase - * code point for ASCII lowercase letter i with an additional combining dot character (U+0307). - * This integer value is not a valid code point itself, but rather an artificial code point - * marker used to represent the two lowercase characters that are the result of converting the - * uppercase Turkish dotted letter I with a combining dot character (U+0130) to lowercase. - */ - private static final int CODE_POINT_LOWERCASE_I = 0x69; - private static final int CODE_POINT_COMBINING_DOT = 0x307; - private static final int CODE_POINT_COMBINED_LOWERCASE_I_DOT = - CODE_POINT_LOWERCASE_I << 16 | CODE_POINT_COMBINING_DOT; - /** * Returns the lowercase version of the provided code point, with special handling for * one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and @@ -462,15 +489,15 @@ private static void appendLowercaseCodePoint(final int codePoint, final StringBu * the position in the string relative to other characters in lowercase). */ private static int getLowercaseCodePoint(final int codePoint) { - if (codePoint == 0x0130) { + if (codePoint == SpecialCodePointConstants.CAPITAL_I_WITH_DOT_ABOVE) { // Latin capital letter I with dot above is mapped to 2 lowercase characters. - return CODE_POINT_COMBINED_LOWERCASE_I_DOT; + return COMBINED_ASCII_SMALL_I_COMBINING_DOT; } - else if (codePoint == 0x03C2) { + else if (codePoint == SpecialCodePointConstants.GREEK_FINAL_SIGMA) { // Greek final and non-final letter sigma should be mapped the same. This is achieved by // mapping Greek small final sigma (U+03C2) to Greek small non-final sigma (U+03C3). Capital // letter sigma (U+03A3) is mapped to small non-final sigma (U+03C3) in the `else` branch. - return 0x03C3; + return SpecialCodePointConstants.GREEK_SMALL_SIGMA; } else { // All other characters should follow context-unaware ICU single-code point case mapping. @@ -522,6 +549,152 @@ public static UTF8String toTitleCase(final UTF8String target, final int collatio BreakIterator.getWordInstance(locale))); } + /** + * This 'HashMap' is introduced as a performance speedup. Since title-casing a codepoint can + * result in more than a single codepoint, for correctness, we would use + * 'UCharacter.toTitleCase(String)' which returns a 'String'. If we use + * 'UCharacter.toTitleCase(int)' (the version of the same function which converts a single + * codepoint to its title-case codepoint), it would be faster than the previously mentioned + * version, but the problem here is that we don't handle when title-casing a codepoint yields more + * than 1 codepoint. Since there are only 48 codepoints that are mapped to more than 1 codepoint + * when title-cased, they are precalculated here, so that the faster function for title-casing + * could be used in combination with this 'HashMap' in the method 'appendCodepointToTitleCase'. + */ + private static final HashMap codepointOneToManyTitleCaseLookupTable = + new HashMap<>(){{ + StringBuilder sb = new StringBuilder(); + for (int i = Character.MIN_CODE_POINT; i <= Character.MAX_CODE_POINT; ++i) { + sb.appendCodePoint(i); + String titleCase = UCharacter.toTitleCase(sb.toString(), null); + if (titleCase.codePointCount(0, titleCase.length()) > 1) { + put(i, titleCase); + } + sb.setLength(0); + } + }}; + + /** + * Title-casing a string using ICU case mappings. Iterates over the string and title-cases + * the first character in each word, and lowercases every other character. Handles lowercasing + * capital Greek letter sigma ('Σ') separately, taking into account if it should be a small final + * Greek sigma ('ς') or small non-final Greek sigma ('σ'). Words are separated by ASCII + * space(\u0020). + * + * @param source UTF8String to be title cased + * @return title cased source + */ + public static UTF8String toTitleCaseICU(UTF8String source) { + // In the default UTF8String implementation, `toLowerCase` method implicitly does UTF8String + // validation (replacing invalid UTF-8 byte sequences with Unicode replacement character + // U+FFFD), but now we have to do the validation manually. + source = source.makeValid(); + + // Building the title cased source with 'sb'. + UTF8StringBuilder sb = new UTF8StringBuilder(); + + // 'isNewWord' is true if the current character is the beginning of a word, false otherwise. + boolean isNewWord = true; + // We are maintaining if the current character is preceded by a cased letter. + // This is used when lowercasing capital Greek letter sigma ('Σ'), to figure out if it should be + // lowercased into σ or ς. + boolean precededByCasedLetter = false; + + // 'offset' is a byte offset in source's byte array pointing to the beginning of the character + // that we need to process next. + int offset = 0; + int len = source.numBytes(); + + while (offset < len) { + // We will actually call 'codePointFrom()' 2 times for each character in the worst case (once + // here, and once in 'followedByCasedLetter'). Example of a string where we call it 2 times + // for almost every character is 'ΣΣΣΣΣ' (a string consisting only of Greek capital sigma) + // and 'Σ`````' (a string consisting of a Greek capital sigma, followed by case-ignorable + // characters). + int codepoint = source.codePointFrom(offset); + // Appending the correctly cased character onto 'sb'. + appendTitleCasedCodepoint(sb, codepoint, isNewWord, precededByCasedLetter, source, offset); + // Updating 'isNewWord', 'precededByCasedLetter' and 'offset' to be ready for the next + // character that we will process. + isNewWord = (codepoint == SpecialCodePointConstants.ASCII_SPACE); + if (!UCharacter.hasBinaryProperty(codepoint, UProperty.CASE_IGNORABLE)) { + precededByCasedLetter = UCharacter.hasBinaryProperty(codepoint, UProperty.CASED); + } + offset += UTF8String.numBytesForFirstByte(source.getByte(offset)); + } + return sb.build(); + } + + private static void appendTitleCasedCodepoint( + UTF8StringBuilder sb, + int codepoint, + boolean isAfterAsciiSpace, + boolean precededByCasedLetter, + UTF8String source, + int offset) { + if (isAfterAsciiSpace) { + // Title-casing a character if it is in the beginning of a new word. + appendCodepointToTitleCase(sb, codepoint); + return; + } + if (codepoint == SpecialCodePointConstants.GREEK_CAPITAL_SIGMA) { + // Handling capital Greek letter sigma ('Σ'). + appendLowerCasedGreekCapitalSigma(sb, precededByCasedLetter, source, offset); + return; + } + // If it's not the beginning of a word, or a capital Greek letter sigma ('Σ'), we lowercase the + // character. We specially handle 'CAPITAL_I_WITH_DOT_ABOVE'. + if (codepoint == SpecialCodePointConstants.CAPITAL_I_WITH_DOT_ABOVE) { + sb.appendCodePoint(SpecialCodePointConstants.ASCII_SMALL_I); + sb.appendCodePoint(SpecialCodePointConstants.COMBINING_DOT); + return; + } + sb.appendCodePoint(UCharacter.toLowerCase(codepoint)); + } + + private static void appendLowerCasedGreekCapitalSigma( + UTF8StringBuilder sb, + boolean precededByCasedLetter, + UTF8String source, + int offset) { + int codepoint = (!followedByCasedLetter(source, offset) && precededByCasedLetter) + ? SpecialCodePointConstants.GREEK_FINAL_SIGMA + : SpecialCodePointConstants.GREEK_SMALL_SIGMA; + sb.appendCodePoint(codepoint); + } + + /** + * Checks if the character beginning at 'offset'(in 'sources' byte array) is followed by a cased + * letter. + */ + private static boolean followedByCasedLetter(UTF8String source, int offset) { + // Moving the offset one character forward, so we could start the linear search from there. + offset += UTF8String.numBytesForFirstByte(source.getByte(offset)); + int len = source.numBytes(); + + while (offset < len) { + int codepoint = source.codePointFrom(offset); + + if (UCharacter.hasBinaryProperty(codepoint, UProperty.CASE_IGNORABLE)) { + offset += UTF8String.numBytesForFirstByte(source.getByte(offset)); + continue; + } + return UCharacter.hasBinaryProperty(codepoint, UProperty.CASED); + } + return false; + } + + /** + * Appends title-case of a single character to a 'StringBuilder' using the ICU root locale rules. + */ + private static void appendCodepointToTitleCase(UTF8StringBuilder sb, int codepoint) { + String toTitleCase = codepointOneToManyTitleCaseLookupTable.get(codepoint); + if (toTitleCase == null) { + sb.appendCodePoint(UCharacter.toTitleCase(codepoint)); + } else { + sb.append(toTitleCase); + } + } + /* * Returns the position of the first occurrence of the match string in the set string, * counting ASCII commas as delimiters. The match string is compared in a collation-aware manner, @@ -559,6 +732,58 @@ public static int findInSet(final UTF8String match, final UTF8String set, int co return 0; } + /** + * Checks whether the target string contains the pattern string, with respect to the UTF8_LCASE + * collation. This method generally works with respect to code-point based comparison logic. + * + * @param target the string to be searched in + * @param pattern the string to be searched for + * @return whether the target string contains the pattern string + */ + public static boolean lowercaseContains(final UTF8String target, final UTF8String pattern) { + // Fast path for ASCII-only strings. + if (target.isFullAscii() && pattern.isFullAscii()) { + return target.toLowerCase().contains(pattern.toLowerCase()); + } + // Slow path for non-ASCII strings. + return CollationAwareUTF8String.lowercaseIndexOfSlow(target, pattern, 0) >= 0; + } + + /** + * Checks whether the target string starts with the pattern string, with respect to the UTF8_LCASE + * collation. This method generally works with respect to code-point based comparison logic. + * + * @param target the string to be searched in + * @param pattern the string to be searched for + * @return whether the target string starts with the pattern string + */ + public static boolean lowercaseStartsWith(final UTF8String target, final UTF8String pattern) { + // Fast path for ASCII-only strings. + if (target.isFullAscii() && pattern.isFullAscii()) { + return target.toLowerCase().startsWith(pattern.toLowerCase()); + } + // Slow path for non-ASCII strings. + return CollationAwareUTF8String.lowercaseMatchFrom(target, lowerCaseCodePointsSlow(pattern), 0); + } + + /** + * Checks whether the target string ends with the pattern string, with respect to the UTF8_LCASE + * collation. This method generally works with respect to code-point based comparison logic. + * + * @param target the string to be searched in + * @param pattern the string to be searched for + * @return whether the target string ends with the pattern string + */ + public static boolean lowercaseEndsWith(final UTF8String target, final UTF8String pattern) { + // Fast path for ASCII-only strings. + if (target.isFullAscii() && pattern.isFullAscii()) { + return target.toLowerCase().endsWith(pattern.toLowerCase()); + } + // Slow path for non-ASCII strings. + return CollationAwareUTF8String.lowercaseMatchUntil(target, lowerCaseCodePointsSlow(pattern), + target.numChars()); + } + /** * Returns the position of the first occurrence of the pattern string in the target string, * starting from the specified position (0-based index referring to character position in @@ -573,30 +798,76 @@ public static int findInSet(final UTF8String match, final UTF8String set, int co public static int lowercaseIndexOf(final UTF8String target, final UTF8String pattern, final int start) { if (pattern.numChars() == 0) return target.indexOfEmpty(start); - return lowercaseFind(target, pattern.toLowerCase(), start); + if (target.isFullAscii() && pattern.isFullAscii()) { + return target.toLowerCase().indexOf(pattern.toLowerCase(), start); + } + return lowercaseIndexOfSlow(target, pattern, start); + } + + private static int lowercaseIndexOfSlow(final UTF8String target, final UTF8String pattern, + final int start) { + return lowercaseFind(target, lowerCaseCodePoints(pattern), start); } public static int indexOf(final UTF8String target, final UTF8String pattern, final int start, final int collationId) { if (pattern.numBytes() == 0) return target.indexOfEmpty(start); if (target.numBytes() == 0) return MATCH_NOT_FOUND; - - StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); - stringSearch.setIndex(start); - - return stringSearch.next(); + // Initialize the string search with respect to the specified ICU collation. + String targetStr = target.toValidString(); + String patternStr = pattern.toValidString(); + // Check if `start` is out of bounds. The provided offset `start` is given in number of + // codepoints, so a simple `targetStr.length` check is not sufficient here. This check is + // needed because `String.offsetByCodePoints` throws an `IndexOutOfBoundsException` + // exception when the offset is out of bounds. + if (targetStr.codePointCount(0, targetStr.length()) <= start) return MATCH_NOT_FOUND; + StringSearch stringSearch = + CollationFactory.getStringSearch(targetStr, patternStr, collationId); + stringSearch.setOverlapping(true); + // Start the search from `start`-th code point (NOT necessarily from the `start`-th character). + int startIndex = targetStr.offsetByCodePoints(0, start); + stringSearch.setIndex(startIndex); + // Perform the search and return the next result, starting from the specified position. + int searchIndex = stringSearch.next(); + if (searchIndex == StringSearch.DONE) { + return MATCH_NOT_FOUND; + } + // Convert the search index from character count to code point count. + int indexOf = targetStr.codePointCount(0, searchIndex); + if (indexOf < start) { + return MATCH_NOT_FOUND; + } + return indexOf; } - public static int find(UTF8String target, UTF8String pattern, int start, - int collationId) { - assert (pattern.numBytes() > 0); - - StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); - // Set search start position (start from character at start position) - stringSearch.setIndex(target.bytePosToChar(start)); + private static int findIndex(final StringSearch stringSearch, int count) { + assert(count >= 0); + int index = 0; + while (count > 0) { + int nextIndex = stringSearch.next(); + if (nextIndex == StringSearch.DONE) { + return MATCH_NOT_FOUND; + } else if (nextIndex == index && index != 0) { + stringSearch.setIndex(stringSearch.getIndex() + stringSearch.getMatchLength()); + } else { + count--; + index = nextIndex; + } + } + return index; + } - // Return either the byte position or -1 if not found - return target.charPosToByte(stringSearch.next()); + private static int findIndexReverse(final StringSearch stringSearch, int count) { + assert(count >= 0); + int index = 0; + while (count > 0) { + index = stringSearch.previous(); + if (index == StringSearch.DONE) { + return MATCH_NOT_FOUND; + } + count--; + } + return index + stringSearch.getMatchLength(); } public static UTF8String subStringIndex(final UTF8String string, final UTF8String delimiter, @@ -604,63 +875,30 @@ public static UTF8String subStringIndex(final UTF8String string, final UTF8Strin if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) { return UTF8String.EMPTY_UTF8; } + String str = string.toValidString(); + String delim = delimiter.toValidString(); + StringSearch stringSearch = CollationFactory.getStringSearch(str, delim, collationId); + stringSearch.setOverlapping(true); if (count > 0) { - int idx = -1; - while (count > 0) { - idx = find(string, delimiter, idx + 1, collationId); - if (idx >= 0) { - count --; - } else { - // can not find enough delim - return string; - } - } - if (idx == 0) { + // If the count is positive, we search for the count-th delimiter from the left. + int searchIndex = findIndex(stringSearch, count); + if (searchIndex == MATCH_NOT_FOUND) { + return string; + } else if (searchIndex == 0) { return UTF8String.EMPTY_UTF8; + } else { + return UTF8String.fromString(str.substring(0, searchIndex)); } - byte[] bytes = new byte[idx]; - copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); - return UTF8String.fromBytes(bytes); - } else { - count = -count; - - StringSearch stringSearch = CollationFactory - .getStringSearch(string, delimiter, collationId); - - int start = string.numChars() - 1; - int lastMatchLength = 0; - int prevStart = -1; - while (count > 0) { - stringSearch.reset(); - prevStart = -1; - int matchStart = stringSearch.next(); - lastMatchLength = stringSearch.getMatchLength(); - while (matchStart <= start) { - if (matchStart != StringSearch.DONE) { - // Found a match, update the start position - prevStart = matchStart; - matchStart = stringSearch.next(); - } else { - break; - } - } - - if (prevStart == -1) { - // can not find enough delim + // If the count is negative, we search for the count-th delimiter from the right. + int searchIndex = findIndexReverse(stringSearch, -count); + if (searchIndex == MATCH_NOT_FOUND) { return string; - } else { - start = prevStart - 1; - count--; - } - } - - int resultStart = prevStart + lastMatchLength; - if (resultStart == string.numChars()) { - return UTF8String.EMPTY_UTF8; + } else if (searchIndex == str.length()) { + return UTF8String.EMPTY_UTF8; + } else { + return UTF8String.fromString(str.substring(searchIndex)); } - - return string.substring(resultStart, string.numChars()); } } @@ -670,7 +908,7 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string, return UTF8String.EMPTY_UTF8; } - UTF8String lowercaseDelimiter = delimiter.toLowerCase(); + UTF8String lowercaseDelimiter = lowerCaseCodePoints(delimiter); if (count > 0) { // Search left to right (note: the start code point is inclusive). @@ -750,11 +988,11 @@ public static UTF8String lowercaseTranslate(final UTF8String input, } // Special handling for letter i (U+0069) followed by a combining dot (U+0307). By ensuring // that `CODE_POINT_LOWERCASE_I` is buffered, we guarantee finding a max-length match. - if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) && - codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) { + if (lowercaseDict.containsKey(COMBINED_ASCII_SMALL_I_COMBINING_DOT) + && codePoint == SpecialCodePointConstants.ASCII_SMALL_I && inputIter.hasNext()) { int nextCodePoint = inputIter.next(); - if (nextCodePoint == CODE_POINT_COMBINING_DOT) { - codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT; + if (nextCodePoint == SpecialCodePointConstants.COMBINING_DOT) { + codePoint = COMBINED_ASCII_SMALL_I_COMBINING_DOT; } else { codePointBuffer = nextCodePoint; } @@ -842,6 +1080,24 @@ public static UTF8String translate(final UTF8String input, return UTF8String.fromString(sb.toString()); } + /** + * Trims the `srcString` string from both ends of the string using the specified `trimString` + * characters, with respect to the UTF8_BINARY trim collation. String trimming is performed by + * first trimming the left side of the string, and then trimming the right side of the string. + * The method returns the trimmed string. If the `trimString` is null, the method returns null. + * + * @param srcString the input string to be trimmed from both ends of the string + * @param trimString the trim string characters to trim + * @param collationId the collation ID to use for string trim + * @return the trimmed string (for UTF8_BINARY collation) + */ + public static UTF8String binaryTrim( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + return binaryTrimRight(srcString.trimLeft(trimString), trimString, collationId); + } + /** * Trims the `srcString` string from both ends of the string using the specified `trimString` * characters, with respect to the UTF8_LCASE collation. String trimming is performed by @@ -850,12 +1106,14 @@ public static UTF8String translate(final UTF8String input, * * @param srcString the input string to be trimmed from both ends of the string * @param trimString the trim string characters to trim + * @param collationId the collation ID to use for string trim * @return the trimmed string (for UTF8_LCASE collation) */ public static UTF8String lowercaseTrim( final UTF8String srcString, - final UTF8String trimString) { - return lowercaseTrimRight(lowercaseTrimLeft(srcString, trimString), trimString); + final UTF8String trimString, + final int collationId) { + return lowercaseTrimRight(lowercaseTrimLeft(srcString, trimString), trimString, collationId); } /** @@ -883,7 +1141,8 @@ public static UTF8String trim( * the left side, until reaching a character whose lowercased code point is not in the hash set. * Finally, the method returns the substring from that position to the end of `srcString`. * If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned. - * + * Note: as currently only trimming collation supported is RTRIM, trimLeft is not modified + * to support other trim collations, this should be done in case of adding TRIM and LTRIM. * @param srcString the input string to be trimmed from the left end of the string * @param trimString the trim string characters to trim * @return the trimmed string (for UTF8_LCASE collation) @@ -902,20 +1161,29 @@ public static UTF8String lowercaseTrimLeft( while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next())); // Iterate over `srcString` from the left to find the first character that is not in the set. - int searchIndex = 0, codePoint; + int searchIndex = 0, codePoint, codePointBuffer = -1; Iterator srcIter = srcString.codePointIterator(); while (srcIter.hasNext()) { - codePoint = getLowercaseCodePoint(srcIter.next()); + // Get the next code point from either the buffer or the iterator. + if (codePointBuffer != -1) { + codePoint = codePointBuffer; + codePointBuffer = -1; + } + else { + codePoint = getLowercaseCodePoint(srcIter.next()); + } // Special handling for Turkish dotted uppercase letter I. - if (codePoint == CODE_POINT_LOWERCASE_I && srcIter.hasNext() && - trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) { - int nextCodePoint = getLowercaseCodePoint(srcIter.next()); - if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint)) - || nextCodePoint == CODE_POINT_COMBINING_DOT) { + if (codePoint == SpecialCodePointConstants.ASCII_SMALL_I && srcIter.hasNext() && + trimChars.contains(COMBINED_ASCII_SMALL_I_COMBINING_DOT)) { + codePointBuffer = codePoint; + codePoint = getLowercaseCodePoint(srcIter.next()); + if (codePoint == SpecialCodePointConstants.COMBINING_DOT) { searchIndex += 2; - } - else { - if (trimChars.contains(codePoint)) ++searchIndex; + codePointBuffer = -1; + } else if (trimChars.contains(codePointBuffer)) { + ++searchIndex; + codePointBuffer = codePoint; + } else { break; } } else if (trimChars.contains(codePoint)) { @@ -937,7 +1205,9 @@ public static UTF8String lowercaseTrimLeft( * character in `trimString`, until reaching a character that is not found in `trimString`. * Finally, the method returns the substring from that position to the end of `srcString`. * If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned. - * + * Note: as currently only trimming collation supported is RTRIM, trimLeft is not modified + * to support other trim collations, this should be done in case of adding TRIM and LTRIM + * collation. * @param srcString the input string to be trimmed from the left end of the string * @param trimString the trim string characters to trim * @param collationId the collation ID to use for string trimming @@ -957,7 +1227,7 @@ public static UTF8String trimLeft( CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID); while (trimIter.hasNext()) { int codePoint = trimIter.next(); - trimChars.putIfAbsent(codePoint, String.valueOf((char) codePoint)); + trimChars.putIfAbsent(codePoint, new String(Character.toChars(codePoint))); } // Iterate over srcString from the left and find the first character that is not in trimChars. @@ -985,22 +1255,103 @@ public static UTF8String trimLeft( // Return the substring from the calculated position until the end of the string. return UTF8String.fromString(src.substring(charIndex)); } + /** + * Trims the `srcString` string from the right side using the specified `trimString` characters, + * with respect to the UTF8_BINARY trim collation. For UTF8_BINARY trim collation, the method has + * one special case to cover with respect to trimRight function for regular UTF8_Binary collation. + * Trailing spaces should be ignored in case of trim collation (rtrim for example) and if + * trimString does not contain spaces. In this case, the method trims the string from the right + * and after call of regular trim functions returns back trimmed spaces as those should not get + * removed. + * @param srcString the input string to be trimmed from the right end of the string + * @param trimString the trim string characters to trim + * @param collationId the collation ID to use for string trim + * @return the trimmed string (for UTF_BINARY collation) + */ + public static UTF8String binaryTrimRight( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + // Matching the default UTF8String behavior for null `trimString`. + if (trimString == null) { + return null; + } + + // Create a hash set of code points for all characters of `trimString`. + HashSet trimChars = new HashSet<>(); + Iterator trimIter = trimString.codePointIterator(); + while (trimIter.hasNext()) trimChars.add(trimIter.next()); + + // Iterate over `srcString` from the right to find the first character that is not in the set. + int searchIndex = srcString.numChars(), codePoint, codePointBuffer = -1; + + // In cases of trim collation (rtrim for example) trailing spaces should be ignored. + // If trimString contains spaces this behaviour is not important as they would get trimmed + // anyway. However, if it is not the case they should be ignored and then appended after + // trimming other characters. + int lastNonSpaceByteIdx = srcString.numBytes(), lastNonSpaceCharacterIdx = srcString.numChars(); + if (!trimChars.contains(SpecialCodePointConstants.ASCII_SPACE) && + CollationFactory.ignoresSpacesInTrimFunctions( + collationId, /*isLTrim=*/ false, /*isRTrim=*/true)) { + while (lastNonSpaceByteIdx > 0 && + srcString.getByte(lastNonSpaceByteIdx - 1) == ' ') { + --lastNonSpaceByteIdx; + } + // In case of src string contains only spaces there is no need to do any trimming, since it's + // already checked that trim string does not contain any spaces. + if (lastNonSpaceByteIdx == 0) { + return srcString; + } + searchIndex = lastNonSpaceCharacterIdx = + srcString.numChars() - (srcString.numBytes() - lastNonSpaceByteIdx); + } + Iterator srcIter = srcString.reverseCodePointIterator(); + for (int i = lastNonSpaceCharacterIdx; i < srcString.numChars(); i++) { + srcIter.next(); + } + + while (srcIter.hasNext()) { + codePoint = srcIter.next(); + if (trimChars.contains(codePoint)) { + --searchIndex; + } + else { + break; + } + } + + // Return the substring from the start of the string to the calculated position and append + // trailing spaces if they were ignored + if (searchIndex == srcString.numChars()) { + return srcString; + } + if (lastNonSpaceCharacterIdx == srcString.numChars()) { + return srcString.substring(0, searchIndex); + } + return UTF8String.concat( + srcString.substring(0, searchIndex), + srcString.substring(lastNonSpaceCharacterIdx, srcString.numChars())); + } /** * Trims the `srcString` string from the right side using the specified `trimString` characters, * with respect to the UTF8_LCASE collation. For UTF8_LCASE, the method first creates a hash * set of lowercased code points in `trimString`, and then iterates over the `srcString` from * the right side, until reaching a character whose lowercased code point is not in the hash set. + * In case of UTF8_LCASE trim collation and when trimString does not contain spaces, trailing + * spaces should be ignored. However, after trimming function call they should be appended back. * Finally, the method returns the substring from the start of `srcString` until that position. * If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned. * * @param srcString the input string to be trimmed from the right end of the string * @param trimString the trim string characters to trim + * @param collationId the collation ID to use for string trim * @return the trimmed string (for UTF8_LCASE collation) */ public static UTF8String lowercaseTrimRight( final UTF8String srcString, - final UTF8String trimString) { + final UTF8String trimString, + final int collationId) { // Matching the default UTF8String behavior for null `trimString`. if (trimString == null) { return null; @@ -1012,20 +1363,53 @@ public static UTF8String lowercaseTrimRight( while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next())); // Iterate over `srcString` from the right to find the first character that is not in the set. - int searchIndex = srcString.numChars(), codePoint; + int searchIndex = srcString.numChars(), codePoint, codePointBuffer = -1; + + // In cases of trim collation (rtrim for example) trailing spaces should be ignored. + // If trimString contains spaces this behaviour is not important as they would get trimmed + // anyway. However, if it is not the case they should be ignored and then appended after + // trimming other characters. + int lastNonSpaceByteIdx = srcString.numBytes(), lastNonSpaceCharacterIdx = srcString.numChars(); + if (!trimChars.contains(SpecialCodePointConstants.ASCII_SPACE) && + CollationFactory.ignoresSpacesInTrimFunctions( + collationId, /*isLTrim=*/ false, /*isRTrim=*/true)) { + while (lastNonSpaceByteIdx > 0 && + srcString.getByte(lastNonSpaceByteIdx - 1) == ' ') { + --lastNonSpaceByteIdx; + } + // In case of src string contains only spaces there is no need to do any trimming, since it's + // already checked that trim string does not contain any spaces. + if (lastNonSpaceByteIdx == 0) { + return srcString; + } + searchIndex = lastNonSpaceCharacterIdx = + srcString.numChars() - (srcString.numBytes() - lastNonSpaceByteIdx); + } Iterator srcIter = srcString.reverseCodePointIterator(); + for (int i = lastNonSpaceCharacterIdx; i < srcString.numChars(); i++) { + srcIter.next(); + } + while (srcIter.hasNext()) { - codePoint = getLowercaseCodePoint(srcIter.next()); + if (codePointBuffer != -1) { + codePoint = codePointBuffer; + codePointBuffer = -1; + } + else { + codePoint = getLowercaseCodePoint(srcIter.next()); + } // Special handling for Turkish dotted uppercase letter I. - if (codePoint == CODE_POINT_COMBINING_DOT && srcIter.hasNext() && - trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) { - int nextCodePoint = getLowercaseCodePoint(srcIter.next()); - if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint)) - || nextCodePoint == CODE_POINT_LOWERCASE_I) { + if (codePoint == SpecialCodePointConstants.COMBINING_DOT && srcIter.hasNext() && + trimChars.contains(COMBINED_ASCII_SMALL_I_COMBINING_DOT)) { + codePointBuffer = codePoint; + codePoint = getLowercaseCodePoint(srcIter.next()); + if (codePoint == SpecialCodePointConstants.ASCII_SMALL_I) { searchIndex -= 2; - } - else { - if (trimChars.contains(codePoint)) --searchIndex; + codePointBuffer = -1; + } else if (trimChars.contains(codePointBuffer)) { + --searchIndex; + codePointBuffer = codePoint; + } else { break; } } else if (trimChars.contains(codePoint)) { @@ -1036,8 +1420,17 @@ public static UTF8String lowercaseTrimRight( } } - // Return the substring from the start of the string to the calculated position. - return searchIndex == srcString.numChars() ? srcString : srcString.substring(0, searchIndex); + // Return the substring from the start of the string to the calculated position and append + // trailing spaces if they were ignored + if (searchIndex == srcString.numChars()) { + return srcString; + } + if (lastNonSpaceCharacterIdx == srcString.numChars()) { + return srcString.substring(0, searchIndex); + } + return UTF8String.concat( + srcString.substring(0, searchIndex), + srcString.substring(lastNonSpaceCharacterIdx, srcString.numChars())); } /** @@ -1067,14 +1460,33 @@ public static UTF8String trimRight( CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID); while (trimIter.hasNext()) { int codePoint = trimIter.next(); - trimChars.putIfAbsent(codePoint, String.valueOf((char) codePoint)); + trimChars.putIfAbsent(codePoint, new String(Character.toChars(codePoint))); } // Iterate over srcString from the left and find the first character that is not in trimChars. String src = srcString.toValidString(); CharacterIterator target = new StringCharacterIterator(src); Collator collator = CollationFactory.fetchCollation(collationId).collator; - int charIndex = src.length(), longestMatchLen; + int charIndex = src.length(), longestMatchLen, lastNonSpacePosition = src.length(); + + // In cases of trim collation (rtrim for example) trailing spaces should be ignored. + // If trimString contains spaces this behaviour is not important as they would get trimmed + // anyway. However, if it is not the case they should be ignored and then appended after + // trimming other characters. + if (!trimChars.containsKey(SpecialCodePointConstants.ASCII_SPACE) && + CollationFactory.ignoresSpacesInTrimFunctions( + collationId, /*isLTrim=*/ false, /*isRTrim=*/true)) { + while (lastNonSpacePosition > 0 && src.charAt(lastNonSpacePosition - 1) == ' ') { + --lastNonSpacePosition; + } + // In case of src string contains only spaces there is no need to do any trimming, since it's + // already checked that trim string does not contain any spaces. + if (lastNonSpacePosition == 0) { + return UTF8String.fromString(src); + } + charIndex = lastNonSpacePosition; + } + while (charIndex >= 0) { longestMatchLen = 0; for (String trim : trimChars.values()) { @@ -1102,8 +1514,91 @@ public static UTF8String trimRight( else charIndex -= longestMatchLen; } - // Return the substring from the start of the string until that position. - return UTF8String.fromString(src.substring(0, charIndex)); + // Return the substring from the start of the string until that position and append + // trailing spaces if they were ignored + if (charIndex == src.length()) { + return srcString; + } + if (lastNonSpacePosition == srcString.numChars()) { + return UTF8String.fromString(src.substring(0, charIndex)); + } + return UTF8String.fromString( + src.substring(0, charIndex) + + src.substring(lastNonSpacePosition) + ); + } + + public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, + final int limit, final int collationId) { + if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) { + return input.split(delim, limit); + } else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) { + return lowercaseSplitSQL(input, delim, limit); + } else { + return icuSplitSQL(input, delim, limit, collationId); + } + } + + public static UTF8String[] lowercaseSplitSQL(final UTF8String string, final UTF8String delimiter, + final int limit) { + if (delimiter.numBytes() == 0) return new UTF8String[] { string }; + if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 }; + + List strings = new ArrayList<>(); + UTF8String lowercaseDelimiter = lowerCaseCodePoints(delimiter); + int startIndex = 0, nextMatch = 0, nextMatchLength; + while (nextMatch != MATCH_NOT_FOUND) { + if (limit > 0 && strings.size() == limit - 1) { + break; + } + nextMatch = lowercaseFind(string, lowercaseDelimiter, startIndex); + if (nextMatch != MATCH_NOT_FOUND) { + nextMatchLength = lowercaseMatchLengthFrom(string, lowercaseDelimiter, nextMatch); + strings.add(string.substring(startIndex, nextMatch)); + startIndex = nextMatch + nextMatchLength; + } + } + if (startIndex <= string.numChars()) { + strings.add(string.substring(startIndex, string.numChars())); + } + if (limit == 0) { + // Remove trailing empty strings + int i = strings.size() - 1; + while (i >= 0 && strings.get(i).numBytes() == 0) { + strings.remove(i); + i--; + } + } + return strings.toArray(new UTF8String[0]); + } + + public static UTF8String[] icuSplitSQL(final UTF8String string, final UTF8String delimiter, + final int limit, final int collationId) { + if (delimiter.numBytes() == 0) return new UTF8String[] { string }; + if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 }; + List strings = new ArrayList<>(); + String target = string.toValidString(), pattern = delimiter.toValidString(); + StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); + int start = 0, end; + while ((end = stringSearch.next()) != StringSearch.DONE) { + if (limit > 0 && strings.size() == limit - 1) { + break; + } + strings.add(UTF8String.fromString(target.substring(start, end))); + start = end + stringSearch.getMatchLength(); + } + if (start <= target.length()) { + strings.add(UTF8String.fromString(target.substring(start))); + } + if (limit == 0) { + // Remove trailing empty strings + int i = strings.size() - 1; + while (i >= 0 && strings.get(i).numBytes() == 0) { + strings.remove(i); + i--; + } + } + return strings.toArray(new UTF8String[0]); } // TODO: Add more collation-aware UTF8String operations here. diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index f13f66e384e0f..4064f830e92d8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -23,12 +23,14 @@ import java.util.function.Function; import java.util.function.BiFunction; import java.util.function.ToLongFunction; +import java.util.stream.Stream; +import com.ibm.icu.text.CollationKey; +import com.ibm.icu.text.Collator; import com.ibm.icu.text.RuleBasedCollator; import com.ibm.icu.text.StringSearch; import com.ibm.icu.util.ULocale; -import com.ibm.icu.text.CollationKey; -import com.ibm.icu.text.Collator; +import com.ibm.icu.util.VersionInfo; import org.apache.spark.SparkException; import org.apache.spark.unsafe.types.UTF8String; @@ -88,6 +90,18 @@ public Optional getVersion() { } } + public record CollationMeta( + String catalog, + String schema, + String collationName, + String language, + String country, + String icuVersion, + String padAttribute, + boolean accentSensitivity, + boolean caseSensitivity, + String spaceTrimming) { } + /** * Entry encapsulating all information about a collation. */ @@ -99,7 +113,8 @@ public static class Collation { /** * Version of the collation. This is the version of the ICU library Collator. - * For non-ICU collations (e.g. UTF8 Binary) the version is set to "1.0". + * For UTF8 Binary the version is set to "1.0". For ICU collations and UTF8_LCASE + * (because it uses ICU mappings) the version is set to the version of the ICU library. * When using ICU Collator this version is exposed through collator.getVersion(). * Whenever the collation is updated, the version should be updated as well or kept * for backwards compatibility. @@ -133,13 +148,31 @@ public static class Collation { public final boolean supportsBinaryOrdering; /** - * Support for Lowercase Equality implies that it is possible to check equality on - * byte by byte level, but only after calling "UTF8String.toLowerCase" on both arguments. + * Support for Lowercase Equality implies that it is possible to check equality on byte by + * byte level, but only after calling "UTF8String.lowerCaseCodePoints" on both arguments. * This allows custom collation support for UTF8_LCASE collation in various Spark * expressions, as this particular collation is not supported by the external ICU library. */ public final boolean supportsLowercaseEquality; + /** + * Support for Space Trimming implies that that based on specifier (for now only right trim) + * leading, trailing or both spaces are removed from the input string before comparison. + */ + public final boolean supportsSpaceTrimming; + + /** + * Is Utf8 binary type as indicator if collation base type is UTF8 binary. Note currently only + * collations Utf8_Binary and Utf8_Binary_RTRIM are considered as Utf8 binary type. + */ + public final boolean isUtf8BinaryType; + + /** + * Is Utf8 lcase type as indicator if collation base type is UTF8 lcase. Note currently only + * collations Utf8_Lcase and Utf8_Lcase_RTRIM are considered as Utf8 Lcase type. + */ + public final boolean isUtf8LcaseType; + public Collation( String collationName, String provider, @@ -147,31 +180,27 @@ public Collation( Comparator comparator, String version, ToLongFunction hashFunction, - boolean supportsBinaryEquality, - boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality) { + BiFunction equalsFunction, + boolean isUtf8BinaryType, + boolean isUtf8LcaseType, + boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; this.collator = collator; this.comparator = comparator; this.version = version; this.hashFunction = hashFunction; - this.supportsBinaryEquality = supportsBinaryEquality; - this.supportsBinaryOrdering = supportsBinaryOrdering; - this.supportsLowercaseEquality = supportsLowercaseEquality; - - // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality - assert(!supportsBinaryOrdering || supportsBinaryEquality); + this.isUtf8BinaryType = isUtf8BinaryType; + this.isUtf8LcaseType = isUtf8LcaseType; + this.equalsFunction = equalsFunction; + this.supportsSpaceTrimming = supportsSpaceTrimming; + this.supportsBinaryEquality = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsBinaryOrdering = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsLowercaseEquality = !supportsSpaceTrimming && isUtf8LcaseType; // No Collation can simultaneously support binary equality and lowercase equality assert(!supportsBinaryEquality || !supportsLowercaseEquality); assert(SUPPORTED_PROVIDERS.contains(provider)); - - if (supportsBinaryEquality) { - this.equalsFunction = UTF8String::equals; - } else { - this.equalsFunction = (s1, s2) -> this.comparator.compare(s1, s2) == 0; - } } /** @@ -186,7 +215,8 @@ public Collation( * bit 29: 0 for UTF8_BINARY, 1 for ICU collations. * bit 28-24: Reserved. * bit 23-22: Reserved for version. - * bit 21-18: Reserved for space trimming. + * bit 21-19 Zeros, reserved for future trimmings. + * bit 18 0 = none, 1 = right trim. * bit 17-0: Depend on collation family. * --- * INDETERMINATE collation ID binary layout: @@ -201,7 +231,8 @@ public Collation( * UTF8_BINARY collation ID binary layout: * bit 31-24: Zeroes. * bit 23-22: Zeroes, reserved for version. - * bit 21-18: Zeroes, reserved for space trimming. + * bit 21-19 Zeros, reserved for future trimmings. + * bit 18 0 = none, 1 = right trim. * bit 17-3: Zeroes. * bit 2: 0, reserved for accent sensitivity. * bit 1: 0, reserved for uppercase and case-insensitive. @@ -212,7 +243,8 @@ public Collation( * bit 29: 1 * bit 28-24: Zeroes. * bit 23-22: Zeroes, reserved for version. - * bit 21-18: Zeroes, reserved for space trimming. + * bit 21-18: Reserved for space trimming. + * 0000 = none, 0001 = right trim. Bits 21-19 remain reserved and fixed to 0. * bit 17: 0 = case-sensitive, 1 = case-insensitive. * bit 16: 0 = accent-sensitive, 1 = accent-insensitive. * bit 15-14: Zeroes, reserved for punctuation sensitivity. @@ -220,14 +252,20 @@ public Collation( * bit 11-0: Locale ID as specified in `ICULocaleToId` mapping. * --- * Some illustrative examples of collation name to ID mapping: - * - UTF8_BINARY -> 0 - * - UTF8_LCASE -> 1 - * - UNICODE -> 0x20000000 - * - UNICODE_AI -> 0x20010000 - * - UNICODE_CI -> 0x20020000 - * - UNICODE_CI_AI -> 0x20030000 - * - af -> 0x20000001 - * - af_CI_AI -> 0x20030001 + * - UTF8_BINARY -> 0 + * - UTF8_BINARY_RTRIM -> 0x00040000 + * - UTF8_LCASE -> 1 + * - UTF8_LCASE_RTRIM -> 0x00040001 + * - UNICODE -> 0x20000000 + * - UNICODE_AI -> 0x20010000 + * - UNICODE_CI -> 0x20020000 + * - UNICODE_RTRIM -> 0x20040000 + * - UNICODE_CI_AI -> 0x20030000 + * - UNICODE_CI_RTRIM -> 0x20060000 + * - UNICODE_AI_RTRIM -> 0x20050000 + * - UNICODE_CI_AI_RTRIM-> 0x20070000 + * - af -> 0x20000001 + * - af_CI_AI -> 0x20030001 */ private abstract static class CollationSpec { @@ -246,6 +284,14 @@ protected enum ImplementationProvider { UTF8_BINARY, ICU } + /** + * Bit 18 in collation ID having value 0 for none and 1 for right trimming. + * Bits 21, 20, 19 remained reserved (and fixed to 0) for future use. + */ + protected enum SpaceTrimming { + NONE, RTRIM + } + /** * Offset in binary collation ID layout. */ @@ -266,6 +312,17 @@ protected enum ImplementationProvider { */ protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b1; + + /** + * Offset in binary collation ID layout. + */ + protected static final int SPACE_TRIMMING_OFFSET = 18; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + protected static final int SPACE_TRIMMING_MASK = 0b1; + private static final int INDETERMINATE_COLLATION_ID = -1; /** @@ -290,6 +347,45 @@ private static DefinitionOrigin getDefinitionOrigin(int collationId) { DEFINITION_ORIGIN_OFFSET, DEFINITION_ORIGIN_MASK)]; } + /** + * Utility function to retrieve `SpaceTrimming` enum instance from collation ID. + */ + protected static SpaceTrimming getSpaceTrimming(int collationId) { + return SpaceTrimming.values()[SpecifierUtils.getSpecValue(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK)]; + } + + protected static UTF8String applyTrimmingPolicy(UTF8String s, int collationId) { + return applyTrimmingPolicy(s, getSpaceTrimming(collationId)); + } + + /** + * Returns if leading/trailing spaces should be ignored in trim string expressions. This is + * needed because space trimming collation directly changes behaviour of trim functions. + */ + protected static boolean ignoresSpacesInTrimFunctions( + int collationId, + boolean isLTrim, + boolean isRTrim) { + if (isRTrim && getSpaceTrimming(collationId) == SpaceTrimming.RTRIM) { + return true; + } + + // In case of adding new trimming collations in the future (LTRIM and TRIM) here logic + // should be added. + return false; + } + + /** + * Utility function to trim spaces when collation uses space trimming. + */ + protected static UTF8String applyTrimmingPolicy(UTF8String s, SpaceTrimming spaceTrimming) { + if(spaceTrimming == SpaceTrimming.RTRIM){ + return s.trimRight(); + } + return s; // No trimming. + } + /** * Main entry point for retrieving `Collation` instance from collation ID. */ @@ -342,6 +438,25 @@ private static int collationNameToId(String collationName) throws SparkException } protected abstract Collation buildCollation(); + + protected abstract CollationMeta buildCollationMeta(); + + protected abstract String normalizedCollationName(); + + static List listCollations() { + return Stream.concat( + CollationSpecUTF8.listCollations().stream(), + CollationSpecICU.listCollations().stream()).toList(); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + CollationMeta collationSpecUTF8 = + CollationSpecUTF8.loadCollationMeta(collationIdentifier); + if (collationSpecUTF8 == null) { + return CollationSpecICU.loadCollationMeta(collationIdentifier); + } + return collationSpecUTF8; + } } private static class CollationSpecUTF8 extends CollationSpec { @@ -364,68 +479,227 @@ private enum CaseSensitivity { */ private static final int CASE_SENSITIVITY_MASK = 0b1; + private static final String UTF8_BINARY_COLLATION_NAME = "UTF8_BINARY"; + private static final String UTF8_LCASE_COLLATION_NAME = "UTF8_LCASE"; + private static final int UTF8_BINARY_COLLATION_ID = - new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId; + new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED, SpaceTrimming.NONE).collationId; private static final int UTF8_LCASE_COLLATION_ID = - new CollationSpecUTF8(CaseSensitivity.LCASE).collationId; + new CollationSpecUTF8(CaseSensitivity.LCASE, SpaceTrimming.NONE).collationId; protected static Collation UTF8_BINARY_COLLATION = - new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).buildCollation(); + new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED, SpaceTrimming.NONE).buildCollation(); protected static Collation UTF8_LCASE_COLLATION = - new CollationSpecUTF8(CaseSensitivity.LCASE).buildCollation(); + new CollationSpecUTF8(CaseSensitivity.LCASE, SpaceTrimming.NONE).buildCollation(); + private final CaseSensitivity caseSensitivity; + private final SpaceTrimming spaceTrimming; private final int collationId; - private CollationSpecUTF8(CaseSensitivity caseSensitivity) { - this.collationId = + private CollationSpecUTF8( + CaseSensitivity caseSensitivity, + SpaceTrimming spaceTrimming) { + this.caseSensitivity = caseSensitivity; + this.spaceTrimming = spaceTrimming; + + int collationId = SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET, caseSensitivity); + this.collationId = + SpecifierUtils.setSpecValue(collationId, SPACE_TRIMMING_OFFSET, spaceTrimming); } private static int collationNameToId(String originalName, String collationName) throws SparkException { - if (UTF8_BINARY_COLLATION.collationName.equals(collationName)) { - return UTF8_BINARY_COLLATION_ID; - } else if (UTF8_LCASE_COLLATION.collationName.equals(collationName)) { - return UTF8_LCASE_COLLATION_ID; + + int baseId; + String collationNamePrefix; + + if (collationName.startsWith(UTF8_BINARY_COLLATION.collationName)) { + baseId = UTF8_BINARY_COLLATION_ID; + collationNamePrefix = UTF8_BINARY_COLLATION.collationName; + } else if (collationName.startsWith(UTF8_LCASE_COLLATION.collationName)) { + baseId = UTF8_LCASE_COLLATION_ID; + collationNamePrefix = UTF8_LCASE_COLLATION.collationName; } else { // Throw exception with original (before case conversion) collation name. throw collationInvalidNameException(originalName); } + + String remainingSpecifiers = collationName.substring(collationNamePrefix.length()); + if(remainingSpecifiers.isEmpty()) { + return baseId; + } + if(!remainingSpecifiers.startsWith("_")){ + throw collationInvalidNameException(originalName); + } + + SpaceTrimming spaceTrimming = SpaceTrimming.NONE; + String remainingSpec = remainingSpecifiers.substring(1); + if (remainingSpec.equals("RTRIM")) { + spaceTrimming = SpaceTrimming.RTRIM; + } else { + throw collationInvalidNameException(originalName); + } + + return SpecifierUtils.setSpecValue(baseId, SPACE_TRIMMING_OFFSET, spaceTrimming); } private static CollationSpecUTF8 fromCollationId(int collationId) { // Extract case sensitivity from collation ID. int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); - // Verify only case sensitivity bits were set settable in UTF8_BINARY family of collations. - assert (SpecifierUtils.removeSpec(collationId, - CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0); - return new CollationSpecUTF8(CaseSensitivity.values()[caseConversionOrdinal]); + // Extract space trimming from collation ID. + int spaceTrimmingOrdinal = getSpaceTrimming(collationId).ordinal(); + assert(isValidCollationId(collationId)); + return new CollationSpecUTF8( + CaseSensitivity.values()[caseConversionOrdinal], + SpaceTrimming.values()[spaceTrimmingOrdinal]); + } + + private static boolean isValidCollationId(int collationId) { + collationId = SpecifierUtils.removeSpec( + collationId, + SPACE_TRIMMING_OFFSET, + SPACE_TRIMMING_MASK); + collationId = SpecifierUtils.removeSpec( + collationId, + CASE_SENSITIVITY_OFFSET, + CASE_SENSITIVITY_MASK); + return collationId == 0; } @Override protected Collation buildCollation() { - if (collationId == UTF8_BINARY_COLLATION_ID) { + if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { + Comparator comparator; + ToLongFunction hashFunction; + BiFunction equalsFunction; + boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; + + if (spaceTrimming == SpaceTrimming.NONE) { + comparator = UTF8String::binaryCompare; + hashFunction = s -> (long) s.hashCode(); + equalsFunction = UTF8String::equals; + } else { + comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); + equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( + applyTrimmingPolicy(s2, spaceTrimming)); + } + return new Collation( - "UTF8_BINARY", + normalizedCollationName(), PROVIDER_SPARK, null, - UTF8String::binaryCompare, - "1.0", - s -> (long) s.hashCode(), - /* supportsBinaryEquality = */ true, - /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false); + comparator, + CollationSpecICU.ICU_VERSION, + hashFunction, + equalsFunction, + /* isUtf8BinaryType = */ true, + /* isUtf8LcaseType = */ false, + spaceTrimming != SpaceTrimming.NONE); } else { + Comparator comparator; + ToLongFunction hashFunction; + + if (spaceTrimming == SpaceTrimming.NONE) { + comparator = CollationAwareUTF8String::compareLowerCase; + hashFunction = s -> + (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + } else { + comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( + applyTrimmingPolicy(s1, spaceTrimming), + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints( + applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + } + return new Collation( - "UTF8_LCASE", + normalizedCollationName(), PROVIDER_SPARK, null, - CollationAwareUTF8String::compareLowerCase, - "1.0", - s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(), - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true); + comparator, + CollationSpecICU.ICU_VERSION, + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ true, + spaceTrimming != SpaceTrimming.NONE); + } + } + + @Override + protected CollationMeta buildCollationMeta() { + if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { + return new CollationMeta( + CATALOG, + SCHEMA, + normalizedCollationName(), + /* language = */ null, + /* country = */ null, + /* icuVersion = */ null, + COLLATION_PAD_ATTRIBUTE, + /* accentSensitivity = */ true, + /* caseSensitivity = */ true, + spaceTrimming.toString()); + } else { + return new CollationMeta( + CATALOG, + SCHEMA, + normalizedCollationName(), + /* language = */ null, + /* country = */ null, + /* icuVersion = */ null, + COLLATION_PAD_ATTRIBUTE, + /* accentSensitivity = */ true, + /* caseSensitivity = */ false, + spaceTrimming.toString()); + } + } + + /** + * Compute normalized collation name. Components of collation name are given in order: + * - Base collation name (UTF8_BINARY or UTF8_LCASE) + * - Optional space trimming when non-default preceded by underscore + * Examples: UTF8_BINARY, UTF8_BINARY_LCASE_LTRIM, UTF8_BINARY_TRIM. + */ + @Override + protected String normalizedCollationName() { + StringBuilder builder = new StringBuilder(); + if(caseSensitivity == CaseSensitivity.UNSPECIFIED){ + builder.append(UTF8_BINARY_COLLATION_NAME); + } else{ + builder.append(UTF8_LCASE_COLLATION_NAME); + } + if (spaceTrimming != SpaceTrimming.NONE) { + builder.append('_'); + builder.append(spaceTrimming.toString()); + } + return builder.toString(); + } + + static List listCollations() { + CollationIdentifier UTF8_BINARY_COLLATION_IDENT = new CollationIdentifier( + PROVIDER_SPARK, + UTF8_BINARY_COLLATION_NAME, + CollationSpecICU.ICU_VERSION + ); + CollationIdentifier UTF8_LCASE_COLLATION_IDENT = new CollationIdentifier( + PROVIDER_SPARK, + UTF8_LCASE_COLLATION_NAME, + CollationSpecICU.ICU_VERSION + ); + return Arrays.asList(UTF8_BINARY_COLLATION_IDENT, UTF8_LCASE_COLLATION_IDENT); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + try { + int collationId = CollationSpecUTF8.collationNameToId( + collationIdentifier.name, collationIdentifier.name.toUpperCase()); + return CollationSpecUTF8.fromCollationId(collationId).buildCollationMeta(); + } catch (SparkException ignored) { + // ignore + return null; } } } @@ -489,9 +763,11 @@ private enum AccentSensitivity { private static final Map ICULocaleToId = new HashMap<>(); /** - * ICU library Collator version passed to `Collation` instance. + * ICU library version. */ - private static final String ICU_COLLATOR_VERSION = "153.120.0.0"; + private static final String ICU_VERSION = String.format("%d.%d", + VersionInfo.ICU_VERSION.getMajor(), + VersionInfo.ICU_VERSION.getMinor()); static { ICULocaleMap.put("UNICODE", ULocale.ROOT); @@ -541,21 +817,33 @@ private enum AccentSensitivity { } } - private static final int UNICODE_COLLATION_ID = - new CollationSpecICU("UNICODE", CaseSensitivity.CS, AccentSensitivity.AS).collationId; - private static final int UNICODE_CI_COLLATION_ID = - new CollationSpecICU("UNICODE", CaseSensitivity.CI, AccentSensitivity.AS).collationId; + private static final int UNICODE_COLLATION_ID = new CollationSpecICU( + "UNICODE", + CaseSensitivity.CS, + AccentSensitivity.AS, + SpaceTrimming.NONE).collationId; + + private static final int UNICODE_CI_COLLATION_ID = new CollationSpecICU( + "UNICODE", + CaseSensitivity.CI, + AccentSensitivity.AS, + SpaceTrimming.NONE).collationId; private final CaseSensitivity caseSensitivity; private final AccentSensitivity accentSensitivity; + private final SpaceTrimming spaceTrimming; private final String locale; private final int collationId; - private CollationSpecICU(String locale, CaseSensitivity caseSensitivity, - AccentSensitivity accentSensitivity) { + private CollationSpecICU( + String locale, + CaseSensitivity caseSensitivity, + AccentSensitivity accentSensitivity, + SpaceTrimming spaceTrimming) { this.locale = locale; this.caseSensitivity = caseSensitivity; this.accentSensitivity = accentSensitivity; + this.spaceTrimming = spaceTrimming; // Construct collation ID from locale, case-sensitivity and accent-sensitivity specifiers. int collationId = ICULocaleToId.get(locale); // Mandatory ICU implementation provider. @@ -565,6 +853,8 @@ private CollationSpecICU(String locale, CaseSensitivity caseSensitivity, caseSensitivity); collationId = SpecifierUtils.setSpecValue(collationId, ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, SPACE_TRIMMING_OFFSET, + spaceTrimming); this.collationId = collationId; } @@ -582,58 +872,86 @@ private static int collationNameToId( } if (lastPos == -1) { throw collationInvalidNameException(originalName); - } else { - String locale = collationName.substring(0, lastPos); - int collationId = ICULocaleToId.get(ICULocaleMapUppercase.get(locale)); - - // Try all combinations of AS/AI and CS/CI. - CaseSensitivity caseSensitivity; - AccentSensitivity accentSensitivity; - if (collationName.equals(locale) || - collationName.equals(locale + "_AS") || - collationName.equals(locale + "_CS") || - collationName.equals(locale + "_AS_CS") || - collationName.equals(locale + "_CS_AS") - ) { - caseSensitivity = CaseSensitivity.CS; - accentSensitivity = AccentSensitivity.AS; - } else if (collationName.equals(locale + "_CI") || - collationName.equals(locale + "_AS_CI") || - collationName.equals(locale + "_CI_AS")) { - caseSensitivity = CaseSensitivity.CI; - accentSensitivity = AccentSensitivity.AS; - } else if (collationName.equals(locale + "_AI") || - collationName.equals(locale + "_CS_AI") || - collationName.equals(locale + "_AI_CS")) { - caseSensitivity = CaseSensitivity.CS; - accentSensitivity = AccentSensitivity.AI; - } else if (collationName.equals(locale + "_AI_CI") || - collationName.equals(locale + "_CI_AI")) { - caseSensitivity = CaseSensitivity.CI; - accentSensitivity = AccentSensitivity.AI; - } else { - throw collationInvalidNameException(originalName); - } + } + String locale = collationName.substring(0, lastPos); + int collationId = ICULocaleToId.get(ICULocaleMapUppercase.get(locale)); + collationId = SpecifierUtils.setSpecValue(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU); - // Build collation ID from computed specifiers. - collationId = SpecifierUtils.setSpecValue(collationId, - IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU); - collationId = SpecifierUtils.setSpecValue(collationId, - CASE_SENSITIVITY_OFFSET, caseSensitivity); - collationId = SpecifierUtils.setSpecValue(collationId, - ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + // No other specifiers present. + if(collationName.equals(locale)){ return collationId; } + if(collationName.charAt(locale.length()) != '_'){ + throw collationInvalidNameException(originalName); + } + // Extract remaining specifiers and trim "_" separator. + String remainingSpecifiers = collationName.substring(lastPos + 1); + + // Initialize default specifier flags. + // Case sensitive, accent sensitive, no space trimming. + boolean isCaseSpecifierSet = false; + boolean isAccentSpecifierSet = false; + boolean isSpaceTrimmingSpecifierSet = false; + CaseSensitivity caseSensitivity = CaseSensitivity.CS; + AccentSensitivity accentSensitivity = AccentSensitivity.AS; + SpaceTrimming spaceTrimming = SpaceTrimming.NONE; + + String[] specifiers = remainingSpecifiers.split("_"); + + // Iterate through specifiers and set corresponding flags + for (String specifier : specifiers) { + switch (specifier) { + case "CI": + case "CS": + if (isCaseSpecifierSet) { + throw collationInvalidNameException(originalName); + } + caseSensitivity = CaseSensitivity.valueOf(specifier); + isCaseSpecifierSet = true; + break; + case "AI": + case "AS": + if (isAccentSpecifierSet) { + throw collationInvalidNameException(originalName); + } + accentSensitivity = AccentSensitivity.valueOf(specifier); + isAccentSpecifierSet = true; + break; + case "RTRIM": + if (isSpaceTrimmingSpecifierSet) { + throw collationInvalidNameException(originalName); + } + spaceTrimming = SpaceTrimming.valueOf(specifier); + isSpaceTrimmingSpecifierSet = true; + break; + default: + throw collationInvalidNameException(originalName); + } + } + + // Build collation ID from computed specifiers. + collationId = SpecifierUtils.setSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, caseSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, + ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, + SPACE_TRIMMING_OFFSET, spaceTrimming); + return collationId; } private static CollationSpecICU fromCollationId(int collationId) { // Parse specifiers from collation ID. + int spaceTrimmingOrdinal = SpecifierUtils.getSpecValue(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK); int caseSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); int accentSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK); collationId = SpecifierUtils.removeSpec(collationId, IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK); collationId = SpecifierUtils.removeSpec(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); collationId = SpecifierUtils.removeSpec(collationId, @@ -644,8 +962,9 @@ private static CollationSpecICU fromCollationId(int collationId) { assert(localeId >= 0 && localeId < ICULocaleNames.length); CaseSensitivity caseSensitivity = CaseSensitivity.values()[caseSensitivityOrdinal]; AccentSensitivity accentSensitivity = AccentSensitivity.values()[accentSensitivityOrdinal]; + SpaceTrimming spaceTrimming = SpaceTrimming.values()[spaceTrimmingOrdinal]; String locale = ICULocaleNames[localeId]; - return new CollationSpecICU(locale, caseSensitivity, accentSensitivity); + return new CollationSpecICU(locale, caseSensitivity, accentSensitivity, spaceTrimming); } @Override @@ -672,16 +991,51 @@ protected Collation buildCollation() { Collator collator = Collator.getInstance(resultLocale); // Freeze ICU collator to ensure thread safety. collator.freeze(); + + Comparator comparator; + ToLongFunction hashFunction; + + if (spaceTrimming == SpaceTrimming.NONE) { + hashFunction = s -> (long) collator.getCollationKey( + s.toValidString()).hashCode(); + comparator = (s1, s2) -> + collator.compare(s1.toValidString(), s2.toValidString()); + } else { + comparator = (s1, s2) -> collator.compare( + applyTrimmingPolicy(s1, spaceTrimming).toValidString(), + applyTrimmingPolicy(s2, spaceTrimming).toValidString()); + hashFunction = s -> (long) collator.getCollationKey( + applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + } + return new Collation( - collationName(), + normalizedCollationName(), PROVIDER_ICU, collator, - (s1, s2) -> collator.compare(s1.toString(), s2.toString()), - ICU_COLLATOR_VERSION, - s -> (long) collator.getCollationKey(s.toString()).hashCode(), - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false); + comparator, + ICU_VERSION, + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ false, + spaceTrimming != SpaceTrimming.NONE); + } + + @Override + protected CollationMeta buildCollationMeta() { + String language = ICULocaleMap.get(locale).getDisplayLanguage(); + String country = ICULocaleMap.get(locale).getDisplayCountry(); + return new CollationMeta( + CATALOG, + SCHEMA, + normalizedCollationName(), + language.isEmpty() ? null : language, + country.isEmpty() ? null : country, + VersionInfo.ICU_VERSION.toString(), + COLLATION_PAD_ATTRIBUTE, + accentSensitivity == AccentSensitivity.AS, + caseSensitivity == CaseSensitivity.CS, + spaceTrimming.toString()); } /** @@ -689,9 +1043,11 @@ protected Collation buildCollation() { * - Locale name * - Optional case sensitivity when non-default preceded by underscore * - Optional accent sensitivity when non-default preceded by underscore - * Examples: en, en_USA_CI_AI, sr_Cyrl_SRB_AI. + * - Optional space trimming when non-default preceded by underscore + * Examples: en, en_USA_CI_LTRIM, en_USA_CI_AI, en_USA_CI_AI_TRIM, sr_Cyrl_SRB_AI. */ - private String collationName() { + @Override + protected String normalizedCollationName() { StringBuilder builder = new StringBuilder(); builder.append(locale); if (caseSensitivity != CaseSensitivity.CS) { @@ -702,8 +1058,40 @@ private String collationName() { builder.append('_'); builder.append(accentSensitivity.toString()); } + if(spaceTrimming != SpaceTrimming.NONE) { + builder.append('_'); + builder.append(spaceTrimming.toString()); + } return builder.toString(); } + + private static List allCollationNames() { + List collationNames = new ArrayList<>(); + List caseAccentSpecifiers = Arrays.asList("", "_AI", "_CI", "_CI_AI"); + for (String locale : ICULocaleToId.keySet()) { + for (String caseAccent : caseAccentSpecifiers) { + String collationName = locale + caseAccent; + collationNames.add(collationName); + } + } + return collationNames.stream().sorted().toList(); + } + + static List listCollations() { + return allCollationNames().stream().map(name -> + new CollationIdentifier(PROVIDER_ICU, name, VersionInfo.ICU_VERSION.toString())).toList(); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + try { + int collationId = CollationSpecICU.collationNameToId( + collationIdentifier.name, collationIdentifier.name.toUpperCase()); + return CollationSpecICU.fromCollationId(collationId).buildCollationMeta(); + } catch (SparkException ignored) { + // ignore + return null; + } + } } /** @@ -730,9 +1118,12 @@ public CollationIdentifier identifier() { } } + public static final String CATALOG = "SYSTEM"; + public static final String SCHEMA = "BUILTIN"; public static final String PROVIDER_SPARK = "spark"; public static final String PROVIDER_ICU = "icu"; public static final List SUPPORTED_PROVIDERS = List.of(PROVIDER_SPARK, PROVIDER_ICU); + public static final String COLLATION_PAD_ATTRIBUTE = "NO_PAD"; public static final int UTF8_BINARY_COLLATION_ID = Collation.CollationSpecUTF8.UTF8_BINARY_COLLATION_ID; @@ -749,12 +1140,15 @@ public CollationIdentifier identifier() { * Returns a StringSearch object for the given pattern and target strings, under collation * rules corresponding to the given collationId. The external ICU library StringSearch object can * be used to find occurrences of the pattern in the target string, while respecting collation. + * When given invalid UTF8Strings, the method will first convert them to valid strings, and then + * instantiate the StringSearch object. However, original UTF8Strings will remain unchanged. */ public static StringSearch getStringSearch( final UTF8String targetUTF8String, final UTF8String patternUTF8String, final int collationId) { - return getStringSearch(targetUTF8String.toString(), patternUTF8String.toString(), collationId); + return getStringSearch(targetUTF8String.toValidString(), patternUTF8String.toValidString(), + collationId); } /** @@ -763,9 +1157,9 @@ public static StringSearch getStringSearch( * be used to find occurrences of the pattern in the target string, while respecting collation. */ public static StringSearch getStringSearch( - final String targetString, - final String patternString, - final int collationId) { + final String targetString, + final String patternString, + final int collationId) { CharacterIterator target = new StringCharacterIterator(targetString); Collator collator = CollationFactory.fetchCollation(collationId).collator; return new StringSearch(patternString, target, (RuleBasedCollator) collator); @@ -775,11 +1169,13 @@ public static StringSearch getStringSearch( * Returns a collation-unaware StringSearch object for the given pattern and target strings. * While this object does not respect collation, it can be used to find occurrences of the pattern * in the target string for UTF8_BINARY or UTF8_LCASE (if arguments are lowercased). + * When given invalid UTF8Strings, the method will first convert them to valid strings, and then + * instantiate the StringSearch object. However, original UTF8Strings will remain unchanged. */ public static StringSearch getStringSearch( - final UTF8String targetUTF8String, - final UTF8String patternUTF8String) { - return new StringSearch(patternUTF8String.toString(), targetUTF8String.toString()); + final UTF8String targetUTF8String, + final UTF8String patternUTF8String) { + return new StringSearch(patternUTF8String.toValidString(), targetUTF8String.toValidString()); } /** @@ -789,6 +1185,16 @@ public static int collationNameToId(String collationName) throws SparkException return Collation.CollationSpec.collationNameToId(collationName); } + public static boolean isCaseInsensitive(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity == + Collation.CollationSpecICU.CaseSensitivity.CI; + } + + public static boolean isAccentInsensitive(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity == + Collation.CollationSpecICU.AccentSensitivity.AI; + } + public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( @@ -813,26 +1219,52 @@ public static String[] getICULocaleNames() { return Collation.CollationSpecICU.ICULocaleNames; } + /** + * Applies trimming policy depending up on trim collation type. + */ + public static UTF8String applyTrimmingPolicy(UTF8String input, int collationId) { + return Collation.CollationSpec.applyTrimmingPolicy(input, collationId); + } + + /** + * Returns if leading/trailing spaces should be ignored in trim string expressions. This is needed + * because space trimming collation directly changes behaviour of trim functions. + */ + public static boolean ignoresSpacesInTrimFunctions( + int collationId, + boolean isLTrim, + boolean isRTrim) { + return Collation.CollationSpec.ignoresSpacesInTrimFunctions(collationId, isLTrim, isRTrim); + } + public static UTF8String getCollationKey(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); + } + if (collation.isUtf8BinaryType) { return input; - } else if (collation.supportsLowercaseEquality) { - return input.toLowerCase(); + } else if (collation.isUtf8LcaseType) { + return CollationAwareUTF8String.lowerCaseCodePoints(input); } else { - CollationKey collationKey = collation.collator.getCollationKey(input.toString()); + CollationKey collationKey = collation.collator.getCollationKey( + input.toValidString()); return UTF8String.fromBytes(collationKey.toByteArray()); } } public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); + } + if (collation.isUtf8BinaryType) { return input.getBytes(); - } else if (collation.supportsLowercaseEquality) { - return input.toLowerCase().getBytes(); + } else if (collation.isUtf8LcaseType) { + return CollationAwareUTF8String.lowerCaseCodePoints(input).getBytes(); } else { - return collation.collator.getCollationKey(input.toString()).toByteArray(); + return collation.collator.getCollationKey( + input.toValidString()).toByteArray(); } } @@ -848,19 +1280,26 @@ public static String getClosestSuggestionsOnInvalidName( Collation.CollationSpecUTF8.UTF8_BINARY_COLLATION.collationName, Collation.CollationSpecUTF8.UTF8_LCASE_COLLATION.collationName }; - validModifiers = new String[0]; + validModifiers = new String[]{"_RTRIM"}; } else { validRootNames = getICULocaleNames(); - validModifiers = new String[]{"_CI", "_AI", "_CS", "_AS"}; + validModifiers = new String[]{"_CI", "_AI", "_CS", "_AS", "_RTRIM"}; } // Split modifiers and locale name. - final int MODIFIER_LENGTH = 3; + boolean foundModifier = true; String localeName = collationName.toUpperCase(); List modifiers = new ArrayList<>(); - while (Arrays.stream(validModifiers).anyMatch(localeName::endsWith)) { - modifiers.add(localeName.substring(localeName.length() - MODIFIER_LENGTH)); - localeName = localeName.substring(0, localeName.length() - MODIFIER_LENGTH); + while (foundModifier) { + foundModifier = false; + for (String modifier : validModifiers) { + if (localeName.endsWith(modifier)) { + modifiers.add(modifier); + localeName = localeName.substring(0, localeName.length() - modifier.length()); + foundModifier = true; + break; + } + } } // Suggest version with unique modifiers. @@ -918,4 +1357,12 @@ public static String getClosestSuggestionsOnInvalidName( return String.join(", ", suggestions); } + + public static List listCollations() { + return Collation.CollationSpec.listCollations(); + } + + public static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + return Collation.CollationSpec.loadCollationMeta(collationIdentifier); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 453423ddbc33d..135250e482b16 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -20,8 +20,6 @@ import org.apache.spark.unsafe.types.UTF8String; -import java.util.ArrayList; -import java.util.List; import java.util.Map; import java.util.regex.Pattern; @@ -37,88 +35,67 @@ public final class CollationSupport { */ public static class StringSplitSQL { - public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) { + public static UTF8String[] exec(final UTF8String s, UTF8String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + d = CollationFactory.applyTrimmingPolicy(d, collationId); + } + if (collation.isUtf8BinaryType) { return execBinary(s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(s, d); } else { return execICU(s, d, collationId); } } public static String genCode(final String s, final String d, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringSplitSQL.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", s, d); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", s, d); } else { - return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId); + return String.format(expr + "(%s, %s, %d)", s, d, collationId); } } public static UTF8String[] execBinary(final UTF8String string, final UTF8String delimiter) { return string.splitSQL(delimiter, -1); } public static UTF8String[] execLowercase(final UTF8String string, final UTF8String delimiter) { - if (delimiter.numBytes() == 0) return new UTF8String[] { string }; - if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 }; - Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()), - CollationSupport.lowercaseRegexFlags); - String[] splits = pattern.split(string.toString(), -1); - UTF8String[] res = new UTF8String[splits.length]; - for (int i = 0; i < res.length; i++) { - res[i] = UTF8String.fromString(splits[i]); - } - return res; + return CollationAwareUTF8String.lowercaseSplitSQL(string, delimiter, -1); } public static UTF8String[] execICU(final UTF8String string, final UTF8String delimiter, final int collationId) { - if (delimiter.numBytes() == 0) return new UTF8String[] { string }; - if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 }; - List strings = new ArrayList<>(); - String target = string.toString(), pattern = delimiter.toString(); - StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); - int start = 0, end; - while ((end = stringSearch.next()) != StringSearch.DONE) { - strings.add(UTF8String.fromString(target.substring(start, end))); - start = end + stringSearch.getMatchLength(); - } - if (start <= target.length()) { - strings.add(UTF8String.fromString(target.substring(start))); - } - return strings.toArray(new UTF8String[0]); + return CollationAwareUTF8String.icuSplitSQL(string, delimiter, -1, collationId); } } public static class Contains { - public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { + public static boolean exec(UTF8String l, UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + l = CollationFactory.applyTrimmingPolicy(l, collationId); + r = CollationFactory.applyTrimmingPolicy(r, collationId); + } + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); } } public static String genCode(final String l, final String r, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Contains.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", l, r); } else { - return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + return String.format(expr + "(%s, %s, %d)", l, r, collationId); } } public static boolean execBinary(final UTF8String l, final UTF8String r) { return l.contains(r); } public static boolean execLowercase(final UTF8String l, final UTF8String r) { - return CollationAwareUTF8String.lowercaseIndexOf(l, r, 0) >= 0; + return CollationAwareUTF8String.lowercaseContains(l, r); } public static boolean execICU(final UTF8String l, final UTF8String r, final int collationId) { @@ -130,33 +107,35 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } public static class StartsWith { - public static boolean exec(final UTF8String l, final UTF8String r, + public static boolean exec(UTF8String l, UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + l = CollationFactory.applyTrimmingPolicy(l, collationId); + r = CollationFactory.applyTrimmingPolicy(r, collationId); + } + + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); } } public static String genCode(final String l, final String r, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StartsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", l, r); } else { - return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + return String.format(expr + "(%s, %s, %d)", l, r, collationId); } } public static boolean execBinary(final UTF8String l, final UTF8String r) { return l.startsWith(r); } public static boolean execLowercase(final UTF8String l, final UTF8String r) { - return CollationAwareUTF8String.lowercaseMatchFrom(l, r.toLowerCase(), 0); + return CollationAwareUTF8String.lowercaseStartsWith(l, r); } public static boolean execICU(final UTF8String l, final UTF8String r, final int collationId) { @@ -168,32 +147,33 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } public static class EndsWith { - public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { + public static boolean exec(UTF8String l, UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + l = CollationFactory.applyTrimmingPolicy(l, collationId); + r = CollationFactory.applyTrimmingPolicy(r, collationId); + } + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); } } public static String genCode(final String l, final String r, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.EndsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", l, r); } else { - return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + return String.format(expr + "(%s, %s, %d)", l, r, collationId); } } public static boolean execBinary(final UTF8String l, final UTF8String r) { return l.endsWith(r); } public static boolean execLowercase(final UTF8String l, final UTF8String r) { - return CollationAwareUTF8String.lowercaseMatchUntil(l, r.toLowerCase(), l.numChars()); + return CollationAwareUTF8String.lowercaseEndsWith(l, r); } public static boolean execICU(final UTF8String l, final UTF8String r, final int collationId) { @@ -208,9 +188,10 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class Upper { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + // Space trimming does not affect the output of this expression. + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -219,10 +200,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Upper.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -245,9 +226,10 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class Lower { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + // Space trimming does not affect the output of this expression. + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -256,10 +238,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Lower.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -282,9 +264,10 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class InitCap { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + // Space trimming does not affect the output of this expression. + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -294,10 +277,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.InitCap.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -307,7 +290,7 @@ public static UTF8String execBinary(final UTF8String v) { return v.toLowerCase().toTitleCase(); } public static UTF8String execBinaryICU(final UTF8String v) { - return CollationAwareUTF8String.toLowerCase(v).toTitleCaseICU(); + return CollationAwareUTF8String.toTitleCaseICU(v); } public static UTF8String execLowercase(final UTF8String v) { return CollationAwareUTF8String.toTitleCase(v); @@ -319,17 +302,16 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class FindInSet { public static int exec(final UTF8String word, final UTF8String set, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + // FindInSet does space trimming collation as comparison is space trimming collation aware + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return execBinary(word, set); } else { return execCollationAware(word, set, collationId); } } public static String genCode(final String word, final String set, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.FindInSet.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", word, set); } else { return String.format(expr + "CollationAware(%s, %s, %d)", word, set, collationId); @@ -345,12 +327,15 @@ public static int execCollationAware(final UTF8String word, final UTF8String set } public static class StringInstr { - public static int exec(final UTF8String string, final UTF8String substring, + public static int exec(final UTF8String string, UTF8String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + substring = CollationFactory.applyTrimmingPolicy(substring, collationId); + } + if (collation.isUtf8BinaryType) { return execBinary(string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring); } else { return execICU(string, substring, collationId); @@ -358,14 +343,11 @@ public static int exec(final UTF8String string, final UTF8String substring, } public static String genCode(final String string, final String substring, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringInstr.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", string, substring); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", string, substring); } else { - return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); + return String.format(expr + "(%s, %s, %d)", string, substring, collationId); } } public static int execBinary(final UTF8String string, final UTF8String substring) { @@ -384,9 +366,10 @@ public static class StringReplace { public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + // Space trimming does not affect the output of this expression. + if (collation.isUtf8BinaryType) { return execBinary(src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(src, search, replace); } else { return execICU(src, search, replace, collationId); @@ -396,9 +379,9 @@ public static String genCode(final String src, final String search, final String final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringReplace.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); @@ -419,12 +402,15 @@ public static UTF8String execICU(final UTF8String src, final UTF8String search, } public static class StringLocate { - public static int exec(final UTF8String string, final UTF8String substring, final int start, + public static int exec(final UTF8String string, UTF8String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + substring = CollationFactory.applyTrimmingPolicy(substring, collationId); + } + if (collation.isUtf8BinaryType) { return execBinary(string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring, start); } else { return execICU(string, substring, start, collationId); @@ -432,14 +418,11 @@ public static int exec(final UTF8String string, final UTF8String substring, fina } public static String genCode(final String string, final String substring, final int start, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringLocate.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); } else { - return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); + return String.format(expr + "(%s, %s, %d, %d)", string, substring, start, collationId); } } public static int execBinary(final UTF8String string, final UTF8String substring, @@ -457,27 +440,27 @@ public static int execICU(final UTF8String string, final UTF8String substring, f } public static class SubstringIndex { - public static UTF8String exec(final UTF8String string, final UTF8String delimiter, + public static UTF8String exec(final UTF8String string, UTF8String delimiter, final int count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.supportsSpaceTrimming) { + delimiter = CollationFactory.applyTrimmingPolicy(delimiter, collationId); + } + if (collation.isUtf8BinaryType) { return execBinary(string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, delimiter, count); } else { return execICU(string, delimiter, count, collationId); } } public static String genCode(final String string, final String delimiter, - final int count, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + final String count, final int collationId) { String expr = "CollationSupport.SubstringIndex.exec"; - if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s, %s, %d)", string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s, %d)", string, delimiter, count); + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { + return String.format(expr + "Binary(%s, %s, %s)", string, delimiter, count); } else { - return String.format(expr + "ICU(%s, %s, %d, %d)", string, delimiter, count, collationId); + return String.format(expr + "(%s, %s, %s, %d)", string, delimiter, count, collationId); } } public static UTF8String execBinary(final UTF8String string, final UTF8String delimiter, @@ -490,8 +473,7 @@ public static UTF8String execLowercase(final UTF8String string, final UTF8String } public static UTF8String execICU(final UTF8String string, final UTF8String delimiter, final int count, final int collationId) { - return CollationAwareUTF8String.subStringIndex(string, delimiter, count, - collationId); + return CollationAwareUTF8String.subStringIndex(string, delimiter, count, collationId); } } @@ -499,25 +481,15 @@ public static class StringTranslate { public static UTF8String exec(final UTF8String source, Map dict, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + // Space trimming does not affect the output of this expression. + if (collation.isUtf8BinaryType) { return execBinary(source, dict); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(source, dict); } else { return execICU(source, dict, collationId); } } - public static String genCode(final String source, final String dict, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - String expr = "CollationSupport.EndsWith.exec"; - if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s, %s)", source, dict); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", source, dict); - } else { - return String.format(expr + "ICU(%s, %s, %d)", source, dict, collationId); - } - } public static UTF8String execBinary(final UTF8String source, Map dict) { return source.translate(dict); } @@ -539,10 +511,15 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType && !collation.supportsSpaceTrimming) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { - return execLowercase(srcString, trimString); + } + + if (collation.isUtf8BinaryType) { + // special handling needed for utf8_binary_rtrim collation. + return execBinaryTrim(srcString, trimString, collationId); + } else if (collation.isUtf8LcaseType) { + return execLowercase(srcString, trimString, collationId); } else { return execICU(srcString, trimString, collationId); } @@ -554,14 +531,11 @@ public static String genCode( final String srcString, final String trimString, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrim.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { - return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); + return String.format(expr + "(%s, %s, %d)", srcString, trimString, collationId); } } public static UTF8String execBinary( @@ -575,8 +549,9 @@ public static UTF8String execBinary( } public static UTF8String execLowercase( final UTF8String srcString, - final UTF8String trimString) { - return CollationAwareUTF8String.lowercaseTrim(srcString, trimString); + final UTF8String trimString, + final int collationId) { + return CollationAwareUTF8String.lowercaseTrim(srcString, trimString, collationId); } public static UTF8String execICU( final UTF8String srcString, @@ -584,6 +559,12 @@ public static UTF8String execICU( final int collationId) { return CollationAwareUTF8String.trim(srcString, trimString, collationId); } + public static UTF8String execBinaryTrim( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + return CollationAwareUTF8String.binaryTrim(srcString, trimString, collationId); + } } public static class StringTrimLeft { @@ -591,13 +572,15 @@ public static UTF8String exec(final UTF8String srcString) { return execBinary(srcString); } public static UTF8String exec( - final UTF8String srcString, - final UTF8String trimString, - final int collationId) { + final UTF8String srcString, + UTF8String trimString, + final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + // Space trimming does not affect the output of this expression as currently only supported + // space trimming is RTRIM. + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -610,14 +593,11 @@ public static String genCode( final String srcString, final String trimString, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimLeft.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { - return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); + return String.format(expr + "(%s, %s, %d)", srcString, trimString, collationId); } } public static UTF8String execBinary(final UTF8String srcString) { @@ -650,10 +630,15 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType && !collation.supportsSpaceTrimming) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { - return execLowercase(srcString, trimString); + } + + if (collation.isUtf8BinaryType) { + // special handling needed for utf8_binary_rtrim collation. + return execBinaryTrim(srcString, trimString, collationId); + } else if (collation.isUtf8LcaseType) { + return execLowercase(srcString, trimString, collationId); } else { return execICU(srcString, trimString, collationId); } @@ -665,14 +650,11 @@ public static String genCode( final String srcString, final String trimString, final int collationId) { - CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimRight.exec"; - if (collation.supportsBinaryEquality) { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { - return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); + return String.format(expr + "(%s, %s, %d)", srcString, trimString, collationId); } } public static UTF8String execBinary(final UTF8String srcString) { @@ -685,8 +667,9 @@ public static UTF8String execBinary( } public static UTF8String execLowercase( final UTF8String srcString, - final UTF8String trimString) { - return CollationAwareUTF8String.lowercaseTrimRight(srcString, trimString); + final UTF8String trimString, + final int collationId) { + return CollationAwareUTF8String.lowercaseTrimRight(srcString, trimString, collationId); } public static UTF8String execICU( final UTF8String srcString, @@ -694,6 +677,12 @@ public static UTF8String execICU( final int collationId) { return CollationAwareUTF8String.trimRight(srcString, trimString, collationId); } + public static UTF8String execBinaryTrim( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + return CollationAwareUTF8String.binaryTrimRight(srcString, trimString, collationId); + } } // TODO: Add more collation-aware string expressions. @@ -705,10 +694,10 @@ public static UTF8String execICU( public static boolean supportsLowercaseRegex(final int collationId) { // for regex, only Unicode case-insensitive matching is possible, // so UTF8_LCASE is treated as UNICODE_CI in this context - return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality; + return CollationFactory.fetchCollation(collationId).isUtf8LcaseType; } - private static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE; + static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE; public static int collationAwareRegexFlags(final int collationId) { return supportsLowercaseRegex(collationId) ? lowercaseRegexFlags : 0; } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/SpecialCodePointConstants.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/SpecialCodePointConstants.java new file mode 100644 index 0000000000000..db615d747910b --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/SpecialCodePointConstants.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +/** + * 'SpecialCodePointConstants' is introduced in order to keep the codepoints used in + * 'CollationAwareUTF8String' in one place. + */ +public class SpecialCodePointConstants { + + public static final int COMBINING_DOT = 0x0307; + public static final int ASCII_SMALL_I = 0x0069; + public static final int ASCII_SPACE = 0x0020; + public static final int GREEK_CAPITAL_SIGMA = 0x03A3; + public static final int GREEK_SMALL_SIGMA = 0x03C3; + public static final int GREEK_FINAL_SIGMA = 0x03C2; + public static final int CAPITAL_I_WITH_DOT_ABOVE = 0x0130; +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java index 481ea89090b2a..3a697345ce1f5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java @@ -96,4 +96,33 @@ public void appendBytes(Object base, long offset, int length) { public UTF8String build() { return UTF8String.fromBytes(buffer, 0, totalSize()); } + + public void appendCodePoint(int codePoint) { + if (codePoint <= 0x7F) { + grow(1); + buffer[cursor - Platform.BYTE_ARRAY_OFFSET] = (byte) codePoint; + ++cursor; + } else if (codePoint <= 0x7FF) { + grow(2); + buffer[cursor - Platform.BYTE_ARRAY_OFFSET] = (byte) (0xC0 | (codePoint >> 6)); + buffer[cursor + 1 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | (codePoint & 0x3F)); + cursor += 2; + } else if (codePoint <= 0xFFFF) { + grow(3); + buffer[cursor - Platform.BYTE_ARRAY_OFFSET] = (byte) (0xE0 | (codePoint >> 12)); + buffer[cursor + 1 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | ((codePoint >> 6) & 0x3F)); + buffer[cursor + 2 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | (codePoint & 0x3F)); + cursor += 3; + } else if (codePoint <= 0x10FFFF) { + grow(4); + buffer[cursor - Platform.BYTE_ARRAY_OFFSET] = (byte) (0xF0 | (codePoint >> 18)); + buffer[cursor + 1 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | ((codePoint >> 12) & 0x3F)); + buffer[cursor + 2 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | ((codePoint >> 6) & 0x3F)); + buffer[cursor + 3 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | (codePoint & 0x3F)); + cursor += 4; + } else { + throw new IllegalArgumentException("Invalid Unicode codePoint: " + codePoint); + } + } + } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e6bddb12da56b..caf8461b0b5d6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -83,6 +83,20 @@ private enum UTF8StringValidity { */ private volatile int numBytesValid = -1; + /** + * The ASCII-ness of the UTF8Strings can be cached to avoid repeated checks, because that + * operation requires full string scan. Full ASCII strings contain only ASCII characters. + */ + private enum IsFullAscii { + UNKNOWN, FULL_ASCII, NOT_ASCII + } + + /** + * Internal flag to indicate whether the string is full ASCII or not. Initially, the ASCII-ness + * is UNKNOWN, and will be set to either FULL_ASCII or NOT_ASCII after the first check. + */ + private volatile IsFullAscii isFullAscii = IsFullAscii.UNKNOWN; + public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } @@ -127,6 +141,7 @@ private enum UTF8StringValidity { private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); public static final UTF8String ZERO_UTF8 = UTF8String.fromString("0"); + public static final UTF8String SPACE_UTF8 = UTF8String.fromString(" "); /** @@ -788,12 +803,19 @@ public UTF8String toLowerCase() { } public boolean isFullAscii() { + if (isFullAscii == IsFullAscii.UNKNOWN) { + isFullAscii = getIsFullAscii(); + } + return isFullAscii == IsFullAscii.FULL_ASCII; + } + + private IsFullAscii getIsFullAscii() { for (var i = 0; i < numBytes; i++) { if (getByte(i) < 0) { - return false; + return IsFullAscii.NOT_ASCII; } } - return true; + return IsFullAscii.FULL_ASCII; } private UTF8String toLowerCaseSlow() { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index b082ab21944f7..a696da8cf45b8 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -40,165 +40,224 @@ public class CollationSupportSuite { {"UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"}; /** - * Collation-aware UTF8String comparison. + * Collation-aware UTF8String comparison and equality check. */ - private void assertStringCompare(String s1, String s2, String collationName, int expected) + private void assertCompare(String s1, String s2, String collationName, int expected) throws SparkException { UTF8String l = UTF8String.fromString(s1); UTF8String r = UTF8String.fromString(s2); + // Test the comparator, which is the most general way to compare strings with collations. int compare = CollationFactory.fetchCollation(collationName).comparator.compare(l, r); assertEquals(Integer.signum(expected), Integer.signum(compare)); + // Test the equals function, which may be faster than the comparator for equality checks. + boolean equals = CollationFactory.fetchCollation(collationName).equalsFunction.apply(l ,r); + assertEquals(expected == 0, equals); } @Test public void testCompare() throws SparkException { for (String collationName: testSupportedCollations) { - // Edge cases - assertStringCompare("", "", collationName, 0); - assertStringCompare("a", "", collationName, 1); - assertStringCompare("", "a", collationName, -1); - // Basic tests - assertStringCompare("a", "a", collationName, 0); - assertStringCompare("a", "b", collationName, -1); - assertStringCompare("b", "a", collationName, 1); - assertStringCompare("A", "A", collationName, 0); - assertStringCompare("A", "B", collationName, -1); - assertStringCompare("B", "A", collationName, 1); - assertStringCompare("aa", "a", collationName, 1); - assertStringCompare("b", "bb", collationName, -1); - assertStringCompare("abc", "a", collationName, 1); - assertStringCompare("abc", "b", collationName, -1); - assertStringCompare("abc", "ab", collationName, 1); - assertStringCompare("abc", "abc", collationName, 0); - // ASCII strings - assertStringCompare("aaaa", "aaa", collationName, 1); - assertStringCompare("hello", "world", collationName, -1); - assertStringCompare("Spark", "Spark", collationName, 0); - // Non-ASCII strings - assertStringCompare("ü", "ü", collationName, 0); - assertStringCompare("ü", "", collationName, 1); - assertStringCompare("", "ü", collationName, -1); - assertStringCompare("äü", "äü", collationName, 0); - assertStringCompare("äxx", "äx", collationName, 1); - assertStringCompare("a", "ä", collationName, -1); + // Empty strings. + assertCompare("", "", collationName, 0); + assertCompare("a", "", collationName, 1); + assertCompare("", "a", collationName, -1); + // Basic tests. + assertCompare("a", "a", collationName, 0); + assertCompare("a", "b", collationName, -1); + assertCompare("b", "a", collationName, 1); + assertCompare("A", "A", collationName, 0); + assertCompare("A", "B", collationName, -1); + assertCompare("B", "A", collationName, 1); + assertCompare("aa", "a", collationName, 1); + assertCompare("b", "bb", collationName, -1); + assertCompare("abc", "a", collationName, 1); + assertCompare("abc", "b", collationName, -1); + assertCompare("abc", "ab", collationName, 1); + assertCompare("abc", "abc", collationName, 0); + assertCompare("aaaa", "aaa", collationName, 1); + assertCompare("hello", "world", collationName, -1); + assertCompare("Spark", "Spark", collationName, 0); + assertCompare("ü", "ü", collationName, 0); + assertCompare("ü", "", collationName, 1); + assertCompare("", "ü", collationName, -1); + assertCompare("äü", "äü", collationName, 0); + assertCompare("äxx", "äx", collationName, 1); + assertCompare("a", "ä", collationName, -1); } - // Non-ASCII strings - assertStringCompare("äü", "bü", "UTF8_BINARY", 1); - assertStringCompare("bxx", "bü", "UTF8_BINARY", -1); - assertStringCompare("äü", "bü", "UTF8_LCASE", 1); - assertStringCompare("bxx", "bü", "UTF8_LCASE", -1); - assertStringCompare("äü", "bü", "UNICODE", -1); - assertStringCompare("bxx", "bü", "UNICODE", 1); - assertStringCompare("äü", "bü", "UNICODE_CI", -1); - assertStringCompare("bxx", "bü", "UNICODE_CI", 1); - // Case variation - assertStringCompare("AbCd", "aBcD", "UTF8_BINARY", -1); - assertStringCompare("ABCD", "abcd", "UTF8_LCASE", 0); - assertStringCompare("AbcD", "aBCd", "UNICODE", 1); - assertStringCompare("abcd", "ABCD", "UNICODE_CI", 0); - // Accent variation - assertStringCompare("aBćD", "ABĆD", "UTF8_BINARY", 1); - assertStringCompare("AbCδ", "ABCΔ", "UTF8_LCASE", 0); - assertStringCompare("äBCd", "ÄBCD", "UNICODE", -1); - assertStringCompare("Ab́cD", "AB́CD", "UNICODE_CI", 0); - // Case-variable character length - assertStringCompare("i\u0307", "İ", "UTF8_BINARY", -1); - assertStringCompare("İ", "i\u0307", "UTF8_BINARY", 1); - assertStringCompare("i\u0307", "İ", "UTF8_LCASE", 0); - assertStringCompare("İ", "i\u0307", "UTF8_LCASE", 0); - assertStringCompare("i\u0307", "İ", "UNICODE", -1); - assertStringCompare("İ", "i\u0307", "UNICODE", 1); - assertStringCompare("i\u0307", "İ", "UNICODE_CI", 0); - assertStringCompare("İ", "i\u0307", "UNICODE_CI", 0); - assertStringCompare("i\u0307İ", "i\u0307İ", "UTF8_LCASE", 0); - assertStringCompare("i\u0307İ", "İi\u0307", "UTF8_LCASE", 0); - assertStringCompare("İi\u0307", "i\u0307İ", "UTF8_LCASE", 0); - assertStringCompare("İi\u0307", "İi\u0307", "UTF8_LCASE", 0); - assertStringCompare("i\u0307İ", "i\u0307İ", "UNICODE_CI", 0); - assertStringCompare("i\u0307İ", "İi\u0307", "UNICODE_CI", 0); - assertStringCompare("İi\u0307", "i\u0307İ", "UNICODE_CI", 0); - assertStringCompare("İi\u0307", "İi\u0307", "UNICODE_CI", 0); - // Conditional case mapping - assertStringCompare("ς", "σ", "UTF8_BINARY", -1); - assertStringCompare("ς", "Σ", "UTF8_BINARY", 1); - assertStringCompare("σ", "Σ", "UTF8_BINARY", 1); - assertStringCompare("ς", "σ", "UTF8_LCASE", 0); - assertStringCompare("ς", "Σ", "UTF8_LCASE", 0); - assertStringCompare("σ", "Σ", "UTF8_LCASE", 0); - assertStringCompare("ς", "σ", "UNICODE", 1); - assertStringCompare("ς", "Σ", "UNICODE", 1); - assertStringCompare("σ", "Σ", "UNICODE", -1); - assertStringCompare("ς", "σ", "UNICODE_CI", 0); - assertStringCompare("ς", "Σ", "UNICODE_CI", 0); - assertStringCompare("σ", "Σ", "UNICODE_CI", 0); + // Advanced tests. + assertCompare("äü", "bü", "UTF8_BINARY", 1); + assertCompare("bxx", "bü", "UTF8_BINARY", -1); + assertCompare("äü", "bü", "UTF8_LCASE", 1); + assertCompare("bxx", "bü", "UTF8_LCASE", -1); + assertCompare("äü", "bü", "UNICODE", -1); + assertCompare("bxx", "bü", "UNICODE", 1); + assertCompare("äü", "bü", "UNICODE_CI", -1); + assertCompare("bxx", "bü", "UNICODE_CI", 1); + assertCompare("cČć", "ČćC", "SR_CI_AI", 0); + // Case variation. + assertCompare("AbCd", "aBcD", "UTF8_BINARY", -1); + assertCompare("ABCD", "abcd", "UTF8_LCASE", 0); + assertCompare("AbcD", "aBCd", "UNICODE", 1); + assertCompare("abcd", "ABCD", "UNICODE_CI", 0); + // Accent variation. + assertCompare("aBćD", "ABĆD", "UTF8_BINARY", 1); + assertCompare("AbCδ", "ABCΔ", "UTF8_LCASE", 0); + assertCompare("äBCd", "ÄBCD", "UNICODE", -1); + assertCompare("Ab́cD", "AB́CD", "UNICODE_CI", 0); + assertCompare("ÈÉÊË", "EeEe", "AF_CI_AI", 0); + // One-to-many case mapping (e.g. Turkish dotted I). + assertCompare("i\u0307", "İ", "UTF8_BINARY", -1); + assertCompare("İ", "i\u0307", "UTF8_BINARY", 1); + assertCompare("i\u0307", "İ", "UTF8_LCASE", 0); + assertCompare("İ", "i\u0307", "UTF8_LCASE", 0); + assertCompare("i\u0307", "İ", "UNICODE", -1); + assertCompare("İ", "i\u0307", "UNICODE", 1); + assertCompare("i\u0307", "İ", "UNICODE_CI", 0); + assertCompare("İ", "i\u0307", "UNICODE_CI", 0); + assertCompare("i\u0307İ", "i\u0307İ", "UTF8_LCASE", 0); + assertCompare("i\u0307İ", "İi\u0307", "UTF8_LCASE", 0); + assertCompare("İi\u0307", "i\u0307İ", "UTF8_LCASE", 0); + assertCompare("İi\u0307", "İi\u0307", "UTF8_LCASE", 0); + assertCompare("i\u0307İ", "i\u0307İ", "UNICODE_CI", 0); + assertCompare("i\u0307İ", "İi\u0307", "UNICODE_CI", 0); + assertCompare("İi\u0307", "i\u0307İ", "UNICODE_CI", 0); + assertCompare("İi\u0307", "İi\u0307", "UNICODE_CI", 0); + // Conditional case mapping (e.g. Greek sigmas). + assertCompare("ς", "σ", "UTF8_BINARY", -1); + assertCompare("ς", "Σ", "UTF8_BINARY", 1); + assertCompare("σ", "Σ", "UTF8_BINARY", 1); + assertCompare("ς", "σ", "UTF8_LCASE", 0); + assertCompare("ς", "Σ", "UTF8_LCASE", 0); + assertCompare("σ", "Σ", "UTF8_LCASE", 0); + assertCompare("ς", "σ", "UNICODE", 1); + assertCompare("ς", "Σ", "UNICODE", 1); + assertCompare("σ", "Σ", "UNICODE", -1); + assertCompare("ς", "σ", "UNICODE_CI", 0); + assertCompare("ς", "Σ", "UNICODE_CI", 0); + assertCompare("σ", "Σ", "UNICODE_CI", 0); + // Surrogate pairs. + assertCompare("a🙃b🙃c", "aaaaa", "UTF8_BINARY", 1); + assertCompare("a🙃b🙃c", "aaaaa", "UTF8_LCASE", 1); + assertCompare("a🙃b🙃c", "aaaaa", "UNICODE", -1); // != UTF8_BINARY + assertCompare("a🙃b🙃c", "aaaaa", "UNICODE_CI", -1); // != UTF8_LCASE + assertCompare("a🙃b🙃c", "a🙃b🙃c", "UTF8_BINARY", 0); + assertCompare("a🙃b🙃c", "a🙃b🙃c", "UTF8_LCASE", 0); + assertCompare("a🙃b🙃c", "a🙃b🙃c", "UNICODE", 0); + assertCompare("a🙃b🙃c", "a🙃b🙃c", "UNICODE_CI", 0); + assertCompare("a🙃b🙃c", "a🙃b🙃d", "UTF8_BINARY", -1); + assertCompare("a🙃b🙃c", "a🙃b🙃d", "UTF8_LCASE", -1); + assertCompare("a🙃b🙃c", "a🙃b🙃d", "UNICODE", -1); + assertCompare("a🙃b🙃c", "a🙃b🙃d", "UNICODE_CI", -1); // Maximum code point. int maxCodePoint = Character.MAX_CODE_POINT; String maxCodePointStr = new String(Character.toChars(maxCodePoint)); for (int i = 0; i < maxCodePoint && Character.isValidCodePoint(i); ++i) { - assertStringCompare(new String(Character.toChars(i)), maxCodePointStr, "UTF8_BINARY", -1); - assertStringCompare(new String(Character.toChars(i)), maxCodePointStr, "UTF8_LCASE", -1); + assertCompare(new String(Character.toChars(i)), maxCodePointStr, "UTF8_BINARY", -1); + assertCompare(new String(Character.toChars(i)), maxCodePointStr, "UTF8_LCASE", -1); } // Minimum code point. int minCodePoint = Character.MIN_CODE_POINT; String minCodePointStr = new String(Character.toChars(minCodePoint)); for (int i = minCodePoint + 1; i <= maxCodePoint && Character.isValidCodePoint(i); ++i) { - assertStringCompare(new String(Character.toChars(i)), minCodePointStr, "UTF8_BINARY", 1); - assertStringCompare(new String(Character.toChars(i)), minCodePointStr, "UTF8_LCASE", 1); + assertCompare(new String(Character.toChars(i)), minCodePointStr, "UTF8_BINARY", 1); + assertCompare(new String(Character.toChars(i)), minCodePointStr, "UTF8_LCASE", 1); } } - private void assertLowerCaseCodePoints(UTF8String target, UTF8String expected, - Boolean useCodePoints) { + /** + * Collation-aware UTF8String lowercase conversion. + */ + + private void assertLowerCaseCodePoints(String string, String expected, Boolean useCodePoints) { + UTF8String str = UTF8String.fromString(string); if (useCodePoints) { - assertEquals(expected, CollationAwareUTF8String.lowerCaseCodePoints(target)); + UTF8String result = CollationAwareUTF8String.lowerCaseCodePoints(str); + assertEquals(UTF8String.fromString(expected), result); } else { - assertEquals(expected, target.toLowerCase()); + UTF8String result = str.toLowerCase(); + assertEquals(UTF8String.fromString(expected), result); } } @Test public void testLowerCaseCodePoints() { - // Edge cases - assertLowerCaseCodePoints(UTF8String.fromString(""), UTF8String.fromString(""), false); - assertLowerCaseCodePoints(UTF8String.fromString(""), UTF8String.fromString(""), true); - // Basic tests - assertLowerCaseCodePoints(UTF8String.fromString("abcd"), UTF8String.fromString("abcd"), false); - assertLowerCaseCodePoints(UTF8String.fromString("AbCd"), UTF8String.fromString("abcd"), false); - assertLowerCaseCodePoints(UTF8String.fromString("abcd"), UTF8String.fromString("abcd"), true); - assertLowerCaseCodePoints(UTF8String.fromString("aBcD"), UTF8String.fromString("abcd"), true); - // Accent variation - assertLowerCaseCodePoints(UTF8String.fromString("AbĆd"), UTF8String.fromString("abćd"), false); - assertLowerCaseCodePoints(UTF8String.fromString("aBcΔ"), UTF8String.fromString("abcδ"), true); - // Case-variable character length - assertLowerCaseCodePoints( - UTF8String.fromString("İoDiNe"), UTF8String.fromString("i̇odine"), false); - assertLowerCaseCodePoints( - UTF8String.fromString("Abi̇o12"), UTF8String.fromString("abi̇o12"), false); - assertLowerCaseCodePoints( - UTF8String.fromString("İodInE"), UTF8String.fromString("i̇odine"), true); - assertLowerCaseCodePoints( - UTF8String.fromString("aBi̇o12"), UTF8String.fromString("abi̇o12"), true); - // Conditional case mapping - assertLowerCaseCodePoints( - UTF8String.fromString("ΘΑΛΑΣΣΙΝΟΣ"), UTF8String.fromString("θαλασσινος"), false); - assertLowerCaseCodePoints( - UTF8String.fromString("ΘΑΛΑΣΣΙΝΟΣ"), UTF8String.fromString("θαλασσινοσ"), true); - // Surrogate pairs are treated as invalid UTF8 sequences - assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[] - {(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}), - UTF8String.fromString("\uFFFD\uFFFD"), false); - assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[] - {(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}), - UTF8String.fromString("\uFFFD\uFFFD\uFFFD\uFFFD\uFFFD\uFFFD"), true); // != Java toLowerCase + // Empty strings. + assertLowerCaseCodePoints("", "", false); + assertLowerCaseCodePoints("", "", true); + // Basic tests. + assertLowerCaseCodePoints("xyz", "xyz", false); + assertLowerCaseCodePoints("xyz", "xyz", true); + assertLowerCaseCodePoints("abcd", "abcd", false); + assertLowerCaseCodePoints("abcd", "abcd", true); + // Advanced tests. + assertLowerCaseCodePoints("你好", "你好", false); + assertLowerCaseCodePoints("你好", "你好", true); + assertLowerCaseCodePoints("Γειά", "γειά", false); + assertLowerCaseCodePoints("Γειά", "γειά", true); + assertLowerCaseCodePoints("Здраво", "здраво", false); + assertLowerCaseCodePoints("Здраво", "здраво", true); + // Case variation. + assertLowerCaseCodePoints("xYz", "xyz", false); + assertLowerCaseCodePoints("xYz", "xyz", true); + assertLowerCaseCodePoints("AbCd", "abcd", false); + assertLowerCaseCodePoints("aBcD", "abcd", true); + // Accent variation. + assertLowerCaseCodePoints("äbć", "äbć", false); + assertLowerCaseCodePoints("äbć", "äbć", true); + assertLowerCaseCodePoints("AbĆd", "abćd", false); + assertLowerCaseCodePoints("aBcΔ", "abcδ", true); + // One-to-many case mapping (e.g. Turkish dotted I). + assertLowerCaseCodePoints("i\u0307", "i\u0307", false); + assertLowerCaseCodePoints("i\u0307", "i\u0307", true); + assertLowerCaseCodePoints("I\u0307", "i\u0307", false); + assertLowerCaseCodePoints("I\u0307", "i\u0307", true); + assertLowerCaseCodePoints("İ", "i\u0307", false); + assertLowerCaseCodePoints("İ", "i\u0307", true); + assertLowerCaseCodePoints("İİİ", "i\u0307i\u0307i\u0307", false); + assertLowerCaseCodePoints("İİİ", "i\u0307i\u0307i\u0307", true); + assertLowerCaseCodePoints("İiIi\u0307", "i\u0307iii\u0307", false); + assertLowerCaseCodePoints("İiIi\u0307", "i\u0307iii\u0307", true); + assertLowerCaseCodePoints("İoDiNe", "i\u0307odine", false); + assertLowerCaseCodePoints("İodInE", "i\u0307odine", true); + assertLowerCaseCodePoints("Abi\u0307o12", "abi\u0307o12", false); + assertLowerCaseCodePoints("aBi\u0307o12", "abi\u0307o12", true); + // Conditional case mapping (e.g. Greek sigmas). + assertLowerCaseCodePoints("ς", "ς", false); + assertLowerCaseCodePoints("ς", "σ", true); + assertLowerCaseCodePoints("σ", "σ", false); + assertLowerCaseCodePoints("σ", "σ", true); + assertLowerCaseCodePoints("Σ", "σ", false); + assertLowerCaseCodePoints("Σ", "σ", true); + assertLowerCaseCodePoints("ςΑΛΑΤΑ", "ςαλατα", false); + assertLowerCaseCodePoints("ςΑΛΑΤΑ", "σαλατα", true); + assertLowerCaseCodePoints("σΑΛΑΤΑ", "σαλατα", false); + assertLowerCaseCodePoints("σΑΛΑΤΑ", "σαλατα", true); + assertLowerCaseCodePoints("ΣΑΛΑΤΑ", "σαλατα", false); + assertLowerCaseCodePoints("ΣΑΛΑΤΑ", "σαλατα", true); + assertLowerCaseCodePoints("ΘΑΛΑΣΣΙΝΟς", "θαλασσινος", false); + assertLowerCaseCodePoints("ΘΑΛΑΣΣΙΝΟς", "θαλασσινοσ", true); + assertLowerCaseCodePoints("ΘΑΛΑΣΣΙΝΟσ", "θαλασσινοσ", false); + assertLowerCaseCodePoints("ΘΑΛΑΣΣΙΝΟσ", "θαλασσινοσ", true); + assertLowerCaseCodePoints("ΘΑΛΑΣΣΙΝΟΣ", "θαλασσινος", false); + assertLowerCaseCodePoints("ΘΑΛΑΣΣΙΝΟΣ", "θαλασσινοσ", true); + // Surrogate pairs. + assertLowerCaseCodePoints("a🙃b🙃c", "a🙃b🙃c", false); + assertLowerCaseCodePoints("a🙃b🙃c", "a🙃b🙃c", true); + assertLowerCaseCodePoints("😀😆😃😄😄😆", "😀😆😃😄😄😆", false); + assertLowerCaseCodePoints("😀😆😃😄😄😆", "😀😆😃😄😄😆", true); + assertLowerCaseCodePoints("𐐅", "𐐭", false); + assertLowerCaseCodePoints("𐐅", "𐐭", true); + assertLowerCaseCodePoints("𝔸", "𝔸", false); + assertLowerCaseCodePoints("𝔸", "𝔸", true); } /** - * Collation-aware string expressions. + * Verify the behaviour of the `Contains` collation support class. */ - private void assertContains(String pattern, String target, String collationName, boolean expected) - throws SparkException { + private void assertContains(String pattern, String target, String collationName, + boolean expected) throws SparkException { UTF8String l = UTF8String.fromString(pattern); UTF8String r = UTF8String.fromString(target); int collationId = CollationFactory.collationNameToId(collationName); @@ -207,20 +266,42 @@ private void assertContains(String pattern, String target, String collationName, @Test public void testContains() throws SparkException { - // Edge cases - assertContains("", "", "UTF8_BINARY", true); - assertContains("c", "", "UTF8_BINARY", true); - assertContains("", "c", "UTF8_BINARY", false); - assertContains("", "", "UNICODE", true); - assertContains("c", "", "UNICODE", true); - assertContains("", "c", "UNICODE", false); - assertContains("", "", "UTF8_LCASE", true); - assertContains("c", "", "UTF8_LCASE", true); - assertContains("", "c", "UTF8_LCASE", false); - assertContains("", "", "UNICODE_CI", true); - assertContains("c", "", "UNICODE_CI", true); - assertContains("", "c", "UNICODE_CI", false); - // Basic tests + for (String collationName: testSupportedCollations) { + // Empty strings. + assertContains("", "", collationName, true); + assertContains("a", "", collationName, true); + assertContains("", "x", collationName, false); + // Basic tests. + assertContains("a", "a", collationName, true); + assertContains("_a_", "_a_", collationName, true); + assertContains("_a_", "a", collationName, true); + assertContains("%a%", "%a%", collationName, true); + assertContains("%a%", "a", collationName, true); + assertContains("*a*", "*a*", collationName, true); + assertContains("*a*", "a", collationName, true); + assertContains("?a?", "?a?", collationName, true); + assertContains("?a?", "a", collationName, true); + assertContains("/a/", "/a/", collationName, true); + assertContains("/a/", "a", collationName, true); + assertContains("abcde", "xyz", collationName, false); + assertContains("abcde", "bcd", collationName, true); + assertContains("abcde", "abc", collationName, true); + assertContains("abcde", "cde", collationName, true); + assertContains("abcde", "abcde", collationName, true); + assertContains("你好", "x", collationName, false); + assertContains("你好", "你", collationName, true); + assertContains("你好", "好", collationName, true); + assertContains("你好", "你好", collationName, true); + assertContains("Γειά", "x", collationName, false); + assertContains("Γειά", "ειά", collationName, true); + assertContains("Γειά", "Γει", collationName, true); + assertContains("Γειά", "Γειά", collationName, true); + assertContains("Здраво", "x", collationName, false); + assertContains("Здраво", "драво", collationName, true); + assertContains("Здраво", "Здрав", collationName, true); + assertContains("Здраво", "Здраво", collationName, true); + } + // Advanced tests. assertContains("abcde", "bcd", "UTF8_BINARY", true); assertContains("abcde", "bde", "UTF8_BINARY", false); assertContains("abcde", "fgh", "UTF8_BINARY", false); @@ -233,25 +314,6 @@ public void testContains() throws SparkException { assertContains("abcde", "c", "UNICODE_CI", true); assertContains("abcde", "bCD", "UNICODE_CI", true); assertContains("abcde", "123", "UNICODE_CI", false); - // Case variation - assertContains("aBcDe", "bcd", "UTF8_BINARY", false); - assertContains("aBcDe", "BcD", "UTF8_BINARY", true); - assertContains("aBcDe", "abcde", "UNICODE", false); - assertContains("aBcDe", "aBcDe", "UNICODE", true); - assertContains("aBcDe", "bcd", "UTF8_LCASE", true); - assertContains("aBcDe", "BCD", "UTF8_LCASE", true); - assertContains("aBcDe", "abcde", "UNICODE_CI", true); - assertContains("aBcDe", "AbCdE", "UNICODE_CI", true); - // Accent variation - assertContains("aBcDe", "bćd", "UTF8_BINARY", false); - assertContains("aBcDe", "BćD", "UTF8_BINARY", false); - assertContains("aBcDe", "abćde", "UNICODE", false); - assertContains("aBcDe", "aBćDe", "UNICODE", false); - assertContains("aBcDe", "bćd", "UTF8_LCASE", false); - assertContains("aBcDe", "BĆD", "UTF8_LCASE", false); - assertContains("aBcDe", "abćde", "UNICODE_CI", false); - assertContains("aBcDe", "AbĆdE", "UNICODE_CI", false); - // Variable byte length characters assertContains("ab世De", "b世D", "UTF8_BINARY", true); assertContains("ab世De", "B世d", "UTF8_BINARY", false); assertContains("äbćδe", "bćδ", "UTF8_BINARY", true); @@ -268,45 +330,181 @@ public void testContains() throws SparkException { assertContains("ab世De", "AB世dE", "UNICODE_CI", true); assertContains("äbćδe", "ÄbćδE", "UNICODE_CI", true); assertContains("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); - // Characters with the same binary lowercase representation assertContains("The Kelvin.", "Kelvin", "UTF8_LCASE", true); assertContains("The Kelvin.", "Kelvin", "UTF8_LCASE", true); assertContains("The KKelvin.", "KKelvin", "UTF8_LCASE", true); assertContains("2 Kelvin.", "2 Kelvin", "UTF8_LCASE", true); assertContains("2 Kelvin.", "2 Kelvin", "UTF8_LCASE", true); assertContains("The KKelvin.", "KKelvin,", "UTF8_LCASE", false); - // Case-variable character length - assertContains("i̇", "i", "UNICODE_CI", false); - assertContains("i̇", "\u0307", "UNICODE_CI", false); - assertContains("i̇", "İ", "UNICODE_CI", true); + assertContains("abčćd", "ABCCD", "SR_CI_AI", true); + // Case variation. + assertContains("aBcDe", "bcd", "UTF8_BINARY", false); + assertContains("aBcDe", "BcD", "UTF8_BINARY", true); + assertContains("aBcDe", "abcde", "UNICODE", false); + assertContains("aBcDe", "aBcDe", "UNICODE", true); + assertContains("aBcDe", "bcd", "UTF8_LCASE", true); + assertContains("aBcDe", "BCD", "UTF8_LCASE", true); + assertContains("aBcDe", "abcde", "UNICODE_CI", true); + assertContains("aBcDe", "AbCdE", "UNICODE_CI", true); + // Accent variation. + assertContains("aBcDe", "bćd", "UTF8_BINARY", false); + assertContains("aBcDe", "BćD", "UTF8_BINARY", false); + assertContains("aBcDe", "abćde", "UNICODE", false); + assertContains("aBcDe", "aBćDe", "UNICODE", false); + assertContains("aBcDe", "bćd", "UTF8_LCASE", false); + assertContains("aBcDe", "BĆD", "UTF8_LCASE", false); + assertContains("aBcDe", "abćde", "UNICODE_CI", false); + assertContains("aBcDe", "AbĆdE", "UNICODE_CI", false); + assertContains("abEEE", "Bèêë", "AF_CI_AI", true); + // One-to-many case mapping (e.g. Turkish dotted I). + assertContains("i\u0307", "i", "UNICODE_CI", false); + assertContains("i\u0307", "\u0307", "UNICODE_CI", false); + assertContains("i\u0307", "İ", "UNICODE_CI", true); assertContains("İ", "i", "UNICODE_CI", false); assertContains("adi̇os", "io", "UNICODE_CI", false); assertContains("adi̇os", "Io", "UNICODE_CI", false); - assertContains("adi̇os", "i̇o", "UNICODE_CI", true); + assertContains("adi̇os", "i\u0307o", "UNICODE_CI", true); assertContains("adi̇os", "İo", "UNICODE_CI", true); assertContains("adİos", "io", "UNICODE_CI", false); assertContains("adİos", "Io", "UNICODE_CI", false); - assertContains("adİos", "i̇o", "UNICODE_CI", true); + assertContains("adİos", "i\u0307o", "UNICODE_CI", true); assertContains("adİos", "İo", "UNICODE_CI", true); - assertContains("i̇", "i", "UTF8_LCASE", true); // != UNICODE_CI + assertContains("i\u0307", "i", "UTF8_LCASE", true); // != UNICODE_CI assertContains("İ", "\u0307", "UTF8_LCASE", false); assertContains("İ", "i", "UTF8_LCASE", false); - assertContains("i̇", "\u0307", "UTF8_LCASE", true); // != UNICODE_CI - assertContains("i̇", "İ", "UTF8_LCASE", true); + assertContains("i\u0307", "\u0307", "UTF8_LCASE", true); // != UNICODE_CI + assertContains("i\u0307", "İ", "UTF8_LCASE", true); assertContains("İ", "i", "UTF8_LCASE", false); assertContains("adi̇os", "io", "UTF8_LCASE", false); assertContains("adi̇os", "Io", "UTF8_LCASE", false); - assertContains("adi̇os", "i̇o", "UTF8_LCASE", true); + assertContains("adi̇os", "i\u0307o", "UTF8_LCASE", true); assertContains("adi̇os", "İo", "UTF8_LCASE", true); assertContains("adİos", "io", "UTF8_LCASE", false); assertContains("adİos", "Io", "UTF8_LCASE", false); - assertContains("adİos", "i̇o", "UTF8_LCASE", true); + assertContains("adİos", "i\u0307o", "UTF8_LCASE", true); assertContains("adİos", "İo", "UTF8_LCASE", true); + // Conditional case mapping (e.g. Greek sigmas). + assertContains("σ", "σ", "UTF8_BINARY", true); + assertContains("σ", "ς", "UTF8_BINARY", false); + assertContains("σ", "Σ", "UTF8_BINARY", false); + assertContains("ς", "σ", "UTF8_BINARY", false); + assertContains("ς", "ς", "UTF8_BINARY", true); + assertContains("ς", "Σ", "UTF8_BINARY", false); + assertContains("Σ", "σ", "UTF8_BINARY", false); + assertContains("Σ", "ς", "UTF8_BINARY", false); + assertContains("Σ", "Σ", "UTF8_BINARY", true); + assertContains("σ", "σ", "UTF8_LCASE", true); + assertContains("σ", "ς", "UTF8_LCASE", true); + assertContains("σ", "Σ", "UTF8_LCASE", true); + assertContains("ς", "σ", "UTF8_LCASE", true); + assertContains("ς", "ς", "UTF8_LCASE", true); + assertContains("ς", "Σ", "UTF8_LCASE", true); + assertContains("Σ", "σ", "UTF8_LCASE", true); + assertContains("Σ", "ς", "UTF8_LCASE", true); + assertContains("Σ", "Σ", "UTF8_LCASE", true); + assertContains("σ", "σ", "UNICODE", true); + assertContains("σ", "ς", "UNICODE", false); + assertContains("σ", "Σ", "UNICODE", false); + assertContains("ς", "σ", "UNICODE", false); + assertContains("ς", "ς", "UNICODE", true); + assertContains("ς", "Σ", "UNICODE", false); + assertContains("Σ", "σ", "UNICODE", false); + assertContains("Σ", "ς", "UNICODE", false); + assertContains("Σ", "Σ", "UNICODE", true); + assertContains("σ", "σ", "UNICODE_CI", true); + assertContains("σ", "ς", "UNICODE_CI", true); + assertContains("σ", "Σ", "UNICODE_CI", true); + assertContains("ς", "σ", "UNICODE_CI", true); + assertContains("ς", "ς", "UNICODE_CI", true); + assertContains("ς", "Σ", "UNICODE_CI", true); + assertContains("Σ", "σ", "UNICODE_CI", true); + assertContains("Σ", "ς", "UNICODE_CI", true); + assertContains("Σ", "Σ", "UNICODE_CI", true); + assertContains("ΣΑΛΑΤΑ", "Σ", "UTF8_BINARY", true); + assertContains("ΣΑΛΑΤΑ", "σ", "UTF8_BINARY", false); + assertContains("ΣΑΛΑΤΑ", "ς", "UTF8_BINARY", false); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UTF8_BINARY", true); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UTF8_BINARY", false); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UTF8_BINARY", false); + assertContains("ΣΑΛΑΤΑ", "Σ", "UTF8_LCASE", true); + assertContains("ΣΑΛΑΤΑ", "σ", "UTF8_LCASE", true); + assertContains("ΣΑΛΑΤΑ", "ς", "UTF8_LCASE", true); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UTF8_LCASE", true); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UTF8_LCASE", true); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UTF8_LCASE", true); + assertContains("ΣΑΛΑΤΑ", "Σ", "UNICODE", true); + assertContains("ΣΑΛΑΤΑ", "σ", "UNICODE", false); + assertContains("ΣΑΛΑΤΑ", "ς", "UNICODE", false); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UNICODE", true); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UNICODE", false); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UNICODE", false); + assertContains("ΣΑΛΑΤΑ", "Σ", "UNICODE_CI", true); + assertContains("ΣΑΛΑΤΑ", "σ", "UNICODE_CI", true); + assertContains("ΣΑΛΑΤΑ", "ς", "UNICODE_CI", true); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UNICODE_CI", true); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UNICODE_CI", true); + assertContains("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UNICODE_CI", true); + // Surrogate pairs. + assertContains("a🙃b🙃c", "x", "UTF8_BINARY", false); + assertContains("a🙃b🙃c", "x", "UTF8_LCASE", false); + assertContains("a🙃b🙃c", "x", "UNICODE", false); + assertContains("a🙃b🙃c", "x", "UNICODE_CI", false); + assertContains("a🙃b🙃c", "b", "UTF8_BINARY", true); + assertContains("a🙃b🙃c", "b", "UTF8_LCASE", true); + assertContains("a🙃b🙃c", "b", "UNICODE", true); + assertContains("a🙃b🙃c", "b", "UNICODE_CI", true); + assertContains("a🙃b🙃c", "a🙃b", "UTF8_BINARY", true); + assertContains("a🙃b🙃c", "a🙃b", "UTF8_LCASE", true); + assertContains("a🙃b🙃c", "a🙃b", "UNICODE", true); + assertContains("a🙃b🙃c", "a🙃b", "UNICODE_CI", true); + assertContains("a🙃b🙃c", "b🙃c", "UTF8_BINARY", true); + assertContains("a🙃b🙃c", "b🙃c", "UTF8_LCASE", true); + assertContains("a🙃b🙃c", "b🙃c", "UNICODE", true); + assertContains("a🙃b🙃c", "b🙃c", "UNICODE_CI", true); + assertContains("a🙃b🙃c", "a🙃b🙃c", "UTF8_BINARY", true); + assertContains("a🙃b🙃c", "a🙃b🙃c", "UTF8_LCASE", true); + assertContains("a🙃b🙃c", "a🙃b🙃c", "UNICODE", true); + assertContains("a🙃b🙃c", "a🙃b🙃c", "UNICODE_CI", true); + assertContains("😀😆😃😄", "😄😆", "UTF8_BINARY", false); + assertContains("😀😆😃😄", "😄😆", "UTF8_LCASE", false); + assertContains("😀😆😃😄", "😄😆", "UNICODE", false); + assertContains("😀😆😃😄", "😄😆", "UNICODE_CI", false); + assertContains("😀😆😃😄", "😆😃", "UTF8_BINARY", true); + assertContains("😀😆😃😄", "😆😃", "UTF8_LCASE", true); + assertContains("😀😆😃😄", "😆😃", "UNICODE", true); + assertContains("😀😆😃😄", "😆😃", "UNICODE_CI", true); + assertContains("😀😆😃😄", "😀😆", "UTF8_BINARY", true); + assertContains("😀😆😃😄", "😀😆", "UTF8_LCASE", true); + assertContains("😀😆😃😄", "😀😆", "UNICODE", true); + assertContains("😀😆😃😄", "😀😆", "UNICODE_CI", true); + assertContains("😀😆😃😄", "😃😄", "UTF8_BINARY", true); + assertContains("😀😆😃😄", "😃😄", "UTF8_LCASE", true); + assertContains("😀😆😃😄", "😃😄", "UNICODE", true); + assertContains("😀😆😃😄", "😃😄", "UNICODE_CI", true); + assertContains("😀😆😃😄", "😀😆😃😄", "UTF8_BINARY", true); + assertContains("😀😆😃😄", "😀😆😃😄", "UTF8_LCASE", true); + assertContains("😀😆😃😄", "😀😆😃😄", "UNICODE", true); + assertContains("😀😆😃😄", "😀😆😃😄", "UNICODE_CI", true); + assertContains("𐐅", "𐐅", "UTF8_BINARY", true); + assertContains("𐐅", "𐐅", "UTF8_LCASE", true); + assertContains("𐐅", "𐐅", "UNICODE", true); + assertContains("𐐅", "𐐅", "UNICODE_CI", true); + assertContains("𐐅", "𐐭", "UTF8_BINARY", false); + assertContains("𐐅", "𐐭", "UTF8_LCASE", true); + assertContains("𐐅", "𐐭", "UNICODE", false); + assertContains("𐐅", "𐐭", "UNICODE_CI", true); + assertContains("𝔸", "𝔸", "UTF8_BINARY", true); + assertContains("𝔸", "𝔸", "UTF8_LCASE", true); + assertContains("𝔸", "𝔸", "UNICODE", true); + assertContains("𝔸", "𝔸", "UNICODE_CI", true); } - private void assertStartsWith( - String pattern, String prefix, String collationName, boolean expected) - throws SparkException { + /** + * Verify the behaviour of the `StartsWith` collation support class. + */ + + private void assertStartsWith(String pattern, String prefix, String collationName, + boolean expected) throws SparkException { UTF8String l = UTF8String.fromString(pattern); UTF8String r = UTF8String.fromString(prefix); int collationId = CollationFactory.collationNameToId(collationName); @@ -315,20 +513,42 @@ private void assertStartsWith( @Test public void testStartsWith() throws SparkException { - // Edge cases - assertStartsWith("", "", "UTF8_BINARY", true); - assertStartsWith("c", "", "UTF8_BINARY", true); - assertStartsWith("", "c", "UTF8_BINARY", false); - assertStartsWith("", "", "UNICODE", true); - assertStartsWith("c", "", "UNICODE", true); - assertStartsWith("", "c", "UNICODE", false); - assertStartsWith("", "", "UTF8_LCASE", true); - assertStartsWith("c", "", "UTF8_LCASE", true); - assertStartsWith("", "c", "UTF8_LCASE", false); - assertStartsWith("", "", "UNICODE_CI", true); - assertStartsWith("c", "", "UNICODE_CI", true); - assertStartsWith("", "c", "UNICODE_CI", false); - // Basic tests + for (String collationName: testSupportedCollations) { + // Empty strings. + assertStartsWith("", "", collationName, true); + assertStartsWith("a", "", collationName, true); + assertStartsWith("", "x", collationName, false); + // Basic tests. + assertStartsWith("a", "a", collationName, true); + assertStartsWith("_a_", "_a", collationName, true); + assertStartsWith("_a_", "a", collationName, false); + assertStartsWith("%a%", "%a", collationName, true); + assertStartsWith("%a%", "a", collationName, false); + assertStartsWith("*a*", "*a", collationName, true); + assertStartsWith("*a*", "a", collationName, false); + assertStartsWith("?a?", "?a", collationName, true); + assertStartsWith("?a?", "a", collationName, false); + assertStartsWith("/a/", "/a", collationName, true); + assertStartsWith("/a/", "a", collationName, false); + assertStartsWith("abcde", "xyz", collationName, false); + assertStartsWith("abcde", "bcd", collationName, false); + assertStartsWith("abcde", "abc", collationName, true); + assertStartsWith("abcde", "cde", collationName, false); + assertStartsWith("abcde", "abcde", collationName, true); + assertStartsWith("你好", "x", collationName, false); + assertStartsWith("你好", "你", collationName, true); + assertStartsWith("你好", "好", collationName, false); + assertStartsWith("你好", "你好", collationName, true); + assertStartsWith("Γειά", "x", collationName, false); + assertStartsWith("Γειά", "ειά", collationName, false); + assertStartsWith("Γειά", "Γει", collationName, true); + assertStartsWith("Γειά", "Γειά", collationName, true); + assertStartsWith("Здраво", "x", collationName, false); + assertStartsWith("Здраво", "драво", collationName, false); + assertStartsWith("Здраво", "Здрав", collationName, true); + assertStartsWith("Здраво", "Здраво", collationName, true); + } + // Advanced tests. assertStartsWith("abcde", "abc", "UTF8_BINARY", true); assertStartsWith("abcde", "abd", "UTF8_BINARY", false); assertStartsWith("abcde", "fgh", "UTF8_BINARY", false); @@ -342,25 +562,6 @@ public void testStartsWith() throws SparkException { assertStartsWith("abcde", "aBC", "UNICODE_CI", true); assertStartsWith("abcde", "bcd", "UNICODE_CI", false); assertStartsWith("abcde", "123", "UNICODE_CI", false); - // Case variation - assertStartsWith("aBcDe", "abc", "UTF8_BINARY", false); - assertStartsWith("aBcDe", "aBc", "UTF8_BINARY", true); - assertStartsWith("aBcDe", "abcde", "UNICODE", false); - assertStartsWith("aBcDe", "aBcDe", "UNICODE", true); - assertStartsWith("aBcDe", "abc", "UTF8_LCASE", true); - assertStartsWith("aBcDe", "ABC", "UTF8_LCASE", true); - assertStartsWith("aBcDe", "abcde", "UNICODE_CI", true); - assertStartsWith("aBcDe", "AbCdE", "UNICODE_CI", true); - // Accent variation - assertStartsWith("aBcDe", "abć", "UTF8_BINARY", false); - assertStartsWith("aBcDe", "aBć", "UTF8_BINARY", false); - assertStartsWith("aBcDe", "abćde", "UNICODE", false); - assertStartsWith("aBcDe", "aBćDe", "UNICODE", false); - assertStartsWith("aBcDe", "abć", "UTF8_LCASE", false); - assertStartsWith("aBcDe", "ABĆ", "UTF8_LCASE", false); - assertStartsWith("aBcDe", "abćde", "UNICODE_CI", false); - assertStartsWith("aBcDe", "AbĆdE", "UNICODE_CI", false); - // Variable byte length characters assertStartsWith("ab世De", "ab世", "UTF8_BINARY", true); assertStartsWith("ab世De", "aB世", "UTF8_BINARY", false); assertStartsWith("äbćδe", "äbć", "UTF8_BINARY", true); @@ -377,16 +578,38 @@ public void testStartsWith() throws SparkException { assertStartsWith("ab世De", "AB世dE", "UNICODE_CI", true); assertStartsWith("äbćδe", "ÄbćδE", "UNICODE_CI", true); assertStartsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); - // Characters with the same binary lowercase representation assertStartsWith("Kelvin.", "Kelvin", "UTF8_LCASE", true); assertStartsWith("Kelvin.", "Kelvin", "UTF8_LCASE", true); assertStartsWith("KKelvin.", "KKelvin", "UTF8_LCASE", true); assertStartsWith("2 Kelvin.", "2 Kelvin", "UTF8_LCASE", true); assertStartsWith("2 Kelvin.", "2 Kelvin", "UTF8_LCASE", true); assertStartsWith("KKelvin.", "KKelvin,", "UTF8_LCASE", false); - // Case-variable character length - assertStartsWith("i̇", "i", "UNICODE_CI", false); - assertStartsWith("i̇", "İ", "UNICODE_CI", true); + assertStartsWith("Ћао", "Ца", "sr_Cyrl_CI_AI", false); + assertStartsWith("Ћао", "ћа", "sr_Cyrl_CI_AI", true); + assertStartsWith("Ćao", "Ca", "SR_CI", false); + assertStartsWith("Ćao", "Ca", "SR_CI_AI", true); + assertStartsWith("Ćao", "Ća", "SR", true); + // Case variation. + assertStartsWith("aBcDe", "abc", "UTF8_BINARY", false); + assertStartsWith("aBcDe", "aBc", "UTF8_BINARY", true); + assertStartsWith("aBcDe", "abcde", "UNICODE", false); + assertStartsWith("aBcDe", "aBcDe", "UNICODE", true); + assertStartsWith("aBcDe", "abc", "UTF8_LCASE", true); + assertStartsWith("aBcDe", "ABC", "UTF8_LCASE", true); + assertStartsWith("aBcDe", "abcde", "UNICODE_CI", true); + assertStartsWith("aBcDe", "AbCdE", "UNICODE_CI", true); + // Accent variation. + assertStartsWith("aBcDe", "abć", "UTF8_BINARY", false); + assertStartsWith("aBcDe", "aBć", "UTF8_BINARY", false); + assertStartsWith("aBcDe", "abćde", "UNICODE", false); + assertStartsWith("aBcDe", "aBćDe", "UNICODE", false); + assertStartsWith("aBcDe", "abć", "UTF8_LCASE", false); + assertStartsWith("aBcDe", "ABĆ", "UTF8_LCASE", false); + assertStartsWith("aBcDe", "abćde", "UNICODE_CI", false); + assertStartsWith("aBcDe", "AbĆdE", "UNICODE_CI", false); + // One-to-many case mapping (e.g. Turkish dotted I). + assertStartsWith("i\u0307", "i", "UNICODE_CI", false); + assertStartsWith("i\u0307", "İ", "UNICODE_CI", true); assertStartsWith("İ", "i", "UNICODE_CI", false); assertStartsWith("İİİ", "i̇i̇", "UNICODE_CI", true); assertStartsWith("İİİ", "i̇i", "UNICODE_CI", false); @@ -394,14 +617,14 @@ public void testStartsWith() throws SparkException { assertStartsWith("i̇İi̇i̇", "İi̇İi", "UNICODE_CI", false); assertStartsWith("i̇onic", "io", "UNICODE_CI", false); assertStartsWith("i̇onic", "Io", "UNICODE_CI", false); - assertStartsWith("i̇onic", "i̇o", "UNICODE_CI", true); + assertStartsWith("i̇onic", "i\u0307o", "UNICODE_CI", true); assertStartsWith("i̇onic", "İo", "UNICODE_CI", true); assertStartsWith("İonic", "io", "UNICODE_CI", false); assertStartsWith("İonic", "Io", "UNICODE_CI", false); - assertStartsWith("İonic", "i̇o", "UNICODE_CI", true); + assertStartsWith("İonic", "i\u0307o", "UNICODE_CI", true); assertStartsWith("İonic", "İo", "UNICODE_CI", true); - assertStartsWith("i̇", "i", "UTF8_LCASE", true); // != UNICODE_CI - assertStartsWith("i̇", "İ", "UTF8_LCASE", true); + assertStartsWith("i\u0307", "i", "UTF8_LCASE", true); // != UNICODE_CI + assertStartsWith("i\u0307", "İ", "UTF8_LCASE", true); assertStartsWith("İ", "i", "UTF8_LCASE", false); assertStartsWith("İİİ", "i̇i̇", "UTF8_LCASE", true); assertStartsWith("İİİ", "i̇i", "UTF8_LCASE", false); @@ -409,16 +632,136 @@ public void testStartsWith() throws SparkException { assertStartsWith("i̇İi̇i̇", "İi̇İi", "UTF8_LCASE", true); // != UNICODE_CI assertStartsWith("i̇onic", "io", "UTF8_LCASE", false); assertStartsWith("i̇onic", "Io", "UTF8_LCASE", false); - assertStartsWith("i̇onic", "i̇o", "UTF8_LCASE", true); + assertStartsWith("i̇onic", "i\u0307o", "UTF8_LCASE", true); assertStartsWith("i̇onic", "İo", "UTF8_LCASE", true); assertStartsWith("İonic", "io", "UTF8_LCASE", false); assertStartsWith("İonic", "Io", "UTF8_LCASE", false); - assertStartsWith("İonic", "i̇o", "UTF8_LCASE", true); + assertStartsWith("İonic", "i\u0307o", "UTF8_LCASE", true); assertStartsWith("İonic", "İo", "UTF8_LCASE", true); + assertStartsWith("oİ", "oİ", "UTF8_LCASE", true); + assertStartsWith("oİ", "oi̇", "UTF8_LCASE", true); + // Conditional case mapping (e.g. Greek sigmas). + assertStartsWith("σ", "σ", "UTF8_BINARY", true); + assertStartsWith("σ", "ς", "UTF8_BINARY", false); + assertStartsWith("σ", "Σ", "UTF8_BINARY", false); + assertStartsWith("ς", "σ", "UTF8_BINARY", false); + assertStartsWith("ς", "ς", "UTF8_BINARY", true); + assertStartsWith("ς", "Σ", "UTF8_BINARY", false); + assertStartsWith("Σ", "σ", "UTF8_BINARY", false); + assertStartsWith("Σ", "ς", "UTF8_BINARY", false); + assertStartsWith("Σ", "Σ", "UTF8_BINARY", true); + assertStartsWith("σ", "σ", "UTF8_LCASE", true); + assertStartsWith("σ", "ς", "UTF8_LCASE", true); + assertStartsWith("σ", "Σ", "UTF8_LCASE", true); + assertStartsWith("ς", "σ", "UTF8_LCASE", true); + assertStartsWith("ς", "ς", "UTF8_LCASE", true); + assertStartsWith("ς", "Σ", "UTF8_LCASE", true); + assertStartsWith("Σ", "σ", "UTF8_LCASE", true); + assertStartsWith("Σ", "ς", "UTF8_LCASE", true); + assertStartsWith("Σ", "Σ", "UTF8_LCASE", true); + assertStartsWith("σ", "σ", "UNICODE", true); + assertStartsWith("σ", "ς", "UNICODE", false); + assertStartsWith("σ", "Σ", "UNICODE", false); + assertStartsWith("ς", "σ", "UNICODE", false); + assertStartsWith("ς", "ς", "UNICODE", true); + assertStartsWith("ς", "Σ", "UNICODE", false); + assertStartsWith("Σ", "σ", "UNICODE", false); + assertStartsWith("Σ", "ς", "UNICODE", false); + assertStartsWith("Σ", "Σ", "UNICODE", true); + assertStartsWith("σ", "σ", "UNICODE_CI", true); + assertStartsWith("σ", "ς", "UNICODE_CI", true); + assertStartsWith("σ", "Σ", "UNICODE_CI", true); + assertStartsWith("ς", "σ", "UNICODE_CI", true); + assertStartsWith("ς", "ς", "UNICODE_CI", true); + assertStartsWith("ς", "Σ", "UNICODE_CI", true); + assertStartsWith("Σ", "σ", "UNICODE_CI", true); + assertStartsWith("Σ", "ς", "UNICODE_CI", true); + assertStartsWith("Σ", "Σ", "UNICODE_CI", true); + assertStartsWith("ΣΑΛΑΤΑ", "Σ", "UTF8_BINARY", true); + assertStartsWith("ΣΑΛΑΤΑ", "σ", "UTF8_BINARY", false); + assertStartsWith("ΣΑΛΑΤΑ", "ς", "UTF8_BINARY", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UTF8_BINARY", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UTF8_BINARY", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UTF8_BINARY", false); + assertStartsWith("ΣΑΛΑΤΑ", "Σ", "UTF8_LCASE", true); + assertStartsWith("ΣΑΛΑΤΑ", "σ", "UTF8_LCASE", true); + assertStartsWith("ΣΑΛΑΤΑ", "ς", "UTF8_LCASE", true); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UTF8_LCASE", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UTF8_LCASE", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UTF8_LCASE", false); + assertStartsWith("ΣΑΛΑΤΑ", "Σ", "UNICODE", true); + assertStartsWith("ΣΑΛΑΤΑ", "σ", "UNICODE", false); + assertStartsWith("ΣΑΛΑΤΑ", "ς", "UNICODE", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UNICODE", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UNICODE", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UNICODE", false); + assertStartsWith("ΣΑΛΑΤΑ", "Σ", "UNICODE_CI", true); + assertStartsWith("ΣΑΛΑΤΑ", "σ", "UNICODE_CI", true); + assertStartsWith("ΣΑΛΑΤΑ", "ς", "UNICODE_CI", true); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UNICODE_CI", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UNICODE_CI", false); + assertStartsWith("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UNICODE_CI", false); + // Surrogate pairs. + assertStartsWith("a🙃b🙃c", "x", "UTF8_BINARY", false); + assertStartsWith("a🙃b🙃c", "x", "UTF8_LCASE", false); + assertStartsWith("a🙃b🙃c", "x", "UNICODE", false); + assertStartsWith("a🙃b🙃c", "x", "UNICODE_CI", false); + assertStartsWith("a🙃b🙃c", "b", "UTF8_BINARY", false); + assertStartsWith("a🙃b🙃c", "b", "UTF8_LCASE", false); + assertStartsWith("a🙃b🙃c", "b", "UNICODE", false); + assertStartsWith("a🙃b🙃c", "b", "UNICODE_CI", false); + assertStartsWith("a🙃b🙃c", "a🙃b", "UTF8_BINARY", true); + assertStartsWith("a🙃b🙃c", "a🙃b", "UTF8_LCASE", true); + assertStartsWith("a🙃b🙃c", "a🙃b", "UNICODE", true); + assertStartsWith("a🙃b🙃c", "a🙃b", "UNICODE_CI", true); + assertStartsWith("a🙃b🙃c", "b🙃c", "UTF8_BINARY", false); + assertStartsWith("a🙃b🙃c", "b🙃c", "UTF8_LCASE", false); + assertStartsWith("a🙃b🙃c", "b🙃c", "UNICODE", false); + assertStartsWith("a🙃b🙃c", "b🙃c", "UNICODE_CI", false); + assertStartsWith("a🙃b🙃c", "a🙃b🙃c", "UTF8_BINARY", true); + assertStartsWith("a🙃b🙃c", "a🙃b🙃c", "UTF8_LCASE", true); + assertStartsWith("a🙃b🙃c", "a🙃b🙃c", "UNICODE", true); + assertStartsWith("a🙃b🙃c", "a🙃b🙃c", "UNICODE_CI", true); + assertStartsWith("😀😆😃😄", "😄😆", "UTF8_BINARY", false); + assertStartsWith("😀😆😃😄", "😄😆", "UTF8_LCASE", false); + assertStartsWith("😀😆😃😄", "😄😆", "UNICODE", false); + assertStartsWith("😀😆😃😄", "😄😆", "UNICODE_CI", false); + assertStartsWith("😀😆😃😄", "😆😃", "UTF8_BINARY", false); + assertStartsWith("😀😆😃😄", "😆😃", "UTF8_LCASE", false); + assertStartsWith("😀😆😃😄", "😆😃", "UNICODE", false); + assertStartsWith("😀😆😃😄", "😆😃", "UNICODE_CI", false); + assertStartsWith("😀😆😃😄", "😀😆", "UTF8_BINARY", true); + assertStartsWith("😀😆😃😄", "😀😆", "UTF8_LCASE", true); + assertStartsWith("😀😆😃😄", "😀😆", "UNICODE", true); + assertStartsWith("😀😆😃😄", "😀😆", "UNICODE_CI", true); + assertStartsWith("😀😆😃😄", "😃😄", "UTF8_BINARY", false); + assertStartsWith("😀😆😃😄", "😃😄", "UTF8_LCASE", false); + assertStartsWith("😀😆😃😄", "😃😄", "UNICODE", false); + assertStartsWith("😀😆😃😄", "😃😄", "UNICODE_CI", false); + assertStartsWith("😀😆😃😄", "😀😆😃😄", "UTF8_BINARY", true); + assertStartsWith("😀😆😃😄", "😀😆😃😄", "UTF8_LCASE", true); + assertStartsWith("😀😆😃😄", "😀😆😃😄", "UNICODE", true); + assertStartsWith("😀😆😃😄", "😀😆😃😄", "UNICODE_CI", true); + assertStartsWith("𐐅", "𐐅", "UTF8_BINARY", true); + assertStartsWith("𐐅", "𐐅", "UTF8_LCASE", true); + assertStartsWith("𐐅", "𐐅", "UNICODE", true); + assertStartsWith("𐐅", "𐐅", "UNICODE_CI", true); + assertStartsWith("𐐅", "𐐭", "UTF8_BINARY", false); + assertStartsWith("𐐅", "𐐭", "UTF8_LCASE", true); + assertStartsWith("𐐅", "𐐭", "UNICODE", false); + assertStartsWith("𐐅", "𐐭", "UNICODE_CI", true); + assertStartsWith("𝔸", "𝔸", "UTF8_BINARY", true); + assertStartsWith("𝔸", "𝔸", "UTF8_LCASE", true); + assertStartsWith("𝔸", "𝔸", "UNICODE", true); + assertStartsWith("𝔸", "𝔸", "UNICODE_CI", true); } - private void assertEndsWith(String pattern, String suffix, String collationName, boolean expected) - throws SparkException { + /** + * Verify the behaviour of the `EndsWith` collation support class. + */ + + private void assertEndsWith(String pattern, String suffix, String collationName, + boolean expected) throws SparkException { UTF8String l = UTF8String.fromString(pattern); UTF8String r = UTF8String.fromString(suffix); int collationId = CollationFactory.collationNameToId(collationName); @@ -427,20 +770,42 @@ private void assertEndsWith(String pattern, String suffix, String collationName, @Test public void testEndsWith() throws SparkException { - // Edge cases - assertEndsWith("", "", "UTF8_BINARY", true); - assertEndsWith("c", "", "UTF8_BINARY", true); - assertEndsWith("", "c", "UTF8_BINARY", false); - assertEndsWith("", "", "UNICODE", true); - assertEndsWith("c", "", "UNICODE", true); - assertEndsWith("", "c", "UNICODE", false); - assertEndsWith("", "", "UTF8_LCASE", true); - assertEndsWith("c", "", "UTF8_LCASE", true); - assertEndsWith("", "c", "UTF8_LCASE", false); - assertEndsWith("", "", "UNICODE_CI", true); - assertEndsWith("c", "", "UNICODE_CI", true); - assertEndsWith("", "c", "UNICODE_CI", false); - // Basic tests + for (String collationName: testSupportedCollations) { + // Empty strings. + assertEndsWith("", "", collationName, true); + assertEndsWith("a", "", collationName, true); + assertEndsWith("", "x", collationName, false); + // Basic tests. + assertEndsWith("a", "a", collationName, true); + assertEndsWith("_a_", "a_", collationName, true); + assertEndsWith("_a_", "a", collationName, false); + assertEndsWith("%a%", "a%", collationName, true); + assertEndsWith("%a%", "a", collationName, false); + assertEndsWith("*a*", "a*", collationName, true); + assertEndsWith("*a*", "a", collationName, false); + assertEndsWith("?a?", "a?", collationName, true); + assertEndsWith("?a?", "a", collationName, false); + assertEndsWith("/a/", "a/", collationName, true); + assertEndsWith("/a/", "a", collationName, false); + assertEndsWith("abcde", "xyz", collationName, false); + assertEndsWith("abcde", "bcd", collationName, false); + assertEndsWith("abcde", "abc", collationName, false); + assertEndsWith("abcde", "cde", collationName, true); + assertEndsWith("abcde", "abcde", collationName, true); + assertEndsWith("你好", "x", collationName, false); + assertEndsWith("你好", "你", collationName, false); + assertEndsWith("你好", "好", collationName, true); + assertEndsWith("你好", "你好", collationName, true); + assertEndsWith("Γειά", "x", collationName, false); + assertEndsWith("Γειά", "ειά", collationName, true); + assertEndsWith("Γειά", "Γει", collationName, false); + assertEndsWith("Γειά", "Γειά", collationName, true); + assertEndsWith("Здраво", "x", collationName, false); + assertEndsWith("Здраво", "драво", collationName, true); + assertEndsWith("Здраво", "Здрав", collationName, false); + assertEndsWith("Здраво", "Здраво", collationName, true); + } + // Advanced tests. assertEndsWith("abcde", "cde", "UTF8_BINARY", true); assertEndsWith("abcde", "bde", "UTF8_BINARY", false); assertEndsWith("abcde", "fgh", "UTF8_BINARY", false); @@ -454,25 +819,6 @@ public void testEndsWith() throws SparkException { assertEndsWith("abcde", "CDe", "UNICODE_CI", true); assertEndsWith("abcde", "bcd", "UNICODE_CI", false); assertEndsWith("abcde", "123", "UNICODE_CI", false); - // Case variation - assertEndsWith("aBcDe", "cde", "UTF8_BINARY", false); - assertEndsWith("aBcDe", "cDe", "UTF8_BINARY", true); - assertEndsWith("aBcDe", "abcde", "UNICODE", false); - assertEndsWith("aBcDe", "aBcDe", "UNICODE", true); - assertEndsWith("aBcDe", "cde", "UTF8_LCASE", true); - assertEndsWith("aBcDe", "CDE", "UTF8_LCASE", true); - assertEndsWith("aBcDe", "abcde", "UNICODE_CI", true); - assertEndsWith("aBcDe", "AbCdE", "UNICODE_CI", true); - // Accent variation - assertEndsWith("aBcDe", "ćde", "UTF8_BINARY", false); - assertEndsWith("aBcDe", "ćDe", "UTF8_BINARY", false); - assertEndsWith("aBcDe", "abćde", "UNICODE", false); - assertEndsWith("aBcDe", "aBćDe", "UNICODE", false); - assertEndsWith("aBcDe", "ćde", "UTF8_LCASE", false); - assertEndsWith("aBcDe", "ĆDE", "UTF8_LCASE", false); - assertEndsWith("aBcDe", "abćde", "UNICODE_CI", false); - assertEndsWith("aBcDe", "AbĆdE", "UNICODE_CI", false); - // Variable byte length characters assertEndsWith("ab世De", "世De", "UTF8_BINARY", true); assertEndsWith("ab世De", "世dE", "UTF8_BINARY", false); assertEndsWith("äbćδe", "ćδe", "UTF8_BINARY", true); @@ -489,53 +835,196 @@ public void testEndsWith() throws SparkException { assertEndsWith("ab世De", "AB世dE", "UNICODE_CI", true); assertEndsWith("äbćδe", "ÄbćδE", "UNICODE_CI", true); assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); - // Characters with the same binary lowercase representation assertEndsWith("The Kelvin", "Kelvin", "UTF8_LCASE", true); assertEndsWith("The Kelvin", "Kelvin", "UTF8_LCASE", true); assertEndsWith("The KKelvin", "KKelvin", "UTF8_LCASE", true); assertEndsWith("The 2 Kelvin", "2 Kelvin", "UTF8_LCASE", true); assertEndsWith("The 2 Kelvin", "2 Kelvin", "UTF8_LCASE", true); assertEndsWith("The KKelvin", "KKelvin,", "UTF8_LCASE", false); - // Case-variable character length - assertEndsWith("i̇", "\u0307", "UNICODE_CI", false); - assertEndsWith("i̇", "İ", "UNICODE_CI", true); + assertEndsWith("Ћевапчићи", "цици", "sr_Cyrl_CI_AI", false); + assertEndsWith("Ћевапчићи", "чИЋи", "sr_Cyrl_CI_AI", true); + assertEndsWith("Ćevapčići", "cici", "SR_CI", false); + assertEndsWith("Ćevapčići", "cici", "SR_CI_AI", true); + assertEndsWith("Ćevapčići", "čići", "SR", true); + // Case variation. + assertEndsWith("aBcDe", "cde", "UTF8_BINARY", false); + assertEndsWith("aBcDe", "cDe", "UTF8_BINARY", true); + assertEndsWith("aBcDe", "abcde", "UNICODE", false); + assertEndsWith("aBcDe", "aBcDe", "UNICODE", true); + assertEndsWith("aBcDe", "cde", "UTF8_LCASE", true); + assertEndsWith("aBcDe", "CDE", "UTF8_LCASE", true); + assertEndsWith("aBcDe", "abcde", "UNICODE_CI", true); + assertEndsWith("aBcDe", "AbCdE", "UNICODE_CI", true); + // Accent variation. + assertEndsWith("aBcDe", "ćde", "UTF8_BINARY", false); + assertEndsWith("aBcDe", "ćDe", "UTF8_BINARY", false); + assertEndsWith("aBcDe", "abćde", "UNICODE", false); + assertEndsWith("aBcDe", "aBćDe", "UNICODE", false); + assertEndsWith("aBcDe", "ćde", "UTF8_LCASE", false); + assertEndsWith("aBcDe", "ĆDE", "UTF8_LCASE", false); + assertEndsWith("aBcDe", "abćde", "UNICODE_CI", false); + assertEndsWith("aBcDe", "AbĆdE", "UNICODE_CI", false); + // One-to-many case mapping (e.g. Turkish dotted I). + assertEndsWith("i\u0307", "\u0307", "UNICODE_CI", false); + assertEndsWith("i\u0307", "İ", "UNICODE_CI", true); assertEndsWith("İ", "i", "UNICODE_CI", false); assertEndsWith("İİİ", "i̇i̇", "UNICODE_CI", true); assertEndsWith("İİİ", "ii̇", "UNICODE_CI", false); assertEndsWith("İi̇İ", "İi̇", "UNICODE_CI", true); assertEndsWith("i̇İi̇i̇", "\u0307İi̇İ", "UNICODE_CI", false); - assertEndsWith("the i̇o", "io", "UNICODE_CI", false); - assertEndsWith("the i̇o", "Io", "UNICODE_CI", false); - assertEndsWith("the i̇o", "i̇o", "UNICODE_CI", true); - assertEndsWith("the i̇o", "İo", "UNICODE_CI", true); + assertEndsWith("the i\u0307o", "io", "UNICODE_CI", false); + assertEndsWith("the i\u0307o", "Io", "UNICODE_CI", false); + assertEndsWith("the i\u0307o", "i\u0307o", "UNICODE_CI", true); + assertEndsWith("the i\u0307o", "İo", "UNICODE_CI", true); assertEndsWith("the İo", "io", "UNICODE_CI", false); assertEndsWith("the İo", "Io", "UNICODE_CI", false); - assertEndsWith("the İo", "i̇o", "UNICODE_CI", true); + assertEndsWith("the İo", "i\u0307o", "UNICODE_CI", true); assertEndsWith("the İo", "İo", "UNICODE_CI", true); - assertEndsWith("i̇", "\u0307", "UTF8_LCASE", true); // != UNICODE_CI - assertEndsWith("i̇", "İ", "UTF8_LCASE", true); + assertEndsWith("i\u0307", "\u0307", "UTF8_LCASE", true); // != UNICODE_CI + assertEndsWith("i\u0307", "İ", "UTF8_LCASE", true); assertEndsWith("İ", "\u0307", "UTF8_LCASE", false); assertEndsWith("İİİ", "i̇i̇", "UTF8_LCASE", true); assertEndsWith("İİİ", "ii̇", "UTF8_LCASE", false); assertEndsWith("İi̇İ", "İi̇", "UTF8_LCASE", true); assertEndsWith("i̇İi̇i̇", "\u0307İi̇İ", "UTF8_LCASE", true); // != UNICODE_CI assertEndsWith("i̇İi̇i̇", "\u0307İİ", "UTF8_LCASE", false); - assertEndsWith("the i̇o", "io", "UTF8_LCASE", false); - assertEndsWith("the i̇o", "Io", "UTF8_LCASE", false); - assertEndsWith("the i̇o", "i̇o", "UTF8_LCASE", true); - assertEndsWith("the i̇o", "İo", "UTF8_LCASE", true); + assertEndsWith("the i\u0307o", "io", "UTF8_LCASE", false); + assertEndsWith("the i\u0307o", "Io", "UTF8_LCASE", false); + assertEndsWith("the i\u0307o", "i\u0307o", "UTF8_LCASE", true); + assertEndsWith("the i\u0307o", "İo", "UTF8_LCASE", true); assertEndsWith("the İo", "io", "UTF8_LCASE", false); assertEndsWith("the İo", "Io", "UTF8_LCASE", false); - assertEndsWith("the İo", "i̇o", "UTF8_LCASE", true); + assertEndsWith("the İo", "i\u0307o", "UTF8_LCASE", true); assertEndsWith("the İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "i̇o", "UTF8_LCASE", true); + // Conditional case mapping (e.g. Greek sigmas). + assertEndsWith("σ", "σ", "UTF8_BINARY", true); + assertEndsWith("σ", "ς", "UTF8_BINARY", false); + assertEndsWith("σ", "Σ", "UTF8_BINARY", false); + assertEndsWith("ς", "σ", "UTF8_BINARY", false); + assertEndsWith("ς", "ς", "UTF8_BINARY", true); + assertEndsWith("ς", "Σ", "UTF8_BINARY", false); + assertEndsWith("Σ", "σ", "UTF8_BINARY", false); + assertEndsWith("Σ", "ς", "UTF8_BINARY", false); + assertEndsWith("Σ", "Σ", "UTF8_BINARY", true); + assertEndsWith("σ", "σ", "UTF8_LCASE", true); + assertEndsWith("σ", "ς", "UTF8_LCASE", true); + assertEndsWith("σ", "Σ", "UTF8_LCASE", true); + assertEndsWith("ς", "σ", "UTF8_LCASE", true); + assertEndsWith("ς", "ς", "UTF8_LCASE", true); + assertEndsWith("ς", "Σ", "UTF8_LCASE", true); + assertEndsWith("Σ", "σ", "UTF8_LCASE", true); + assertEndsWith("Σ", "ς", "UTF8_LCASE", true); + assertEndsWith("Σ", "Σ", "UTF8_LCASE", true); + assertEndsWith("σ", "σ", "UNICODE", true); + assertEndsWith("σ", "ς", "UNICODE", false); + assertEndsWith("σ", "Σ", "UNICODE", false); + assertEndsWith("ς", "σ", "UNICODE", false); + assertEndsWith("ς", "ς", "UNICODE", true); + assertEndsWith("ς", "Σ", "UNICODE", false); + assertEndsWith("Σ", "σ", "UNICODE", false); + assertEndsWith("Σ", "ς", "UNICODE", false); + assertEndsWith("Σ", "Σ", "UNICODE", true); + assertEndsWith("σ", "σ", "UNICODE_CI", true); + assertEndsWith("σ", "ς", "UNICODE_CI", true); + assertEndsWith("σ", "Σ", "UNICODE_CI", true); + assertEndsWith("ς", "σ", "UNICODE_CI", true); + assertEndsWith("ς", "ς", "UNICODE_CI", true); + assertEndsWith("ς", "Σ", "UNICODE_CI", true); + assertEndsWith("Σ", "σ", "UNICODE_CI", true); + assertEndsWith("Σ", "ς", "UNICODE_CI", true); + assertEndsWith("Σ", "Σ", "UNICODE_CI", true); + assertEndsWith("ΣΑΛΑΤΑ", "Σ", "UTF8_BINARY", false); + assertEndsWith("ΣΑΛΑΤΑ", "σ", "UTF8_BINARY", false); + assertEndsWith("ΣΑΛΑΤΑ", "ς", "UTF8_BINARY", false); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UTF8_BINARY", true); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UTF8_BINARY", false); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UTF8_BINARY", false); + assertEndsWith("ΣΑΛΑΤΑ", "Σ", "UTF8_LCASE", false); + assertEndsWith("ΣΑΛΑΤΑ", "σ", "UTF8_LCASE", false); + assertEndsWith("ΣΑΛΑΤΑ", "ς", "UTF8_LCASE", false); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UTF8_LCASE", true); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UTF8_LCASE", true); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UTF8_LCASE", true); + assertEndsWith("ΣΑΛΑΤΑ", "Σ", "UNICODE", false); + assertEndsWith("ΣΑΛΑΤΑ", "σ", "UNICODE", false); + assertEndsWith("ΣΑΛΑΤΑ", "ς", "UNICODE", false); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UNICODE", true); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UNICODE", false); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UNICODE", false); + assertEndsWith("ΣΑΛΑΤΑ", "Σ", "UNICODE_CI", false); + assertEndsWith("ΣΑΛΑΤΑ", "σ", "UNICODE_CI", false); + assertEndsWith("ΣΑΛΑΤΑ", "ς", "UNICODE_CI", false); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "Σ", "UNICODE_CI", true); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "σ", "UNICODE_CI", true); + assertEndsWith("ΘΑΛΑΣΣΙΝΟΣ", "ς", "UNICODE_CI", true); + // Surrogate pairs. + assertEndsWith("a🙃b🙃c", "x", "UTF8_BINARY", false); + assertEndsWith("a🙃b🙃c", "x", "UTF8_LCASE", false); + assertEndsWith("a🙃b🙃c", "x", "UNICODE", false); + assertEndsWith("a🙃b🙃c", "x", "UNICODE_CI", false); + assertEndsWith("a🙃b🙃c", "b", "UTF8_BINARY", false); + assertEndsWith("a🙃b🙃c", "b", "UTF8_LCASE", false); + assertEndsWith("a🙃b🙃c", "b", "UNICODE", false); + assertEndsWith("a🙃b🙃c", "b", "UNICODE_CI", false); + assertEndsWith("a🙃b🙃c", "a🙃b", "UTF8_BINARY", false); + assertEndsWith("a🙃b🙃c", "a🙃b", "UTF8_LCASE", false); + assertEndsWith("a🙃b🙃c", "a🙃b", "UNICODE", false); + assertEndsWith("a🙃b🙃c", "a🙃b", "UNICODE_CI", false); + assertEndsWith("a🙃b🙃c", "b🙃c", "UTF8_BINARY", true); + assertEndsWith("a🙃b🙃c", "b🙃c", "UTF8_LCASE", true); + assertEndsWith("a🙃b🙃c", "b🙃c", "UNICODE", true); + assertEndsWith("a🙃b🙃c", "b🙃c", "UNICODE_CI", true); + assertEndsWith("a🙃b🙃c", "a🙃b🙃c", "UTF8_BINARY", true); + assertEndsWith("a🙃b🙃c", "a🙃b🙃c", "UTF8_LCASE", true); + assertEndsWith("a🙃b🙃c", "a🙃b🙃c", "UNICODE", true); + assertEndsWith("a🙃b🙃c", "a🙃b🙃c", "UNICODE_CI", true); + assertEndsWith("😀😆😃😄", "😄😆", "UTF8_BINARY", false); + assertEndsWith("😀😆😃😄", "😄😆", "UTF8_LCASE", false); + assertEndsWith("😀😆😃😄", "😄😆", "UNICODE", false); + assertEndsWith("😀😆😃😄", "😄😆", "UNICODE_CI", false); + assertEndsWith("😀😆😃😄", "😆😃", "UTF8_BINARY", false); + assertEndsWith("😀😆😃😄", "😆😃", "UTF8_LCASE", false); + assertEndsWith("😀😆😃😄", "😆😃", "UNICODE", false); + assertEndsWith("😀😆😃😄", "😆😃", "UNICODE_CI", false); + assertEndsWith("😀😆😃😄", "😀😆", "UTF8_BINARY", false); + assertEndsWith("😀😆😃😄", "😀😆", "UTF8_LCASE", false); + assertEndsWith("😀😆😃😄", "😀😆", "UNICODE", false); + assertEndsWith("😀😆😃😄", "😀😆", "UNICODE_CI", false); + assertEndsWith("😀😆😃😄", "😃😄", "UTF8_BINARY", true); + assertEndsWith("😀😆😃😄", "😃😄", "UTF8_LCASE", true); + assertEndsWith("😀😆😃😄", "😃😄", "UNICODE", true); + assertEndsWith("😀😆😃😄", "😃😄", "UNICODE_CI", true); + assertEndsWith("😀😆😃😄", "😀😆😃😄", "UTF8_BINARY", true); + assertEndsWith("😀😆😃😄", "😀😆😃😄", "UTF8_LCASE", true); + assertEndsWith("😀😆😃😄", "😀😆😃😄", "UNICODE", true); + assertEndsWith("😀😆😃😄", "😀😆😃😄", "UNICODE_CI", true); + assertEndsWith("𐐅", "𐐅", "UTF8_BINARY", true); + assertEndsWith("𐐅", "𐐅", "UTF8_LCASE", true); + assertEndsWith("𐐅", "𐐅", "UNICODE", true); + assertEndsWith("𐐅", "𐐅", "UNICODE_CI", true); + assertEndsWith("𐐅", "𐐭", "UTF8_BINARY", false); + assertEndsWith("𐐅", "𐐭", "UTF8_LCASE", true); + assertEndsWith("𐐅", "𐐭", "UNICODE", false); + assertEndsWith("𐐅", "𐐭", "UNICODE_CI", true); + assertEndsWith("𝔸", "𝔸", "UTF8_BINARY", true); + assertEndsWith("𝔸", "𝔸", "UTF8_LCASE", true); + assertEndsWith("𝔸", "𝔸", "UNICODE", true); + assertEndsWith("𝔸", "𝔸", "UNICODE_CI", true); } + /** + * Verify the behaviour of the `StringSplitSQL` collation support class. + */ + private void assertStringSplitSQL(String str, String delimiter, String collationName, UTF8String[] expected) throws SparkException { UTF8String s = UTF8String.fromString(str); UTF8String d = UTF8String.fromString(delimiter); int collationId = CollationFactory.collationNameToId(collationName); - assertArrayEquals(expected, CollationSupport.StringSplitSQL.exec(s, d, collationId)); + UTF8String[] result = CollationSupport.StringSplitSQL.exec(s, d, collationId); + assertArrayEquals(expected, result); } @Test @@ -553,7 +1042,21 @@ public void testStringSplitSQL() throws SparkException { var array_A_B = new UTF8String[] { UTF8String.fromString("A"), UTF8String.fromString("B") }; var array_a_e = new UTF8String[] { UTF8String.fromString("ä"), UTF8String.fromString("e") }; var array_Aa_bB = new UTF8String[] { UTF8String.fromString("Aa"), UTF8String.fromString("bB") }; - // Edge cases + var array_Turkish_uppercase_dotted_I = new UTF8String[] { UTF8String.fromString("İ") }; + var array_Turkish_lowercase_dotted_i = new UTF8String[] { UTF8String.fromString("i\u0307") }; + var array_i = new UTF8String[] { UTF8String.fromString("i"), UTF8String.fromString("") }; + var array_dot = new UTF8String[] { UTF8String.fromString(""), UTF8String.fromString("\u0307") }; + var array_AiB = new UTF8String[] { UTF8String.fromString("Ai\u0307B") }; + var array_AIB = new UTF8String[] { UTF8String.fromString("AİB") }; + var array_small_nonfinal_sigma = new UTF8String[] { UTF8String.fromString("σ") }; + var array_small_final_sigma = new UTF8String[] { UTF8String.fromString("ς") }; + var array_capital_sigma = new UTF8String[] { UTF8String.fromString("Σ") }; + var array_a_b_c = new UTF8String[] { UTF8String.fromString("a"), UTF8String.fromString("b"), + UTF8String.fromString("c") }; + var array_emojis = new UTF8String[] { UTF8String.fromString("😀"), UTF8String.fromString("😄") }; + var array_AOB = new UTF8String[] { UTF8String.fromString("A𐐅B") }; + var array_AoB = new UTF8String[] { UTF8String.fromString("A𐐭B") }; + // Empty strings. assertStringSplitSQL("", "", "UTF8_BINARY", empty_match); assertStringSplitSQL("abc", "", "UTF8_BINARY", array_abc); assertStringSplitSQL("", "abc", "UTF8_BINARY", empty_match); @@ -566,7 +1069,7 @@ public void testStringSplitSQL() throws SparkException { assertStringSplitSQL("", "", "UNICODE_CI", empty_match); assertStringSplitSQL("abc", "", "UNICODE_CI", array_abc); assertStringSplitSQL("", "abc", "UNICODE_CI", empty_match); - // Basic tests + // Basic tests. assertStringSplitSQL("1a2", "a", "UTF8_BINARY", array_1_2); assertStringSplitSQL("1a2", "A", "UTF8_BINARY", array_1a2); assertStringSplitSQL("1a2", "b", "UTF8_BINARY", array_1a2); @@ -580,25 +1083,7 @@ public void testStringSplitSQL() throws SparkException { assertStringSplitSQL("1a2", "A", "UNICODE_CI", array_1_2); assertStringSplitSQL("1a2", "1A2", "UNICODE_CI", full_match); assertStringSplitSQL("1a2", "123", "UNICODE_CI", array_1a2); - // Case variation - assertStringSplitSQL("AaXbB", "x", "UTF8_BINARY", array_AaXbB); - assertStringSplitSQL("AaXbB", "X", "UTF8_BINARY", array_Aa_bB); - assertStringSplitSQL("AaXbB", "axb", "UNICODE", array_AaXbB); - assertStringSplitSQL("AaXbB", "aXb", "UNICODE", array_A_B); - assertStringSplitSQL("AaXbB", "axb", "UTF8_LCASE", array_A_B); - assertStringSplitSQL("AaXbB", "AXB", "UTF8_LCASE", array_A_B); - assertStringSplitSQL("AaXbB", "axb", "UNICODE_CI", array_A_B); - assertStringSplitSQL("AaXbB", "AxB", "UNICODE_CI", array_A_B); - // Accent variation - assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY", array_aBcDe); - assertStringSplitSQL("aBcDe", "BćD", "UTF8_BINARY", array_aBcDe); - assertStringSplitSQL("aBcDe", "abćde", "UNICODE", array_aBcDe); - assertStringSplitSQL("aBcDe", "aBćDe", "UNICODE", array_aBcDe); - assertStringSplitSQL("aBcDe", "bćd", "UTF8_LCASE", array_aBcDe); - assertStringSplitSQL("aBcDe", "BĆD", "UTF8_LCASE", array_aBcDe); - assertStringSplitSQL("aBcDe", "abćde", "UNICODE_CI", array_aBcDe); - assertStringSplitSQL("aBcDe", "AbĆdE", "UNICODE_CI", array_aBcDe); - // Variable byte length characters + // Advanced tests. assertStringSplitSQL("äb世De", "b世D", "UTF8_BINARY", array_a_e); assertStringSplitSQL("äb世De", "B世d", "UTF8_BINARY", array_special); assertStringSplitSQL("äbćδe", "bćδ", "UTF8_BINARY", array_a_e); @@ -615,10 +1100,123 @@ public void testStringSplitSQL() throws SparkException { assertStringSplitSQL("äb世De", "AB世dE", "UNICODE_CI", array_special); assertStringSplitSQL("äbćδe", "ÄbćδE", "UNICODE_CI", full_match); assertStringSplitSQL("äbćδe", "ÄBcΔÉ", "UNICODE_CI", array_abcde); + // Case variation. + assertStringSplitSQL("AaXbB", "x", "UTF8_BINARY", array_AaXbB); + assertStringSplitSQL("AaXbB", "X", "UTF8_BINARY", array_Aa_bB); + assertStringSplitSQL("AaXbB", "axb", "UNICODE", array_AaXbB); + assertStringSplitSQL("AaXbB", "aXb", "UNICODE", array_A_B); + assertStringSplitSQL("AaXbB", "axb", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("AaXbB", "AXB", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("AaXbB", "axb", "UNICODE_CI", array_A_B); + assertStringSplitSQL("AaXbB", "AxB", "UNICODE_CI", array_A_B); + // Accent variation. + assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY", array_aBcDe); + assertStringSplitSQL("aBcDe", "BćD", "UTF8_BINARY", array_aBcDe); + assertStringSplitSQL("aBcDe", "abćde", "UNICODE", array_aBcDe); + assertStringSplitSQL("aBcDe", "aBćDe", "UNICODE", array_aBcDe); + assertStringSplitSQL("aBcDe", "bćd", "UTF8_LCASE", array_aBcDe); + assertStringSplitSQL("aBcDe", "BĆD", "UTF8_LCASE", array_aBcDe); + assertStringSplitSQL("aBcDe", "abćde", "UNICODE_CI", array_aBcDe); + assertStringSplitSQL("aBcDe", "AbĆdE", "UNICODE_CI", array_aBcDe); + // One-to-many case mapping (e.g. Turkish dotted I). + assertStringSplitSQL("İ", "i", "UTF8_BINARY", array_Turkish_uppercase_dotted_I); + assertStringSplitSQL("İ", "i", "UTF8_LCASE", array_Turkish_uppercase_dotted_I); + assertStringSplitSQL("İ", "i", "UNICODE", array_Turkish_uppercase_dotted_I); + assertStringSplitSQL("İ", "i", "UNICODE_CI", array_Turkish_uppercase_dotted_I); + assertStringSplitSQL("İ", "\u0307", "UTF8_BINARY", array_Turkish_uppercase_dotted_I); + assertStringSplitSQL("İ", "\u0307", "UTF8_LCASE", array_Turkish_uppercase_dotted_I); + assertStringSplitSQL("İ", "\u0307", "UNICODE", array_Turkish_uppercase_dotted_I); + assertStringSplitSQL("İ", "\u0307", "UNICODE_CI", array_Turkish_uppercase_dotted_I); + assertStringSplitSQL("i\u0307", "i", "UTF8_BINARY", array_dot); + assertStringSplitSQL("i\u0307", "i", "UTF8_LCASE", array_dot); + assertStringSplitSQL("i\u0307", "i", "UNICODE", array_Turkish_lowercase_dotted_i); + assertStringSplitSQL("i\u0307", "i", "UNICODE_CI", array_Turkish_lowercase_dotted_i); + assertStringSplitSQL("i\u0307", "\u0307", "UTF8_BINARY", array_i); + assertStringSplitSQL("i\u0307", "\u0307", "UTF8_LCASE", array_i); + assertStringSplitSQL("i\u0307", "\u0307", "UNICODE", array_Turkish_lowercase_dotted_i); + assertStringSplitSQL("i\u0307", "\u0307", "UNICODE_CI", array_Turkish_lowercase_dotted_i); + assertStringSplitSQL("AİB", "İ", "UTF8_BINARY", array_A_B); + assertStringSplitSQL("AİB", "İ", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("AİB", "İ", "UNICODE", array_A_B); + assertStringSplitSQL("AİB", "İ", "UNICODE_CI", array_A_B); + assertStringSplitSQL("AİB", "i\u0307", "UTF8_BINARY", array_AIB); + assertStringSplitSQL("AİB", "i\u0307", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("AİB", "i\u0307", "UNICODE", array_AIB); + assertStringSplitSQL("AİB", "i\u0307", "UNICODE_CI", array_A_B); + assertStringSplitSQL("Ai\u0307B", "İ", "UTF8_BINARY", array_AiB); + assertStringSplitSQL("Ai\u0307B", "İ", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("Ai\u0307B", "İ", "UNICODE", array_AiB); + assertStringSplitSQL("Ai\u0307B", "İ", "UNICODE_CI", array_A_B); + assertStringSplitSQL("Ai\u0307B", "i\u0307", "UTF8_BINARY", array_A_B); + assertStringSplitSQL("Ai\u0307B", "i\u0307", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("Ai\u0307B", "i\u0307", "UNICODE", array_A_B); + assertStringSplitSQL("Ai\u0307B", "i\u0307", "UNICODE_CI", array_A_B); + // Conditional case mapping (e.g. Greek sigmas). + assertStringSplitSQL("σ", "σ", "UTF8_BINARY", full_match); + assertStringSplitSQL("σ", "σ", "UTF8_LCASE", full_match); + assertStringSplitSQL("σ", "σ", "UNICODE", full_match); + assertStringSplitSQL("σ", "σ", "UNICODE_CI", full_match); + assertStringSplitSQL("σ", "ς", "UTF8_BINARY", array_small_nonfinal_sigma); + assertStringSplitSQL("σ", "ς", "UTF8_LCASE", full_match); + assertStringSplitSQL("σ", "ς", "UNICODE", array_small_nonfinal_sigma); + assertStringSplitSQL("σ", "ς", "UNICODE_CI", full_match); + assertStringSplitSQL("σ", "Σ", "UTF8_BINARY", array_small_nonfinal_sigma); + assertStringSplitSQL("σ", "Σ", "UTF8_LCASE", full_match); + assertStringSplitSQL("σ", "Σ", "UNICODE", array_small_nonfinal_sigma); + assertStringSplitSQL("σ", "Σ", "UNICODE_CI", full_match); + assertStringSplitSQL("ς", "σ", "UTF8_BINARY", array_small_final_sigma); + assertStringSplitSQL("ς", "σ", "UTF8_LCASE", full_match); + assertStringSplitSQL("ς", "σ", "UNICODE", array_small_final_sigma); + assertStringSplitSQL("ς", "σ", "UNICODE_CI", full_match); + assertStringSplitSQL("ς", "ς", "UTF8_BINARY", full_match); + assertStringSplitSQL("ς", "ς", "UTF8_LCASE", full_match); + assertStringSplitSQL("ς", "ς", "UNICODE", full_match); + assertStringSplitSQL("ς", "ς", "UNICODE_CI", full_match); + assertStringSplitSQL("ς", "Σ", "UTF8_BINARY", array_small_final_sigma); + assertStringSplitSQL("ς", "Σ", "UTF8_LCASE", full_match); + assertStringSplitSQL("ς", "Σ", "UNICODE", array_small_final_sigma); + assertStringSplitSQL("ς", "Σ", "UNICODE_CI", full_match); + assertStringSplitSQL("Σ", "σ", "UTF8_BINARY", array_capital_sigma); + assertStringSplitSQL("Σ", "σ", "UTF8_LCASE", full_match); + assertStringSplitSQL("Σ", "σ", "UNICODE", array_capital_sigma); + assertStringSplitSQL("Σ", "σ", "UNICODE_CI", full_match); + assertStringSplitSQL("Σ", "ς", "UTF8_BINARY", array_capital_sigma); + assertStringSplitSQL("Σ", "ς", "UTF8_LCASE", full_match); + assertStringSplitSQL("Σ", "ς", "UNICODE", array_capital_sigma); + assertStringSplitSQL("Σ", "ς", "UNICODE_CI", full_match); + assertStringSplitSQL("Σ", "Σ", "UTF8_BINARY", full_match); + assertStringSplitSQL("Σ", "Σ", "UTF8_LCASE", full_match); + assertStringSplitSQL("Σ", "Σ", "UNICODE", full_match); + assertStringSplitSQL("Σ", "Σ", "UNICODE_CI", full_match); + // Surrogate pairs. + assertStringSplitSQL("a🙃b🙃c", "🙃", "UTF8_BINARY", array_a_b_c); + assertStringSplitSQL("a🙃b🙃c", "🙃", "UTF8_LCASE", array_a_b_c); + assertStringSplitSQL("a🙃b🙃c", "🙃", "UNICODE", array_a_b_c); + assertStringSplitSQL("a🙃b🙃c", "🙃", "UNICODE_CI", array_a_b_c); + assertStringSplitSQL("😀😆😃😄", "😆😃", "UTF8_BINARY", array_emojis); + assertStringSplitSQL("😀😆😃😄", "😆😃", "UTF8_LCASE", array_emojis); + assertStringSplitSQL("😀😆😃😄", "😆😃", "UNICODE", array_emojis); + assertStringSplitSQL("😀😆😃😄", "😆😃", "UNICODE_CI", array_emojis); + assertStringSplitSQL("A𐐅B", "𐐅", "UTF8_BINARY", array_A_B); + assertStringSplitSQL("A𐐅B", "𐐅", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("A𐐅B", "𐐅", "UNICODE", array_A_B); + assertStringSplitSQL("A𐐅B", "𐐅", "UNICODE_CI", array_A_B); + assertStringSplitSQL("A𐐅B", "𐐭", "UTF8_BINARY", array_AOB); + assertStringSplitSQL("A𐐅B", "𐐭", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("A𐐅B", "𐐭", "UNICODE", array_AOB); + assertStringSplitSQL("A𐐅B", "𐐭", "UNICODE_CI", array_A_B); + assertStringSplitSQL("A𐐭B", "𐐅", "UTF8_BINARY", array_AoB); + assertStringSplitSQL("A𐐭B", "𐐅", "UTF8_LCASE", array_A_B); + assertStringSplitSQL("A𐐭B", "𐐅", "UNICODE", array_AoB); + assertStringSplitSQL("A𐐭B", "𐐅", "UNICODE_CI", array_A_B); } + /** + * Verify the behaviour of the `Upper` collation support class. + */ + private void assertUpper(String target, String collationName, String expected) - throws SparkException { + throws SparkException { UTF8String target_utf8 = UTF8String.fromString(target); UTF8String expected_utf8 = UTF8String.fromString(expected); int collationId = CollationFactory.collationNameToId(collationName); @@ -631,52 +1229,57 @@ private void assertUpper(String target, String collationName, String expected) @Test public void testUpper() throws SparkException { - // Edge cases - assertUpper("", "UTF8_BINARY", ""); - assertUpper("", "UTF8_LCASE", ""); - assertUpper("", "UNICODE", ""); - assertUpper("", "UNICODE_CI", ""); - // Basic tests - assertUpper("abcde", "UTF8_BINARY", "ABCDE"); - assertUpper("abcde", "UTF8_LCASE", "ABCDE"); - assertUpper("abcde", "UNICODE", "ABCDE"); - assertUpper("abcde", "UNICODE_CI", "ABCDE"); - // Uppercase present - assertUpper("AbCdE", "UTF8_BINARY", "ABCDE"); - assertUpper("aBcDe", "UTF8_BINARY", "ABCDE"); - assertUpper("AbCdE", "UTF8_LCASE", "ABCDE"); - assertUpper("aBcDe", "UTF8_LCASE", "ABCDE"); - assertUpper("AbCdE", "UNICODE", "ABCDE"); - assertUpper("aBcDe", "UNICODE", "ABCDE"); - assertUpper("AbCdE", "UNICODE_CI", "ABCDE"); - assertUpper("aBcDe", "UNICODE_CI", "ABCDE"); - // Accent letters - assertUpper("aBćDe","UTF8_BINARY", "ABĆDE"); - assertUpper("aBćDe","UTF8_LCASE", "ABĆDE"); - assertUpper("aBćDe","UNICODE", "ABĆDE"); - assertUpper("aBćDe","UNICODE_CI", "ABĆDE"); - // Variable byte length characters - assertUpper("ab世De", "UTF8_BINARY", "AB世DE"); - assertUpper("äbćδe", "UTF8_BINARY", "ÄBĆΔE"); - assertUpper("ab世De", "UTF8_LCASE", "AB世DE"); - assertUpper("äbćδe", "UTF8_LCASE", "ÄBĆΔE"); - assertUpper("ab世De", "UNICODE", "AB世DE"); - assertUpper("äbćδe", "UNICODE", "ÄBĆΔE"); - assertUpper("ab世De", "UNICODE_CI", "AB世DE"); - assertUpper("äbćδe", "UNICODE_CI", "ÄBĆΔE"); - // Case-variable character length - assertUpper("i\u0307o", "UTF8_BINARY","I\u0307O"); - assertUpper("i\u0307o", "UTF8_LCASE","I\u0307O"); - assertUpper("i\u0307o", "UNICODE","I\u0307O"); - assertUpper("i\u0307o", "UNICODE_CI","I\u0307O"); - assertUpper("ß fi ffi ff st ῗ", "UTF8_BINARY","SS FI FFI FF ST \u0399\u0308\u0342"); - assertUpper("ß fi ffi ff st ῗ", "UTF8_LCASE","SS FI FFI FF ST \u0399\u0308\u0342"); - assertUpper("ß fi ffi ff st ῗ", "UNICODE","SS FI FFI FF ST \u0399\u0308\u0342"); - assertUpper("ß fi ffi ff st ῗ", "UNICODE","SS FI FFI FF ST \u0399\u0308\u0342"); + for (String collationName: testSupportedCollations) { + // Empty strings. + assertUpper("", collationName, ""); + // Basic tests. + assertUpper("abcde", collationName, "ABCDE"); + assertUpper("AbCdE", collationName, "ABCDE"); + assertUpper("aBcDe", collationName, "ABCDE"); + assertUpper("ABCDE", collationName, "ABCDE"); + // Advanced tests. + assertUpper("aBćDe", collationName, "ABĆDE"); + assertUpper("ab世De", collationName, "AB世DE"); + assertUpper("äbćδe", collationName, "ÄBĆΔE"); + assertUpper("AbĆdE", collationName, "ABĆDE"); + assertUpper("aB世De", collationName, "AB世DE"); + assertUpper("ÄBĆΔE", collationName, "ÄBĆΔE"); + // One-to-many case mapping (e.g. Turkish dotted I). + assertUpper("İ", collationName, "İ"); + assertUpper("i\u0307", collationName,"I\u0307"); + assertUpper("İonic", collationName, "İONIC"); + assertUpper("i\u0307onic", collationName,"I\u0307ONIC"); + assertUpper("FIDELİO", collationName, "FIDELİO"); + // Conditional case mapping (e.g. Greek sigmas). + assertUpper("σ", collationName, "Σ"); + assertUpper("σ", collationName, "Σ"); + assertUpper("ς", collationName, "Σ"); + assertUpper("Σ", collationName, "Σ"); + assertUpper("ΣΑΛΑΤΑ", collationName, "ΣΑΛΑΤΑ"); + assertUpper("σαλατα", collationName, "ΣΑΛΑΤΑ"); + assertUpper("ςαλατα", collationName, "ΣΑΛΑΤΑ"); + assertUpper("ΘΑΛΑΣΣΙΝΟΣ", collationName, "ΘΑΛΑΣΣΙΝΟΣ"); + assertUpper("θαλασσινοσ", collationName, "ΘΑΛΑΣΣΙΝΟΣ"); + assertUpper("θαλασσινος", collationName, "ΘΑΛΑΣΣΙΝΟΣ"); + // Surrogate pairs. + assertUpper("a🙃B🙃c", collationName, "A🙃B🙃C"); + assertUpper("😄 😆", collationName, "😄 😆"); + assertUpper("😀😆😃😄", collationName, "😀😆😃😄"); + assertUpper("𝔸", collationName, "𝔸"); + assertUpper("𐐅", collationName, "𐐅"); + assertUpper("𐐭", collationName, "𐐅"); + assertUpper("𐐭𝔸", collationName, "𐐅𝔸"); + // Ligatures. + assertUpper("ß fi ffi ff st ῗ", collationName,"SS FI FFI FF ST \u0399\u0308\u0342"); + } } + /** + * Verify the behaviour of the `Lower` collation support class. + */ + private void assertLower(String target, String collationName, String expected) - throws SparkException { + throws SparkException { UTF8String target_utf8 = UTF8String.fromString(target); UTF8String expected_utf8 = UTF8String.fromString(expected); int collationId = CollationFactory.collationNameToId(collationName); @@ -689,48 +1292,56 @@ private void assertLower(String target, String collationName, String expected) @Test public void testLower() throws SparkException { - // Edge cases - assertLower("", "UTF8_BINARY", ""); - assertLower("", "UTF8_LCASE", ""); - assertLower("", "UNICODE", ""); - assertLower("", "UNICODE_CI", ""); - // Basic tests - assertLower("ABCDE", "UTF8_BINARY", "abcde"); - assertLower("ABCDE", "UTF8_LCASE", "abcde"); - assertLower("ABCDE", "UNICODE", "abcde"); - assertLower("ABCDE", "UNICODE_CI", "abcde"); - // Uppercase present - assertLower("AbCdE", "UTF8_BINARY", "abcde"); - assertLower("aBcDe", "UTF8_BINARY", "abcde"); - assertLower("AbCdE", "UTF8_LCASE", "abcde"); - assertLower("aBcDe", "UTF8_LCASE", "abcde"); - assertLower("AbCdE", "UNICODE", "abcde"); - assertLower("aBcDe", "UNICODE", "abcde"); - assertLower("AbCdE", "UNICODE_CI", "abcde"); - assertLower("aBcDe", "UNICODE_CI", "abcde"); - // Accent letters - assertLower("AbĆdE","UTF8_BINARY", "abćde"); - assertLower("AbĆdE","UTF8_LCASE", "abćde"); - assertLower("AbĆdE","UNICODE", "abćde"); - assertLower("AbĆdE","UNICODE_CI", "abćde"); - // Variable byte length characters - assertLower("aB世De", "UTF8_BINARY", "ab世de"); - assertLower("ÄBĆΔE", "UTF8_BINARY", "äbćδe"); - assertLower("aB世De", "UTF8_LCASE", "ab世de"); - assertLower("ÄBĆΔE", "UTF8_LCASE", "äbćδe"); - assertLower("aB世De", "UNICODE", "ab世de"); - assertLower("ÄBĆΔE", "UNICODE", "äbćδe"); - assertLower("aB世De", "UNICODE_CI", "ab世de"); - assertLower("ÄBĆΔE", "UNICODE_CI", "äbćδe"); - // Case-variable character length - assertLower("İo", "UTF8_BINARY","i\u0307o"); - assertLower("İo", "UTF8_LCASE","i\u0307o"); - assertLower("İo", "UNICODE","i\u0307o"); - assertLower("İo", "UNICODE_CI","i\u0307o"); + for (String collationName: testSupportedCollations) { + // Empty strings. + assertLower("", collationName, ""); + // Basic tests. + assertLower("abcde", collationName, "abcde"); + assertLower("AbCdE", collationName, "abcde"); + assertLower("aBcDe", collationName, "abcde"); + assertLower("ABCDE", collationName, "abcde"); + // Advanced tests. + assertUpper("aBćDe", collationName, "ABĆDE"); + assertUpper("ab世De", collationName, "AB世DE"); + assertUpper("äbćδe", collationName, "ÄBĆΔE"); + assertLower("AbĆdE", collationName, "abćde"); + assertLower("aB世De", collationName, "ab世de"); + assertLower("ÄBĆΔE", collationName, "äbćδe"); + // One-to-many case mapping (e.g. Turkish dotted I). + assertLower("İ", collationName, "i\u0307"); + assertLower("I\u0307", collationName,"i\u0307"); + assertLower("İonic", collationName, "i\u0307onic"); + assertLower("i\u0307onic", collationName,"i\u0307onic"); + assertLower("FIDELİO", collationName, "fideli\u0307o"); + // Conditional case mapping (e.g. Greek sigmas). + assertLower("σ", collationName, "σ"); + assertLower("ς", collationName, "ς"); + assertLower("Σ", collationName, "σ"); + assertLower("ΣΑΛΑΤΑ", collationName, "σαλατα"); + assertLower("σαλατα", collationName, "σαλατα"); + assertLower("ςαλατα", collationName, "ςαλατα"); + assertLower("ΘΑΛΑΣΣΙΝΟΣ", collationName, "θαλασσινος"); + assertLower("θαλασσινοσ", collationName, "θαλασσινοσ"); + assertLower("θαλασσινος", collationName, "θαλασσινος"); + // Surrogate pairs. + assertLower("a🙃B🙃c", collationName, "a🙃b🙃c"); + assertLower("😄 😆", collationName, "😄 😆"); + assertLower("😀😆😃😄", collationName, "😀😆😃😄"); + assertLower("𝔸", collationName, "𝔸"); + assertLower("𐐅", collationName, "𐐭"); + assertLower("𐐭", collationName, "𐐭"); + assertLower("𐐭𝔸", collationName, "𐐭𝔸"); + // Ligatures. + assertLower("ß fi ffi ff st ῗ", collationName,"ß fi ffi ff st ῗ"); + } } + /** + * Verify the behaviour of the `InitCap` collation support class. + */ + private void assertInitCap(String target, String collationName, String expected) - throws SparkException { + throws SparkException { UTF8String target_utf8 = UTF8String.fromString(target); UTF8String expected_utf8 = UTF8String.fromString(expected); int collationId = CollationFactory.collationNameToId(collationName); @@ -741,43 +1352,65 @@ private void assertInitCap(String target, String collationName, String expected) // Note: results should be the same in these tests for both ICU and JVM-based implementations. } + private void assertInitCap( + String target, + String collationName, + String expectedICU, + String expectedNonICU) throws SparkException { + UTF8String target_utf8 = UTF8String.fromString(target); + UTF8String expectedICU_utf8 = UTF8String.fromString(expectedICU); + UTF8String expectedNonICU_utf8 = UTF8String.fromString(expectedNonICU); + int collationId = CollationFactory.collationNameToId(collationName); + // Testing the new ICU-based implementation of the Lower function. + assertEquals(expectedICU_utf8, CollationSupport.InitCap.exec(target_utf8, collationId, true)); + // Testing the old JVM-based implementation of the Lower function. + assertEquals(expectedNonICU_utf8, CollationSupport.InitCap.exec(target_utf8, collationId, + false)); + // Note: results should be the same in these tests for both ICU and JVM-based implementations. + } + @Test public void testInitCap() throws SparkException { - // Edge cases - assertInitCap("", "UTF8_BINARY", ""); - assertInitCap("", "UTF8_LCASE", ""); - assertInitCap("", "UNICODE", ""); - assertInitCap("", "UNICODE_CI", ""); - // Basic tests - assertInitCap("ABCDE", "UTF8_BINARY", "Abcde"); - assertInitCap("ABCDE", "UTF8_LCASE", "Abcde"); - assertInitCap("ABCDE", "UNICODE", "Abcde"); - assertInitCap("ABCDE", "UNICODE_CI", "Abcde"); - // Uppercase present - assertInitCap("AbCdE", "UTF8_BINARY", "Abcde"); - assertInitCap("aBcDe", "UTF8_BINARY", "Abcde"); - assertInitCap("AbCdE", "UTF8_LCASE", "Abcde"); - assertInitCap("aBcDe", "UTF8_LCASE", "Abcde"); - assertInitCap("AbCdE", "UNICODE", "Abcde"); - assertInitCap("aBcDe", "UNICODE", "Abcde"); - assertInitCap("AbCdE", "UNICODE_CI", "Abcde"); - assertInitCap("aBcDe", "UNICODE_CI", "Abcde"); - // Accent letters - assertInitCap("AbĆdE", "UTF8_BINARY", "Abćde"); - assertInitCap("AbĆdE", "UTF8_LCASE", "Abćde"); - assertInitCap("AbĆdE", "UNICODE", "Abćde"); - assertInitCap("AbĆdE", "UNICODE_CI", "Abćde"); - // Variable byte length characters - assertInitCap("aB 世 De", "UTF8_BINARY", "Ab 世 De"); + for (String collationName: testSupportedCollations) { + // Empty strings. + assertInitCap("", collationName, ""); + // Basic tests. + assertInitCap("abcde", collationName, "Abcde"); + assertInitCap("AbCdE", collationName, "Abcde"); + assertInitCap("aBcDe", collationName, "Abcde"); + assertInitCap("ABCDE", collationName, "Abcde"); + // Conditional case mapping (e.g. Greek sigmas). + assertInitCap("σ", collationName, "Σ"); + assertInitCap("ς", collationName, "Σ"); + assertInitCap("Σ", collationName, "Σ"); + assertInitCap("ΣΑΛΑΤΑ", collationName, "Σαλατα"); + assertInitCap("σαλατα", collationName, "Σαλατα"); + assertInitCap("ςαλατα", collationName, "Σαλατα"); + assertInitCap("ΘΑΛΑΣΣΙΝΟΣ", collationName, "Θαλασσινος"); + assertInitCap("θαλασσινοσ", collationName, "Θαλασσινοσ"); + assertInitCap("θαλασσινος", collationName, "Θαλασσινος"); + } + // Advanced tests. + assertInitCap("aBćDe", "UTF8_BINARY", "Abćde"); + assertInitCap("aBćDe", "UTF8_LCASE", "Abćde"); + assertInitCap("aBćDe", "UNICODE", "Abćde"); + assertInitCap("aBćDe", "UNICODE_CI", "Abćde"); + assertInitCap("ab世De", "UTF8_BINARY", "Ab世de"); + assertInitCap("ab世De", "UTF8_LCASE", "Ab世De"); + assertInitCap("ab世De", "UNICODE", "Ab世De"); + assertInitCap("ab世De", "UNICODE_CI", "Ab世De"); + assertInitCap("äbćδe", "UTF8_BINARY", "Äbćδe"); + assertInitCap("äbćδe", "UTF8_LCASE", "Äbćδe"); + assertInitCap("äbćδe", "UNICODE", "Äbćδe"); + assertInitCap("äbćδe", "UNICODE_CI", "Äbćδe"); assertInitCap("ÄBĆΔE", "UTF8_BINARY", "Äbćδe"); - assertInitCap("aB 世 De", "UTF8_LCASE", "Ab 世 De"); assertInitCap("ÄBĆΔE", "UTF8_LCASE", "Äbćδe"); - assertInitCap("aB 世 De", "UNICODE", "Ab 世 De"); assertInitCap("ÄBĆΔE", "UNICODE", "Äbćδe"); - assertInitCap("aB 世 de", "UNICODE_CI", "Ab 世 De"); assertInitCap("ÄBĆΔE", "UNICODE_CI", "Äbćδe"); + assertInitCap("êéfgh", "AF_CI_AI", "Êéfgh"); + assertInitCap("öoAÄ", "DE_CI_AI", "Öoaä"); // Case-variable character length - assertInitCap("İo", "UTF8_BINARY", "I\u0307o"); + assertInitCap("İo", "UTF8_BINARY", "İo", "I\u0307o"); assertInitCap("İo", "UTF8_LCASE", "İo"); assertInitCap("İo", "UNICODE", "İo"); assertInitCap("İo", "UNICODE_CI", "İo"); @@ -786,6 +1419,67 @@ public void testInitCap() throws SparkException { assertInitCap("i\u0307o", "UNICODE", "I\u0307o"); assertInitCap("i\u0307o", "UNICODE_CI", "I\u0307o"); // Different possible word boundaries + assertInitCap("aB 世 de", "UTF8_BINARY", "Ab 世 De"); + assertInitCap("aB 世 de", "UTF8_LCASE", "Ab 世 De"); + assertInitCap("aB 世 de", "UNICODE", "Ab 世 De"); + assertInitCap("aB 世 de", "UNICODE_CI", "Ab 世 De"); + // One-to-many case mapping (e.g. Turkish dotted I). + assertInitCap("İ", "UTF8_BINARY", "İ", "I\u0307"); + assertInitCap("İ", "UTF8_LCASE", "İ"); + assertInitCap("İ", "UNICODE", "İ"); + assertInitCap("İ", "UNICODE_CI", "İ"); + assertInitCap("I\u0307", "UTF8_BINARY","I\u0307"); + assertInitCap("I\u0307", "UTF8_LCASE","I\u0307"); + assertInitCap("I\u0307", "UNICODE","I\u0307"); + assertInitCap("I\u0307", "UNICODE_CI","I\u0307"); + assertInitCap("İonic", "UTF8_BINARY", "İonic", "I\u0307onic"); + assertInitCap("İonic", "UTF8_LCASE", "İonic"); + assertInitCap("İonic", "UNICODE", "İonic"); + assertInitCap("İonic", "UNICODE_CI", "İonic"); + assertInitCap("i\u0307onic", "UTF8_BINARY","I\u0307onic"); + assertInitCap("i\u0307onic", "UTF8_LCASE","I\u0307onic"); + assertInitCap("i\u0307onic", "UNICODE","I\u0307onic"); + assertInitCap("i\u0307onic", "UNICODE_CI","I\u0307onic"); + assertInitCap("FIDELİO", "UTF8_BINARY", "Fideli\u0307o"); + assertInitCap("FIDELİO", "UTF8_LCASE", "Fideli\u0307o"); + assertInitCap("FIDELİO", "UNICODE", "Fideli\u0307o"); + assertInitCap("FIDELİO", "UNICODE_CI", "Fideli\u0307o"); + // Surrogate pairs. + assertInitCap("a🙃B🙃c", "UTF8_BINARY", "A🙃b🙃c"); + assertInitCap("a🙃B🙃c", "UTF8_LCASE", "A🙃B🙃C"); + assertInitCap("a🙃B🙃c", "UNICODE", "A🙃B🙃C"); + assertInitCap("a🙃B🙃c", "UNICODE_CI", "A🙃B🙃C"); + assertInitCap("😄 😆", "UTF8_BINARY", "😄 😆"); + assertInitCap("😄 😆", "UTF8_LCASE", "😄 😆"); + assertInitCap("😄 😆", "UNICODE", "😄 😆"); + assertInitCap("😄 😆", "UNICODE_CI", "😄 😆"); + assertInitCap("😀😆😃😄", "UTF8_BINARY", "😀😆😃😄"); + assertInitCap("😀😆😃😄", "UTF8_LCASE", "😀😆😃😄"); + assertInitCap("😀😆😃😄", "UNICODE", "😀😆😃😄"); + assertInitCap("😀😆😃😄", "UNICODE_CI", "😀😆😃😄"); + assertInitCap("𝔸", "UTF8_BINARY", "𝔸"); + assertInitCap("𝔸", "UTF8_LCASE", "𝔸"); + assertInitCap("𝔸", "UNICODE", "𝔸"); + assertInitCap("𝔸", "UNICODE_CI", "𝔸"); + assertInitCap("𐐅", "UTF8_BINARY", "\uD801\uDC05", "𐐭"); + assertInitCap("𐐅", "UTF8_LCASE", "𐐅"); + assertInitCap("𐐅", "UNICODE", "𐐅"); + assertInitCap("𐐅", "UNICODE_CI", "𐐅"); + assertInitCap("𐐭", "UTF8_BINARY", "\uD801\uDC05", "𐐭"); + assertInitCap("𐐭", "UTF8_LCASE", "𐐅"); + assertInitCap("𐐭", "UNICODE", "𐐅"); + assertInitCap("𐐭", "UNICODE_CI", "𐐅"); + assertInitCap("𐐭𝔸", "UTF8_BINARY", "\uD801\uDC05\uD835\uDD38", "𐐭𝔸"); + assertInitCap("𐐭𝔸", "UTF8_LCASE", "𐐅𝔸"); + assertInitCap("𐐭𝔸", "UNICODE", "𐐅𝔸"); + assertInitCap("𐐭𝔸", "UNICODE_CI", "𐐅𝔸"); + // Ligatures. + assertInitCap("ß fi ffi ff st ῗ", "UTF8_BINARY", "Ss Fi Ffi Ff St Ϊ͂", "ß fi ffi ff st ῗ"); + assertInitCap("ß fi ffi ff st ῗ", "UTF8_LCASE", "Ss Fi Ffi Ff St \u0399\u0308\u0342"); + assertInitCap("ß fi ffi ff st ῗ", "UNICODE", "Ss Fi Ffi Ff St \u0399\u0308\u0342"); + assertInitCap("ß fi ffi ff st ῗ", "UNICODE", "Ss Fi Ffi Ff St \u0399\u0308\u0342"); + assertInitCap("œ ǽ", "UTF8_BINARY", "Œ Ǽ", "Œ Ǽ"); + // Different possible word boundaries. assertInitCap("a b c", "UTF8_BINARY", "A B C"); assertInitCap("a b c", "UNICODE", "A B C"); assertInitCap("a b c", "UTF8_LCASE", "A B C"); @@ -802,7 +1496,7 @@ public void testInitCap() throws SparkException { assertInitCap("a?b世c", "UNICODE", "A?B世C"); assertInitCap("a?b世c", "UTF8_LCASE", "A?B世C"); assertInitCap("a?b世c", "UNICODE_CI", "A?B世C"); - // Titlecase characters that are different from uppercase characters + // Titlecase characters that are different from uppercase characters. assertInitCap("dzDZDz", "UTF8_BINARY", "Dzdzdz"); assertInitCap("dzDZDz", "UNICODE", "Dzdzdz"); assertInitCap("dzDZDz", "UTF8_LCASE", "Dzdzdz"); @@ -812,17 +1506,50 @@ public void testInitCap() throws SparkException { assertInitCap("džaba Ljubav NJegova", "UTF8_LCASE", "Džaba Ljubav Njegova"); assertInitCap("džaba Ljubav NJegova", "UNICODE_CI", "Džaba Ljubav Njegova"); assertInitCap("ß fi ffi ff st ΣΗΜΕΡΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UTF8_BINARY", - "ß fi ffi ff st Σημερινος Ασημενιος I\u0307ota"); + "Ss Fi Ffi Ff St Σημερινος Ασημενιος İota","ß fi ffi ff st Σημερινος Ασημενιος I\u0307ota"); assertInitCap("ß fi ffi ff st ΣΗΜΕΡΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UTF8_LCASE", "Ss Fi Ffi Ff St Σημερινος Ασημενιος İota"); assertInitCap("ß fi ffi ff st ΣΗΜΕΡΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UNICODE", "Ss Fi Ffi Ff St Σημερινος Ασημενιος İota"); - assertInitCap("ß fi ffi ff st ΣΗΜΕΡΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UNICODE_CI", - "Ss Fi Ffi Ff St Σημερινος Ασημενιος İota"); + assertInitCap("ß fi ffi ff st ΣΗΜΕΡςΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UNICODE_CI", + "Ss Fi Ffi Ff St Σημερςινος Ασημενιος İota"); + // Characters that map to multiple characters when titlecased and lowercased. + assertInitCap("ß fi ffi ff st İOTA", "UTF8_BINARY", "Ss Fi Ffi Ff St İota", "ß fi ffi ff st İota"); + assertInitCap("ß fi ffi ff st OİOTA", "UTF8_BINARY", + "Ss Fi Ffi Ff St Oi\u0307ota", "ß fi ffi ff st Oi̇ota"); + // Lowercasing Greek letter sigma ('Σ') when case-ignorable character present. + assertInitCap("`Σ", "UTF8_BINARY", "`σ", "`σ"); + assertInitCap("1`Σ`` AΣ", "UTF8_BINARY", "1`σ`` Aς", "1`σ`` Aς"); + assertInitCap("a1`Σ``", "UTF8_BINARY", "A1`σ``", "A1`σ``"); + assertInitCap("a`Σ``", "UTF8_BINARY", "A`ς``", "A`σ``"); + assertInitCap("a`Σ``1", "UTF8_BINARY", "A`ς``1", "A`σ``1"); + assertInitCap("a`Σ``A", "UTF8_BINARY", "A`σ``a", "A`σ``a"); + assertInitCap("ΘΑ�Σ�ΟΣ�", "UTF8_BINARY", "Θα�σ�ος�", "Θα�σ�ος�"); + assertInitCap("ΘΑᵩΣ�ΟᵩΣᵩ�", "UTF8_BINARY", "Θαᵩς�οᵩςᵩ�", "Θαᵩς�οᵩςᵩ�"); + assertInitCap("ΘΑ�ᵩΣ�ΟᵩΣᵩ�", "UTF8_BINARY", "Θα�ᵩσ�οᵩςᵩ�", "Θα�ᵩσ�οᵩςᵩ�"); + assertInitCap("ΘΑ�ᵩΣᵩ�ΟᵩΣᵩ�", "UTF8_BINARY", "Θα�ᵩσᵩ�οᵩςᵩ�", "Θα�ᵩσᵩ�οᵩςᵩ�"); + assertInitCap("ΘΑ�Σ�Ο�Σ�", "UTF8_BINARY", "Θα�σ�ο�σ�", "Θα�σ�ο�σ�"); + // Disallowed bytes and invalid sequences. + assertInitCap(UTF8String.fromBytes(new byte[] { (byte)0xC0, (byte)0xC1, (byte)0xF5}).toString(), + "UTF8_BINARY", "���", "���"); + assertInitCap(UTF8String.fromBytes( + new byte[]{(byte)0xC0, (byte)0xC1, (byte)0xF5, 0x20, 0x61, 0x41, (byte)0xC0}).toString(), + "UTF8_BINARY", + "��� Aa�", "��� Aa�"); + assertInitCap(UTF8String.fromBytes(new byte[]{(byte)0xC2,(byte)0xC2}).toString(), + "UTF8_BINARY", "��", "��"); + assertInitCap(UTF8String.fromBytes( + new byte[]{0x61, 0x41, (byte)0xC2, (byte)0xC2, 0x41}).toString(), + "UTF8_BINARY", + "Aa��a", "Aa��a"); } - private void assertStringInstr(String string, String substring, String collationName, - Integer expected) throws SparkException { + /** + * Verify the behaviour of the `StringInstr` collation support class. + */ + + private void assertStringInstr(String string, String substring, + String collationName, int expected) throws SparkException { UTF8String str = UTF8String.fromString(string); UTF8String substr = UTF8String.fromString(substring); int collationId = CollationFactory.collationNameToId(collationName); @@ -831,143 +1558,402 @@ private void assertStringInstr(String string, String substring, String collation @Test public void testStringInstr() throws SparkException { - assertStringInstr("aaads", "Aa", "UTF8_BINARY", 0); - assertStringInstr("aaaDs", "de", "UTF8_BINARY", 0); + // Empty strings. + assertStringInstr("", "", "UTF8_BINARY", 1); + assertStringInstr("", "", "UTF8_LCASE", 1); + assertStringInstr("", "", "UNICODE_CI", 1); + assertStringInstr("", "", "UNICODE", 1); + assertStringInstr("a", "", "UTF8_BINARY", 1); + assertStringInstr("a", "", "UTF8_LCASE", 1); + assertStringInstr("a", "", "UNICODE", 1); + assertStringInstr("a", "", "UNICODE_CI", 1); + assertStringInstr("", "x", "UTF8_BINARY", 0); + assertStringInstr("", "x", "UTF8_LCASE", 0); + assertStringInstr("", "x", "UNICODE", 0); + assertStringInstr("", "x", "UNICODE_CI", 0); + // Basic tests. + assertStringInstr("aaads", "aa", "UTF8_BINARY", 1); + assertStringInstr("aaads", "aa", "UTF8_LCASE", 1); + assertStringInstr("aaads", "aa", "UNICODE", 1); + assertStringInstr("aaads", "aa", "UNICODE_CI", 1); assertStringInstr("aaads", "ds", "UTF8_BINARY", 4); - assertStringInstr("xxxx", "", "UTF8_BINARY", 1); - assertStringInstr("", "xxxx", "UTF8_BINARY", 0); - assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY", 5); - assertStringInstr("test大千世界X大千世界", "界X", "UTF8_BINARY", 8); + assertStringInstr("aaads", "ds", "UTF8_LCASE", 4); + assertStringInstr("aaads", "ds", "UNICODE", 4); + assertStringInstr("aaads", "ds", "UNICODE_CI", 4); + assertStringInstr("aaads", "Aa", "UTF8_BINARY", 0); assertStringInstr("aaads", "Aa", "UTF8_LCASE", 1); + assertStringInstr("aaads", "Aa", "UNICODE", 0); + assertStringInstr("aaads", "Aa", "UNICODE_CI", 1); + assertStringInstr("aaaDs", "de", "UTF8_BINARY", 0); assertStringInstr("aaaDs", "de", "UTF8_LCASE", 0); + assertStringInstr("aaaDs", "de", "UNICODE", 0); + assertStringInstr("aaaDs", "de", "UNICODE_CI", 0); + assertStringInstr("aaaDs", "ds", "UTF8_BINARY", 0); assertStringInstr("aaaDs", "ds", "UTF8_LCASE", 4); - assertStringInstr("xxxx", "", "UTF8_LCASE", 1); - assertStringInstr("", "xxxx", "UTF8_LCASE", 0); + assertStringInstr("aaaDs", "ds", "UNICODE", 0); + assertStringInstr("aaaDs", "ds", "UNICODE_CI", 4); + assertStringInstr("aaadS", "Ds", "UTF8_BINARY", 0); + assertStringInstr("aaadS", "Ds", "UTF8_LCASE", 4); + assertStringInstr("aaadS", "Ds", "UNICODE", 0); + assertStringInstr("aaadS", "Ds", "UNICODE_CI", 4); + assertStringInstr("aaaČŠčšcs", "cs", "SR", 8); + assertStringInstr("aaaČŠčšcs", "cs", "SR_CI_AI", 4); + // Advanced tests. + assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY", 5); assertStringInstr("test大千世界X大千世界", "大千", "UTF8_LCASE", 5); + assertStringInstr("test大千世界X大千世界", "大千", "UNICODE", 5); + assertStringInstr("test大千世界X大千世界", "大千", "UNICODE_CI", 5); + assertStringInstr("test大千世界X大千世界", "界X", "UTF8_BINARY", 8); + assertStringInstr("test大千世界X大千世界", "界X", "UTF8_LCASE", 8); + assertStringInstr("test大千世界X大千世界", "界X", "UNICODE", 8); + assertStringInstr("test大千世界X大千世界", "界X", "UNICODE_CI", 8); + assertStringInstr("test大千世界X大千世界", "界x", "UTF8_BINARY", 0); assertStringInstr("test大千世界X大千世界", "界x", "UTF8_LCASE", 8); - assertStringInstr("aaads", "Aa", "UNICODE", 0); - assertStringInstr("aaads", "aa", "UNICODE", 1); - assertStringInstr("aaads", "de", "UNICODE", 0); - assertStringInstr("xxxx", "", "UNICODE", 1); - assertStringInstr("", "xxxx", "UNICODE", 0); assertStringInstr("test大千世界X大千世界", "界x", "UNICODE", 0); - assertStringInstr("test大千世界X大千世界", "界X", "UNICODE", 8); - assertStringInstr("xxxx", "", "UNICODE_CI", 1); - assertStringInstr("", "xxxx", "UNICODE_CI", 0); - assertStringInstr("aaads", "AD", "UNICODE_CI", 3); - assertStringInstr("aaads", "dS", "UNICODE_CI", 4); - assertStringInstr("test大千世界X大千世界", "界y", "UNICODE_CI", 0); assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8); - assertStringInstr("i̇", "i", "UNICODE_CI", 0); - assertStringInstr("i̇", "\u0307", "UNICODE_CI", 0); - assertStringInstr("i̇", "İ", "UNICODE_CI", 1); + assertStringInstr("test大千世界X大千世界", "界y", "UTF8_BINARY", 0); + assertStringInstr("test大千世界X大千世界", "界y", "UTF8_LCASE", 0); + assertStringInstr("test大千世界X大千世界", "界y", "UNICODE", 0); + assertStringInstr("test大千世界X大千世界", "界y", "UNICODE_CI", 0); + // One-to-many case mapping (e.g. Turkish dotted I). + assertStringInstr("i\u0307", "i", "UNICODE_CI", 0); + assertStringInstr("i\u0307", "\u0307", "UNICODE_CI", 0); + assertStringInstr("i\u0307", "İ", "UNICODE_CI", 1); assertStringInstr("İ", "i", "UNICODE_CI", 0); - assertStringInstr("İoi̇o12", "i̇o", "UNICODE_CI", 1); + assertStringInstr("İoi̇o12", "i\u0307o", "UNICODE_CI", 1); assertStringInstr("i̇oİo12", "İo", "UNICODE_CI", 1); - assertStringInstr("abİoi̇o", "i̇o", "UNICODE_CI", 3); + assertStringInstr("abİoi̇o", "i\u0307o", "UNICODE_CI", 3); assertStringInstr("abi̇oİo", "İo", "UNICODE_CI", 3); assertStringInstr("ai̇oxXİo", "Xx", "UNICODE_CI", 5); assertStringInstr("aİoi̇oxx", "XX", "UNICODE_CI", 7); - assertStringInstr("i̇", "i", "UTF8_LCASE", 1); // != UNICODE_CI - assertStringInstr("i̇", "\u0307", "UTF8_LCASE", 2); // != UNICODE_CI - assertStringInstr("i̇", "İ", "UTF8_LCASE", 1); + assertStringInstr("i\u0307", "i", "UTF8_LCASE", 1); // != UNICODE_CI + assertStringInstr("i\u0307", "\u0307", "UTF8_LCASE", 2); // != UNICODE_CI + assertStringInstr("i\u0307", "İ", "UTF8_LCASE", 1); assertStringInstr("İ", "i", "UTF8_LCASE", 0); - assertStringInstr("İoi̇o12", "i̇o", "UTF8_LCASE", 1); + assertStringInstr("İoi̇o12", "i\u0307o", "UTF8_LCASE", 1); assertStringInstr("i̇oİo12", "İo", "UTF8_LCASE", 1); - assertStringInstr("abİoi̇o", "i̇o", "UTF8_LCASE", 3); + assertStringInstr("abİoi̇o", "i\u0307o", "UTF8_LCASE", 3); assertStringInstr("abi̇oİo", "İo", "UTF8_LCASE", 3); assertStringInstr("abI\u0307oi̇o", "İo", "UTF8_LCASE", 3); assertStringInstr("ai̇oxXİo", "Xx", "UTF8_LCASE", 5); assertStringInstr("abİoi̇o", "\u0307o", "UTF8_LCASE", 6); assertStringInstr("aİoi̇oxx", "XX", "UTF8_LCASE", 7); + // Conditional case mapping (e.g. Greek sigmas). + assertStringInstr("σ", "σ", "UTF8_BINARY", 1); + assertStringInstr("σ", "ς", "UTF8_BINARY", 0); + assertStringInstr("σ", "Σ", "UTF8_BINARY", 0); + assertStringInstr("ς", "σ", "UTF8_BINARY", 0); + assertStringInstr("ς", "ς", "UTF8_BINARY", 1); + assertStringInstr("ς", "Σ", "UTF8_BINARY", 0); + assertStringInstr("Σ", "σ", "UTF8_BINARY", 0); + assertStringInstr("Σ", "ς", "UTF8_BINARY", 0); + assertStringInstr("Σ", "Σ", "UTF8_BINARY", 1); + assertStringInstr("σ", "σ", "UTF8_LCASE", 1); + assertStringInstr("σ", "ς", "UTF8_LCASE", 1); + assertStringInstr("σ", "Σ", "UTF8_LCASE", 1); + assertStringInstr("ς", "σ", "UTF8_LCASE", 1); + assertStringInstr("ς", "ς", "UTF8_LCASE", 1); + assertStringInstr("ς", "Σ", "UTF8_LCASE", 1); + assertStringInstr("Σ", "σ", "UTF8_LCASE", 1); + assertStringInstr("Σ", "ς", "UTF8_LCASE", 1); + assertStringInstr("Σ", "Σ", "UTF8_LCASE", 1); + assertStringInstr("σ", "σ", "UNICODE", 1); + assertStringInstr("σ", "ς", "UNICODE", 0); + assertStringInstr("σ", "Σ", "UNICODE", 0); + assertStringInstr("ς", "σ", "UNICODE", 0); + assertStringInstr("ς", "ς", "UNICODE", 1); + assertStringInstr("ς", "Σ", "UNICODE", 0); + assertStringInstr("Σ", "σ", "UNICODE", 0); + assertStringInstr("Σ", "ς", "UNICODE", 0); + assertStringInstr("Σ", "Σ", "UNICODE", 1); + assertStringInstr("σ", "σ", "UNICODE_CI", 1); + assertStringInstr("σ", "ς", "UNICODE_CI", 1); + assertStringInstr("σ", "Σ", "UNICODE_CI", 1); + assertStringInstr("ς", "σ", "UNICODE_CI", 1); + assertStringInstr("ς", "ς", "UNICODE_CI", 1); + assertStringInstr("ς", "Σ", "UNICODE_CI", 1); + assertStringInstr("Σ", "σ", "UNICODE_CI", 1); + assertStringInstr("Σ", "ς", "UNICODE_CI", 1); + assertStringInstr("Σ", "Σ", "UNICODE_CI", 1); + // Surrogate pairs. + assertStringInstr("a🙃b", "a", "UTF8_BINARY", 1); + assertStringInstr("a🙃b", "a", "UTF8_LCASE", 1); + assertStringInstr("a🙃b", "a", "UNICODE", 1); + assertStringInstr("a🙃b", "a", "UNICODE_CI", 1); + assertStringInstr("a🙃b", "🙃", "UTF8_BINARY", 2); + assertStringInstr("a🙃b", "🙃", "UTF8_LCASE", 2); + assertStringInstr("a🙃b", "🙃", "UNICODE", 2); + assertStringInstr("a🙃b", "🙃", "UNICODE_CI", 2); + assertStringInstr("a🙃b", "b", "UTF8_BINARY", 3); + assertStringInstr("a🙃b", "b", "UTF8_LCASE", 3); + assertStringInstr("a🙃b", "b", "UNICODE", 3); + assertStringInstr("a🙃b", "b", "UNICODE_CI", 3); + assertStringInstr("a🙃🙃b", "🙃", "UTF8_BINARY", 2); + assertStringInstr("a🙃🙃b", "🙃", "UTF8_LCASE", 2); + assertStringInstr("a🙃🙃b", "🙃", "UNICODE", 2); + assertStringInstr("a🙃🙃b", "🙃", "UNICODE_CI", 2); + assertStringInstr("a🙃🙃b", "b", "UTF8_BINARY", 4); + assertStringInstr("a🙃🙃b", "b", "UTF8_LCASE", 4); + assertStringInstr("a🙃🙃b", "b", "UNICODE", 4); + assertStringInstr("a🙃🙃b", "b", "UNICODE_CI", 4); + assertStringInstr("a🙃x🙃b", "b", "UTF8_BINARY", 5); + assertStringInstr("a🙃x🙃b", "b", "UTF8_LCASE", 5); + assertStringInstr("a🙃x🙃b", "b", "UNICODE", 5); + assertStringInstr("a🙃x🙃b", "b", "UNICODE_CI", 5); } + /** + * Verify the behaviour of the `FindInSet` collation support class. + */ + private void assertFindInSet(String word, UTF8String set, String collationName, - Integer expected) throws SparkException { + int expected) throws SparkException { UTF8String w = UTF8String.fromString(word); int collationId = CollationFactory.collationNameToId(collationName); - assertEquals(expected, CollationSupport.FindInSet.exec(w, set, collationId)); + int result = CollationSupport.FindInSet.exec(w, set, collationId); + assertEquals(expected, result); } @Test public void testFindInSet() throws SparkException { - assertFindInSet("AB", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); - assertFindInSet("abc", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 1); - assertFindInSet("def", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 5); - assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); + // Empty strings. + assertFindInSet("", UTF8String.fromString(""), "UTF8_BINARY", 1); + assertFindInSet("", UTF8String.fromString(""), "UTF8_LCASE", 1); + assertFindInSet("", UTF8String.fromString(""), "UNICODE", 1); + assertFindInSet("", UTF8String.fromString(""), "UNICODE_CI", 1); assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); - assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UTF8_BINARY", 1); - assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UTF8_BINARY", 6); - assertFindInSet("", UTF8String.fromString("abc"), "UTF8_BINARY", 0); - assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0); - assertFindInSet("c", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 4); - assertFindInSet("AB", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 3); - assertFindInSet("AbC", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 1); - assertFindInSet("abcd", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0); - assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0); - assertFindInSet("XX", UTF8String.fromString("xx"), "UTF8_LCASE", 1); assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0); + assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); + assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0); + assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UTF8_BINARY", 1); assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UTF8_LCASE", 1); - assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UTF8_LCASE", 6); - assertFindInSet("", UTF8String.fromString("abc"), "UTF8_LCASE", 0); - assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UTF8_LCASE", 4); - assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); - assertFindInSet("ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 3); - assertFindInSet("Ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); - assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UNICODE", 1); + assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UNICODE_CI", 1); + assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UTF8_BINARY", 6); + assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UTF8_LCASE", 6); assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UNICODE", 6); + assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UNICODE_CI", 6); + assertFindInSet("", UTF8String.fromString("abc"), "UTF8_BINARY", 0); + assertFindInSet("", UTF8String.fromString("abc"), "UTF8_LCASE", 0); assertFindInSet("", UTF8String.fromString("abc"), "UNICODE", 0); + assertFindInSet("", UTF8String.fromString("abc"), "UNICODE_CI", 0); + // Basic tests. + assertFindInSet("xx", UTF8String.fromString("xx"), "UTF8_BINARY", 1); + assertFindInSet("xx", UTF8String.fromString("xx"), "UTF8_LCASE", 1); assertFindInSet("xx", UTF8String.fromString("xx"), "UNICODE", 1); - assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 0); - assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 5); + assertFindInSet("xx", UTF8String.fromString("xx"), "UNICODE_CI", 1); + assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); + assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0); + assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0); + assertFindInSet("abc", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 1); + assertFindInSet("abc", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 1); + assertFindInSet("abc", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 1); + assertFindInSet("abc", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 1); + assertFindInSet("abcd", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); + assertFindInSet("abcd", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0); + assertFindInSet("abcd", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); + assertFindInSet("abcd", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0); + assertFindInSet("def", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 5); + assertFindInSet("def", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 5); + assertFindInSet("def", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 5); + assertFindInSet("def", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 5); + assertFindInSet("xyz", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); + assertFindInSet("xyz", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0); + assertFindInSet("xyz", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); + assertFindInSet("xyz", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0); + assertFindInSet("Ab", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); + assertFindInSet("Ab", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 3); + assertFindInSet("Ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); + assertFindInSet("Ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 3); + assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); + assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0); + assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); + assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0); + assertFindInSet("C", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0); + assertFindInSet("C", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 4); + assertFindInSet("C", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0); assertFindInSet("C", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 4); - assertFindInSet("DeF", UTF8String.fromString("abc,b,ab,c,dEf"), "UNICODE_CI", 5); - assertFindInSet("DEFG", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0); - assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UNICODE_CI", 1); - assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UNICODE_CI", 6); - assertFindInSet("", UTF8String.fromString("abc"), "UNICODE_CI", 0); - assertFindInSet("XX", UTF8String.fromString("xx"), "UNICODE_CI", 1); + // Advanced tests. + assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UTF8_BINARY", 5); + assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UTF8_LCASE", 5); + assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 5); + assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 5); + assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UTF8_BINARY", 0); + assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UTF8_LCASE", 4); + assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 0); assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 4); + assertFindInSet("界x", UTF8String.fromString("test,大千,界Xx,世,界X,大,千,世界"), "UTF8_BINARY", 0); + assertFindInSet("界x", UTF8String.fromString("test,大千,界Xx,世,界X,大,千,世界"), "UTF8_LCASE", 5); + assertFindInSet("界x", UTF8String.fromString("test,大千,界Xx,世,界X,大,千,世界"), "UNICODE", 0); assertFindInSet("界x", UTF8String.fromString("test,大千,界Xx,世,界X,大,千,世界"), "UNICODE_CI", 5); - assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 5); - assertFindInSet("i̇", UTF8String.fromString("İ"), "UNICODE_CI", 1); - assertFindInSet("i", UTF8String.fromString("İ"), "UNICODE_CI", 0); - assertFindInSet("i̇", UTF8String.fromString("i̇"), "UNICODE_CI", 1); - assertFindInSet("i", UTF8String.fromString("i̇"), "UNICODE_CI", 0); - assertFindInSet("i̇", UTF8String.fromString("İ,"), "UNICODE_CI", 1); - assertFindInSet("i", UTF8String.fromString("İ,"), "UNICODE_CI", 0); - assertFindInSet("i̇", UTF8String.fromString("i̇,"), "UNICODE_CI", 1); - assertFindInSet("i", UTF8String.fromString("i̇,"), "UNICODE_CI", 0); - assertFindInSet("i̇", UTF8String.fromString("ab,İ"), "UNICODE_CI", 2); - assertFindInSet("i", UTF8String.fromString("ab,İ"), "UNICODE_CI", 0); - assertFindInSet("i̇", UTF8String.fromString("ab,i̇"), "UNICODE_CI", 2); - assertFindInSet("i", UTF8String.fromString("ab,i̇"), "UNICODE_CI", 0); - assertFindInSet("i̇", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 2); - assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 0); - assertFindInSet("i̇", UTF8String.fromString("ab,i̇,12"), "UNICODE_CI", 2); - assertFindInSet("i", UTF8String.fromString("ab,i̇,12"), "UNICODE_CI", 0); - assertFindInSet("i̇o", UTF8String.fromString("ab,İo,12"), "UNICODE_CI", 2); - assertFindInSet("İo", UTF8String.fromString("ab,i̇o,12"), "UNICODE_CI", 2); - assertFindInSet("i̇", UTF8String.fromString("İ"), "UTF8_LCASE", 1); + // One-to-many case mapping (e.g. Turkish dotted I). + assertFindInSet("i\u0307", UTF8String.fromString("İ"), "UTF8_BINARY", 0); + assertFindInSet("i\u0307", UTF8String.fromString("İ"), "UTF8_LCASE", 1); + assertFindInSet("i\u0307", UTF8String.fromString("İ"), "UNICODE", 0); + assertFindInSet("i\u0307", UTF8String.fromString("İ"), "UNICODE_CI", 1); + assertFindInSet("i", UTF8String.fromString("İ"), "UTF8_BINARY", 0); assertFindInSet("i", UTF8String.fromString("İ"), "UTF8_LCASE", 0); - assertFindInSet("i̇", UTF8String.fromString("i̇"), "UTF8_LCASE", 1); - assertFindInSet("i", UTF8String.fromString("i̇"), "UTF8_LCASE", 0); - assertFindInSet("i̇", UTF8String.fromString("İ,"), "UTF8_LCASE", 1); + assertFindInSet("i", UTF8String.fromString("İ"), "UNICODE", 0); + assertFindInSet("i", UTF8String.fromString("İ"), "UNICODE_CI", 0); + assertFindInSet("i\u0307", UTF8String.fromString("i\u0307"), "UTF8_BINARY", 1); + assertFindInSet("i\u0307", UTF8String.fromString("i\u0307"), "UTF8_LCASE", 1); + assertFindInSet("i\u0307", UTF8String.fromString("i\u0307"), "UNICODE", 1); + assertFindInSet("i\u0307", UTF8String.fromString("i\u0307"), "UNICODE_CI", 1); + assertFindInSet("i", UTF8String.fromString("i\u0307"), "UTF8_BINARY", 0); + assertFindInSet("i", UTF8String.fromString("i\u0307"), "UTF8_LCASE", 0); + assertFindInSet("i", UTF8String.fromString("i\u0307"), "UNICODE", 0); + assertFindInSet("i", UTF8String.fromString("i\u0307"), "UNICODE_CI", 0); + assertFindInSet("i\u0307", UTF8String.fromString("İ,"), "UTF8_BINARY", 0); + assertFindInSet("i\u0307", UTF8String.fromString("İ,"), "UTF8_LCASE", 1); + assertFindInSet("i\u0307", UTF8String.fromString("İ,"), "UNICODE", 0); + assertFindInSet("i\u0307", UTF8String.fromString("İ,"), "UNICODE_CI", 1); + assertFindInSet("i", UTF8String.fromString("İ,"), "UTF8_BINARY", 0); assertFindInSet("i", UTF8String.fromString("İ,"), "UTF8_LCASE", 0); - assertFindInSet("i̇", UTF8String.fromString("i̇,"), "UTF8_LCASE", 1); - assertFindInSet("i", UTF8String.fromString("i̇,"), "UTF8_LCASE", 0); - assertFindInSet("i̇", UTF8String.fromString("ab,İ"), "UTF8_LCASE", 2); + assertFindInSet("i", UTF8String.fromString("İ,"), "UNICODE", 0); + assertFindInSet("i", UTF8String.fromString("İ,"), "UNICODE_CI", 0); + assertFindInSet("i\u0307", UTF8String.fromString("i\u0307,"), "UTF8_BINARY", 1); + assertFindInSet("i\u0307", UTF8String.fromString("i\u0307,"), "UTF8_LCASE", 1); + assertFindInSet("i\u0307", UTF8String.fromString("i\u0307,"), "UNICODE", 1); + assertFindInSet("i\u0307", UTF8String.fromString("i\u0307,"), "UNICODE_CI", 1); + assertFindInSet("i", UTF8String.fromString("i\u0307,"), "UTF8_BINARY", 0); + assertFindInSet("i", UTF8String.fromString("i\u0307,"), "UTF8_LCASE", 0); + assertFindInSet("i", UTF8String.fromString("i\u0307,"), "UNICODE", 0); + assertFindInSet("i", UTF8String.fromString("i\u0307,"), "UNICODE_CI", 0); + assertFindInSet("i\u0307", UTF8String.fromString("ab,İ"), "UTF8_BINARY", 0); + assertFindInSet("i\u0307", UTF8String.fromString("ab,İ"), "UTF8_LCASE", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,İ"), "UNICODE", 0); + assertFindInSet("i\u0307", UTF8String.fromString("ab,İ"), "UNICODE_CI", 2); + assertFindInSet("i", UTF8String.fromString("ab,İ"), "UTF8_BINARY", 0); assertFindInSet("i", UTF8String.fromString("ab,İ"), "UTF8_LCASE", 0); - assertFindInSet("i̇", UTF8String.fromString("ab,i̇"), "UTF8_LCASE", 2); - assertFindInSet("i", UTF8String.fromString("ab,i̇"), "UTF8_LCASE", 0); - assertFindInSet("i̇", UTF8String.fromString("ab,İ,12"), "UTF8_LCASE", 2); + assertFindInSet("i", UTF8String.fromString("ab,İ"), "UNICODE", 0); + assertFindInSet("i", UTF8String.fromString("ab,İ"), "UNICODE_CI", 0); + assertFindInSet("i\u0307", UTF8String.fromString("ab,i\u0307"), "UTF8_BINARY", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,i\u0307"), "UTF8_LCASE", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,i\u0307"), "UNICODE", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,i\u0307"), "UNICODE_CI", 2); + assertFindInSet("i", UTF8String.fromString("ab,i\u0307"), "UTF8_BINARY", 0); + assertFindInSet("i", UTF8String.fromString("ab,i\u0307"), "UTF8_LCASE", 0); + assertFindInSet("i", UTF8String.fromString("ab,i\u0307"), "UNICODE", 0); + assertFindInSet("i", UTF8String.fromString("ab,i\u0307"), "UNICODE_CI", 0); + assertFindInSet("İ", UTF8String.fromString("ab,i\u0307"), "UTF8_BINARY", 0); + assertFindInSet("İ", UTF8String.fromString("ab,i\u0307"), "UTF8_LCASE", 2); + assertFindInSet("İ", UTF8String.fromString("ab,i\u0307"), "UNICODE", 0); + assertFindInSet("İ", UTF8String.fromString("ab,i\u0307"), "UNICODE_CI", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,İ,12"), "UTF8_BINARY", 0); + assertFindInSet("i\u0307", UTF8String.fromString("ab,İ,12"), "UTF8_LCASE", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,İ,12"), "UNICODE", 0); + assertFindInSet("i\u0307", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 2); + assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UTF8_BINARY", 0); assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UTF8_LCASE", 0); - assertFindInSet("i̇", UTF8String.fromString("ab,i̇,12"), "UTF8_LCASE", 2); - assertFindInSet("i", UTF8String.fromString("ab,i̇,12"), "UTF8_LCASE", 0); - assertFindInSet("i̇o", UTF8String.fromString("ab,İo,12"), "UTF8_LCASE", 2); - assertFindInSet("İo", UTF8String.fromString("ab,i̇o,12"), "UTF8_LCASE", 2); + assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UNICODE", 0); + assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 0); + assertFindInSet("i\u0307", UTF8String.fromString("ab,i\u0307,12"), "UTF8_BINARY", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,i\u0307,12"), "UTF8_LCASE", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,i\u0307,12"), "UNICODE", 2); + assertFindInSet("i\u0307", UTF8String.fromString("ab,i\u0307,12"), "UNICODE_CI", 2); + assertFindInSet("i", UTF8String.fromString("ab,i\u0307,12"), "UTF8_BINARY", 0); + assertFindInSet("i", UTF8String.fromString("ab,i\u0307,12"), "UTF8_LCASE", 0); + assertFindInSet("i", UTF8String.fromString("ab,i\u0307,12"), "UNICODE", 0); + assertFindInSet("i", UTF8String.fromString("ab,i\u0307,12"), "UNICODE_CI", 0); + assertFindInSet("i\u0307o", UTF8String.fromString("ab,İo,12"), "UTF8_BINARY", 0); + assertFindInSet("i\u0307o", UTF8String.fromString("ab,İo,12"), "UTF8_LCASE", 2); + assertFindInSet("i\u0307o", UTF8String.fromString("ab,İo,12"), "UNICODE", 0); + assertFindInSet("i\u0307o", UTF8String.fromString("ab,İo,12"), "UNICODE_CI", 2); + assertFindInSet("İo", UTF8String.fromString("ab,i\u0307o,12"), "UTF8_BINARY", 0); + assertFindInSet("İo", UTF8String.fromString("ab,i\u0307o,12"), "UTF8_LCASE", 2); + assertFindInSet("İo", UTF8String.fromString("ab,i\u0307o,12"), "UNICODE", 0); + assertFindInSet("İo", UTF8String.fromString("ab,i\u0307o,12"), "UNICODE_CI", 2); + // Conditional case mapping (e.g. Greek sigmas). + assertFindInSet("σ", UTF8String.fromString("σ"), "UTF8_BINARY", 1); + assertFindInSet("σ", UTF8String.fromString("ς"), "UTF8_BINARY", 0); + assertFindInSet("σ", UTF8String.fromString("Σ"), "UTF8_BINARY", 0); + assertFindInSet("ς", UTF8String.fromString("σ"), "UTF8_BINARY", 0); + assertFindInSet("ς", UTF8String.fromString("ς"), "UTF8_BINARY", 1); + assertFindInSet("ς", UTF8String.fromString("Σ"), "UTF8_BINARY", 0); + assertFindInSet("Σ", UTF8String.fromString("σ"), "UTF8_BINARY", 0); + assertFindInSet("Σ", UTF8String.fromString("ς"), "UTF8_BINARY", 0); + assertFindInSet("Σ", UTF8String.fromString("Σ"), "UTF8_BINARY", 1); + assertFindInSet("σ", UTF8String.fromString("σ"), "UTF8_LCASE", 1); + assertFindInSet("σ", UTF8String.fromString("ς"), "UTF8_LCASE", 1); + assertFindInSet("σ", UTF8String.fromString("Σ"), "UTF8_LCASE", 1); + assertFindInSet("ς", UTF8String.fromString("σ"), "UTF8_LCASE", 1); + assertFindInSet("ς", UTF8String.fromString("ς"), "UTF8_LCASE", 1); + assertFindInSet("ς", UTF8String.fromString("Σ"), "UTF8_LCASE", 1); + assertFindInSet("Σ", UTF8String.fromString("σ"), "UTF8_LCASE", 1); + assertFindInSet("Σ", UTF8String.fromString("ς"), "UTF8_LCASE", 1); + assertFindInSet("Σ", UTF8String.fromString("Σ"), "UTF8_LCASE", 1); + assertFindInSet("σ", UTF8String.fromString("σ"), "UNICODE", 1); + assertFindInSet("σ", UTF8String.fromString("ς"), "UNICODE", 0); + assertFindInSet("σ", UTF8String.fromString("Σ"), "UNICODE", 0); + assertFindInSet("ς", UTF8String.fromString("σ"), "UNICODE", 0); + assertFindInSet("ς", UTF8String.fromString("ς"), "UNICODE", 1); + assertFindInSet("ς", UTF8String.fromString("Σ"), "UNICODE", 0); + assertFindInSet("Σ", UTF8String.fromString("σ"), "UNICODE", 0); + assertFindInSet("Σ", UTF8String.fromString("ς"), "UNICODE", 0); + assertFindInSet("Σ", UTF8String.fromString("Σ"), "UNICODE", 1); + assertFindInSet("σ", UTF8String.fromString("σ"), "UNICODE_CI", 1); + assertFindInSet("σ", UTF8String.fromString("ς"), "UNICODE_CI", 1); + assertFindInSet("σ", UTF8String.fromString("Σ"), "UNICODE_CI", 1); + assertFindInSet("ς", UTF8String.fromString("σ"), "UNICODE_CI", 1); + assertFindInSet("ς", UTF8String.fromString("ς"), "UNICODE_CI", 1); + assertFindInSet("ς", UTF8String.fromString("Σ"), "UNICODE_CI", 1); + assertFindInSet("Σ", UTF8String.fromString("σ"), "UNICODE_CI", 1); + assertFindInSet("Σ", UTF8String.fromString("ς"), "UNICODE_CI", 1); + assertFindInSet("Σ", UTF8String.fromString("Σ"), "UNICODE_CI", 1); + // Surrogate pairs. + assertFindInSet("a", UTF8String.fromString("a🙃,b,🙃c"), "UTF8_BINARY", 0); + assertFindInSet("a", UTF8String.fromString("a🙃,b,🙃c"), "UTF8_LCASE", 0); + assertFindInSet("a", UTF8String.fromString("a🙃,b,🙃c"), "UNICODE", 0); + assertFindInSet("a", UTF8String.fromString("a🙃,b,🙃c"), "UNICODE_CI", 0); + assertFindInSet("a🙃", UTF8String.fromString("a🙃,b,🙃c"), "UTF8_BINARY", 1); + assertFindInSet("a🙃", UTF8String.fromString("a🙃,b,🙃c"), "UTF8_LCASE", 1); + assertFindInSet("a🙃", UTF8String.fromString("a🙃,b,🙃c"), "UNICODE", 1); + assertFindInSet("a🙃", UTF8String.fromString("a🙃,b,🙃c"), "UNICODE_CI", 1); + assertFindInSet("b", UTF8String.fromString("a🙃,b,🙃c"), "UTF8_BINARY", 2); + assertFindInSet("b", UTF8String.fromString("a🙃,b,🙃c"), "UTF8_LCASE", 2); + assertFindInSet("b", UTF8String.fromString("a🙃,b,🙃c"), "UNICODE", 2); + assertFindInSet("b", UTF8String.fromString("a🙃,b,🙃c"), "UNICODE_CI", 2); + assertFindInSet("🙃c", UTF8String.fromString("a🙃,b,🙃c"), "UTF8_BINARY", 3); + assertFindInSet("🙃c", UTF8String.fromString("a🙃,b,🙃c"), "UTF8_LCASE", 3); + assertFindInSet("🙃c", UTF8String.fromString("a🙃,b,🙃c"), "UNICODE", 3); + assertFindInSet("🙃c", UTF8String.fromString("a🙃,b,🙃c"), "UNICODE_CI", 3); + assertFindInSet("😄😆", UTF8String.fromString("😀😆,😃😄"), "UTF8_BINARY", 0); + assertFindInSet("😄😆", UTF8String.fromString("😀😆,😃😄"), "UTF8_LCASE", 0); + assertFindInSet("😄😆", UTF8String.fromString("😀😆,😃😄"), "UNICODE", 0); + assertFindInSet("😄😆", UTF8String.fromString("😀😆,😃😄"), "UNICODE_CI", 0); + assertFindInSet("😀😆", UTF8String.fromString("😀😆,😃😄"), "UTF8_BINARY", 1); + assertFindInSet("😀😆", UTF8String.fromString("😀😆,😃😄"), "UTF8_LCASE", 1); + assertFindInSet("😀😆", UTF8String.fromString("😀😆,😃😄"), "UNICODE", 1); + assertFindInSet("😀😆", UTF8String.fromString("😀😆,😃😄"), "UNICODE_CI", 1); + assertFindInSet("😃😄", UTF8String.fromString("😀😆,😃😄"), "UTF8_BINARY", 2); + assertFindInSet("😃😄", UTF8String.fromString("😀😆,😃😄"), "UTF8_LCASE", 2); + assertFindInSet("😃😄", UTF8String.fromString("😀😆,😃😄"), "UNICODE", 2); + assertFindInSet("😃😄", UTF8String.fromString("😀😆,😃😄"), "UNICODE_CI", 2); + assertFindInSet("x", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_BINARY", 0); + assertFindInSet("x", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_LCASE", 0); + assertFindInSet("x", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE", 0); + assertFindInSet("x", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE_CI", 0); + assertFindInSet("a", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_BINARY", 1); + assertFindInSet("a", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_LCASE", 1); + assertFindInSet("a", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE", 1); + assertFindInSet("a", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE_CI", 1); + assertFindInSet("A", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_BINARY", 0); + assertFindInSet("A", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_LCASE", 1); + assertFindInSet("A", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE", 0); + assertFindInSet("A", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE_CI", 1); + assertFindInSet("𝔸", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_BINARY", 3); + assertFindInSet("𝔸", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_LCASE", 3); + assertFindInSet("𝔸", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE", 3); + assertFindInSet("𝔸", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE_CI", 1); + assertFindInSet("𐐅", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_BINARY", 2); + assertFindInSet("𐐅", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_LCASE", 2); + assertFindInSet("𐐅", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE", 2); + assertFindInSet("𐐅", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE_CI", 2); + assertFindInSet("𐐭", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_BINARY", 0); + assertFindInSet("𐐭", UTF8String.fromString("a,𐐅,𝔸"), "UTF8_LCASE", 2); + assertFindInSet("𐐭", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE", 0); + assertFindInSet("𐐭", UTF8String.fromString("a,𐐅,𝔸"), "UNICODE_CI", 2); // Invalid UTF8 strings assertFindInSet("C", UTF8String.fromBytes( new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }), @@ -983,147 +1969,581 @@ public void testFindInSet() throws SparkException { "UNICODE_CI", 2); } - private void assertReplace(String source, String search, String replace, String collationName, - String expected) throws SparkException { - UTF8String src = UTF8String.fromString(source); + /** + * Verify the behaviour of the `StringReplace` collation support class. + */ + + private void assertStringReplace(String source, String search, String replace, + String collationName, String expected) throws SparkException { + UTF8String src = UTF8String.fromString(source); UTF8String sear = UTF8String.fromString(search); UTF8String repl = UTF8String.fromString(replace); int collationId = CollationFactory.collationNameToId(collationName); - assertEquals(expected, CollationSupport.StringReplace - .exec(src, sear, repl, collationId).toString()); + UTF8String result = CollationSupport.StringReplace.exec(src, sear, repl, collationId); + assertEquals(UTF8String.fromString(expected), result); } @Test - public void testReplace() throws SparkException { - assertReplace("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"); - assertReplace("replace", "pl", "", "UTF8_BINARY", "reace"); - assertReplace("repl世ace", "Pl", "", "UTF8_BINARY", "repl世ace"); - assertReplace("replace", "", "123", "UTF8_BINARY", "replace"); - assertReplace("abcabc", "b", "12", "UTF8_BINARY", "a12ca12c"); - assertReplace("abcdabcd", "bc", "", "UTF8_BINARY", "adad"); - assertReplace("r世eplace", "pl", "xx", "UTF8_LCASE", "r世exxace"); - assertReplace("repl世ace", "PL", "AB", "UTF8_LCASE", "reAB世ace"); - assertReplace("Replace", "", "123", "UTF8_LCASE", "Replace"); - assertReplace("re世place", "世", "x", "UTF8_LCASE", "rexplace"); - assertReplace("abcaBc", "B", "12", "UTF8_LCASE", "a12ca12c"); - assertReplace("AbcdabCd", "Bc", "", "UTF8_LCASE", "Adad"); - assertReplace("re世place", "plx", "123", "UNICODE", "re世place"); - assertReplace("世Replace", "re", "", "UNICODE", "世Replace"); - assertReplace("replace世", "", "123", "UNICODE", "replace世"); - assertReplace("aBc世abc", "b", "12", "UNICODE", "aBc世a12c"); - assertReplace("abcdabcd", "bc", "", "UNICODE", "adad"); - assertReplace("replace", "plx", "123", "UNICODE_CI", "replace"); - assertReplace("Replace", "re", "", "UNICODE_CI", "place"); - assertReplace("replace", "", "123", "UNICODE_CI", "replace"); - assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"); - assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad"); - assertReplace("abi̇12", "i", "X", "UNICODE_CI", "abi̇12"); - assertReplace("abi̇12", "\u0307", "X", "UNICODE_CI", "abi̇12"); - assertReplace("abi̇12", "İ", "X", "UNICODE_CI", "abX12"); - assertReplace("abİ12", "i", "X", "UNICODE_CI", "abİ12"); - assertReplace("İi̇İi̇İi̇", "i̇", "x", "UNICODE_CI", "xxxxxx"); - assertReplace("İi̇İi̇İi̇", "i", "x", "UNICODE_CI", "İi̇İi̇İi̇"); - assertReplace("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx"); - assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"); - assertReplace("abi̇12", "i", "X", "UTF8_LCASE", "abX\u030712"); // != UNICODE_CI - assertReplace("abi̇12", "\u0307", "X", "UTF8_LCASE", "abiX12"); // != UNICODE_CI - assertReplace("abi̇12", "İ", "X", "UTF8_LCASE", "abX12"); - assertReplace("abİ12", "i", "X", "UTF8_LCASE", "abİ12"); - assertReplace("İi̇İi̇İi̇", "i̇", "x", "UTF8_LCASE", "xxxxxx"); - assertReplace("İi̇İi̇İi̇", "i", "x", "UTF8_LCASE", + public void testStringReplace() throws SparkException { + // Empty strings. + assertStringReplace("", "", "", "UTF8_BINARY", ""); + assertStringReplace("", "", "", "UTF8_LCASE", ""); + assertStringReplace("", "", "", "UNICODE", ""); + assertStringReplace("", "", "", "UNICODE_CI", ""); + assertStringReplace("abc", "", "", "UTF8_BINARY", "abc"); + assertStringReplace("abc", "", "", "UTF8_LCASE", "abc"); + assertStringReplace("abc", "", "", "UNICODE", "abc"); + assertStringReplace("abc", "", "", "UNICODE_CI", "abc"); + assertStringReplace("", "x", "", "UTF8_BINARY", ""); + assertStringReplace("", "x", "", "UTF8_LCASE", ""); + assertStringReplace("", "x", "", "UNICODE", ""); + assertStringReplace("", "x", "", "UNICODE_CI", ""); + assertStringReplace("", "", "x", "UTF8_BINARY", ""); + assertStringReplace("", "", "x", "UTF8_LCASE", ""); + assertStringReplace("", "", "x", "UNICODE", ""); + assertStringReplace("", "", "x", "UNICODE_CI", ""); + assertStringReplace("", "b", "x", "UTF8_BINARY", ""); + assertStringReplace("", "b", "x", "UTF8_LCASE", ""); + assertStringReplace("", "b", "x", "UNICODE", ""); + assertStringReplace("", "b", "x", "UNICODE_CI", ""); + assertStringReplace("abc", "b", "", "UTF8_BINARY", "ac"); + assertStringReplace("abc", "b", "", "UTF8_LCASE", "ac"); + assertStringReplace("abc", "b", "", "UNICODE", "ac"); + assertStringReplace("abc", "b", "", "UNICODE_CI", "ac"); + assertStringReplace("abc", "", "x", "UTF8_BINARY", "abc"); + assertStringReplace("abc", "", "x", "UTF8_LCASE", "abc"); + assertStringReplace("abc", "", "x", "UNICODE", "abc"); + assertStringReplace("abc", "", "x", "UNICODE_CI", "abc"); + // Basic tests. + assertStringReplace("replace", "pl", "", "UTF8_BINARY", "reace"); + assertStringReplace("replace", "pl", "", "UTF8_LCASE", "reace"); + assertStringReplace("replace", "pl", "", "UNICODE", "reace"); + assertStringReplace("replace", "pl", "", "UNICODE_CI", "reace"); + assertStringReplace("replace", "", "123", "UTF8_BINARY", "replace"); + assertStringReplace("replace", "", "123", "UTF8_LCASE", "replace"); + assertStringReplace("replace", "", "123", "UNICODE", "replace"); + assertStringReplace("replace", "", "123", "UNICODE_CI", "replace"); + assertStringReplace("abcabc", "b", "12", "UTF8_BINARY", "a12ca12c"); + assertStringReplace("abcabc", "b", "12", "UTF8_LCASE", "a12ca12c"); + assertStringReplace("abcabc", "b", "12", "UNICODE", "a12ca12c"); + assertStringReplace("abcabc", "b", "12", "UNICODE_CI", "a12ca12c"); + assertStringReplace("replace", "plx", "123", "UTF8_BINARY", "replace"); + assertStringReplace("replace", "plx", "123", "UTF8_LCASE", "replace"); + assertStringReplace("replace", "plx", "123", "UNICODE", "replace"); + assertStringReplace("replace", "plx", "123", "UNICODE_CI", "replace"); + assertStringReplace("Replace", "re", "", "UTF8_BINARY", "Replace"); + assertStringReplace("Replace", "re", "", "UTF8_LCASE", "place"); + assertStringReplace("Replace", "re", "", "UNICODE", "Replace"); + assertStringReplace("Replace", "re", "", "UNICODE_CI", "place"); + assertStringReplace("abcdabcd", "Bc", "", "UTF8_BINARY", "abcdabcd"); + assertStringReplace("abcdabcd", "Bc", "", "UTF8_LCASE", "adad"); + assertStringReplace("abcdabcd", "Bc", "", "UNICODE", "abcdabcd"); + assertStringReplace("abcdabcd", "Bc", "", "UNICODE_CI", "adad"); + assertStringReplace("AbcdabCd", "Bc", "", "UTF8_BINARY", "AbcdabCd"); + assertStringReplace("AbcdabCd", "Bc", "", "UTF8_LCASE", "Adad"); + assertStringReplace("AbcdabCd", "Bc", "", "UNICODE", "AbcdabCd"); + assertStringReplace("AbcdabCd", "Bc", "", "UNICODE_CI", "Adad"); + // Advanced tests. + assertStringReplace("abcdabcd", "bc", "", "UTF8_BINARY", "adad"); + assertStringReplace("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"); + assertStringReplace("世Replace", "re", "", "UTF8_BINARY", "世Replace"); + assertStringReplace("r世eplace", "pl", "xx", "UTF8_LCASE", "r世exxace"); + assertStringReplace("repl世ace", "PL", "AB", "UTF8_LCASE", "reAB世ace"); + assertStringReplace("re世place", "世", "x", "UTF8_LCASE", "rexplace"); + assertStringReplace("re世place", "plx", "123", "UNICODE", "re世place"); + assertStringReplace("replace世", "", "123", "UNICODE", "replace世"); + assertStringReplace("aBc世abc", "b", "12", "UNICODE", "aBc世a12c"); + assertStringReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"); + assertStringReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad"); + assertStringReplace("repl世ace", "Pl", "", "UNICODE_CI", "re世ace"); + assertStringReplace("abcčšdabĆŠscd", "cs", "", "SR_CI_AI", "abcdabscd"); + // One-to-many case mapping (e.g. Turkish dotted I). + assertStringReplace("abi̇12", "i", "X", "UNICODE_CI", "abi̇12"); + assertStringReplace("abi̇12", "\u0307", "X", "UNICODE_CI", "abi̇12"); + assertStringReplace("abi̇12", "İ", "X", "UNICODE_CI", "abX12"); + assertStringReplace("abİ12", "i", "X", "UNICODE_CI", "abİ12"); + assertStringReplace("İi̇İi̇İi̇", "i\u0307", "x", "UNICODE_CI", "xxxxxx"); + assertStringReplace("İi̇İi̇İi̇", "i", "x", "UNICODE_CI", "İi̇İi̇İi̇"); + assertStringReplace("abİo12i̇o", "i\u0307o", "xx", "UNICODE_CI", "abxx12xx"); + assertStringReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"); + assertStringReplace("abi̇12", "i", "X", "UTF8_LCASE", "abX\u030712"); // != UNICODE_CI + assertStringReplace("abi̇12", "\u0307", "X", "UTF8_LCASE", "abiX12"); // != UNICODE_CI + assertStringReplace("abi̇12", "İ", "X", "UTF8_LCASE", "abX12"); + assertStringReplace("abİ12", "i", "X", "UTF8_LCASE", "abİ12"); + assertStringReplace("İi̇İi̇İi̇", "i\u0307", "x", "UTF8_LCASE", "xxxxxx"); + assertStringReplace("İi̇İi̇İi̇", "i", "x", "UTF8_LCASE", "İx\u0307İx\u0307İx\u0307"); // != UNICODE_CI - assertReplace("abİo12i̇o", "i̇o", "xx", "UTF8_LCASE", "abxx12xx"); - assertReplace("abi̇o12i̇o", "İo", "yy", "UTF8_LCASE", "abyy12yy"); + assertStringReplace("abİo12i̇o", "i\u0307o", "xx", "UTF8_LCASE", "abxx12xx"); + assertStringReplace("abi̇o12i̇o", "İo", "yy", "UTF8_LCASE", "abyy12yy"); + // Conditional case mapping (e.g. Greek sigmas). + assertStringReplace("σ", "σ", "x", "UTF8_BINARY", "x"); + assertStringReplace("σ", "ς", "x", "UTF8_BINARY", "σ"); + assertStringReplace("σ", "Σ", "x", "UTF8_BINARY", "σ"); + assertStringReplace("ς", "σ", "x", "UTF8_BINARY", "ς"); + assertStringReplace("ς", "ς", "x", "UTF8_BINARY", "x"); + assertStringReplace("ς", "Σ", "x", "UTF8_BINARY", "ς"); + assertStringReplace("Σ", "σ", "x", "UTF8_BINARY", "Σ"); + assertStringReplace("Σ", "ς", "x", "UTF8_BINARY", "Σ"); + assertStringReplace("Σ", "Σ", "x", "UTF8_BINARY", "x"); + assertStringReplace("σ", "σ", "x", "UTF8_LCASE", "x"); + assertStringReplace("σ", "ς", "x", "UTF8_LCASE", "x"); + assertStringReplace("σ", "Σ", "x", "UTF8_LCASE", "x"); + assertStringReplace("ς", "σ", "x", "UTF8_LCASE", "x"); + assertStringReplace("ς", "ς", "x", "UTF8_LCASE", "x"); + assertStringReplace("ς", "Σ", "x", "UTF8_LCASE", "x"); + assertStringReplace("Σ", "σ", "x", "UTF8_LCASE", "x"); + assertStringReplace("Σ", "ς", "x", "UTF8_LCASE", "x"); + assertStringReplace("Σ", "Σ", "x", "UTF8_LCASE", "x"); + assertStringReplace("σ", "σ", "x", "UNICODE", "x"); + assertStringReplace("σ", "ς", "x", "UNICODE", "σ"); + assertStringReplace("σ", "Σ", "x", "UNICODE", "σ"); + assertStringReplace("ς", "σ", "x", "UNICODE", "ς"); + assertStringReplace("ς", "ς", "x", "UNICODE", "x"); + assertStringReplace("ς", "Σ", "x", "UNICODE", "ς"); + assertStringReplace("Σ", "σ", "x", "UNICODE", "Σ"); + assertStringReplace("Σ", "ς", "x", "UNICODE", "Σ"); + assertStringReplace("Σ", "Σ", "x", "UNICODE", "x"); + assertStringReplace("σ", "σ", "x", "UNICODE_CI", "x"); + assertStringReplace("σ", "ς", "x", "UNICODE_CI", "x"); + assertStringReplace("σ", "Σ", "x", "UNICODE_CI", "x"); + assertStringReplace("ς", "σ", "x", "UNICODE_CI", "x"); + assertStringReplace("ς", "ς", "x", "UNICODE_CI", "x"); + assertStringReplace("ς", "Σ", "x", "UNICODE_CI", "x"); + assertStringReplace("Σ", "σ", "x", "UNICODE_CI", "x"); + assertStringReplace("Σ", "ς", "x", "UNICODE_CI", "x"); + assertStringReplace("Σ", "Σ", "x", "UNICODE_CI", "x"); + // Surrogate pairs. + assertStringReplace("a🙃b", "a", "x", "UTF8_BINARY", "x🙃b"); + assertStringReplace("a🙃b", "b", "x", "UTF8_BINARY", "a🙃x"); + assertStringReplace("a🙃b", "🙃", "x", "UTF8_BINARY", "axb"); + assertStringReplace("a🙃b", "b", "c", "UTF8_LCASE", "a🙃c"); + assertStringReplace("a🙃b", "b", "x", "UTF8_LCASE", "a🙃x"); + assertStringReplace("a🙃b", "🙃", "x", "UTF8_LCASE", "axb"); + assertStringReplace("a🙃b", "b", "c", "UNICODE", "a🙃c"); + assertStringReplace("a🙃b", "b", "x", "UNICODE", "a🙃x"); + assertStringReplace("a🙃b", "🙃", "x", "UNICODE", "axb"); + assertStringReplace("a🙃b", "b", "c", "UNICODE_CI", "a🙃c"); + assertStringReplace("a🙃b", "b", "x", "UNICODE_CI", "a🙃x"); + assertStringReplace("a🙃b", "🙃", "x", "UNICODE_CI", "axb"); } - private void assertLocate(String substring, String string, Integer start, String collationName, - Integer expected) throws SparkException { + /** + * Verify the behaviour of the `StringLocate` collation support class. + */ + + private void assertStringLocate(String substring, String string, int start, + String collationName, int expected) throws SparkException { + // Note: When using start < 1, be careful to understand the behavior of the `indexOf` + // method and the implications of using `indexOf` in the `StringLocate` case class. UTF8String substr = UTF8String.fromString(substring); UTF8String str = UTF8String.fromString(string); int collationId = CollationFactory.collationNameToId(collationName); - assertEquals(expected, CollationSupport.StringLocate.exec(str, substr, - start - 1, collationId) + 1); + int result = CollationSupport.StringLocate.exec(str, substr, start - 1, collationId) + 1; + assertEquals(expected, result); } @Test - public void testLocate() throws SparkException { - // If you add tests with start < 1 be careful to understand the behavior of the indexOf method - // and usage of indexOf in the StringLocate class. - assertLocate("aa", "aaads", 1, "UTF8_BINARY", 1); - assertLocate("aa", "aaads", 2, "UTF8_BINARY", 2); - assertLocate("aa", "aaads", 3, "UTF8_BINARY", 0); - assertLocate("Aa", "aaads", 1, "UTF8_BINARY", 0); - assertLocate("Aa", "aAads", 1, "UTF8_BINARY", 2); - assertLocate("界x", "test大千世界X大千世界", 1, "UTF8_BINARY", 0); - assertLocate("界X", "test大千世界X大千世界", 1, "UTF8_BINARY", 8); - assertLocate("界", "test大千世界X大千世界", 13, "UTF8_BINARY", 13); - assertLocate("AA", "aaads", 1, "UTF8_LCASE", 1); - assertLocate("aa", "aAads", 2, "UTF8_LCASE", 2); - assertLocate("aa", "aaAds", 3, "UTF8_LCASE", 0); - assertLocate("abC", "abcabc", 1, "UTF8_LCASE", 1); - assertLocate("abC", "abCabc", 2, "UTF8_LCASE", 4); - assertLocate("abc", "abcabc", 4, "UTF8_LCASE", 4); - assertLocate("界x", "test大千世界X大千世界", 1, "UTF8_LCASE", 8); - assertLocate("界X", "test大千世界Xtest大千世界", 1, "UTF8_LCASE", 8); - assertLocate("界", "test大千世界X大千世界", 13, "UTF8_LCASE", 13); - assertLocate("大千", "test大千世界大千世界", 1, "UTF8_LCASE", 5); - assertLocate("大千", "test大千世界大千世界", 9, "UTF8_LCASE", 9); - assertLocate("大千", "大千世界大千世界", 1, "UTF8_LCASE", 1); - assertLocate("aa", "Aaads", 1, "UNICODE", 2); - assertLocate("AA", "aaads", 1, "UNICODE", 0); - assertLocate("aa", "aAads", 2, "UNICODE", 0); - assertLocate("aa", "aaAds", 3, "UNICODE", 0); - assertLocate("abC", "abcabc", 1, "UNICODE", 0); - assertLocate("abC", "abCabc", 2, "UNICODE", 0); - assertLocate("abC", "abCabC", 2, "UNICODE", 4); - assertLocate("abc", "abcabc", 1, "UNICODE", 1); - assertLocate("abc", "abcabc", 3, "UNICODE", 4); - assertLocate("界x", "test大千世界X大千世界", 1, "UNICODE", 0); - assertLocate("界X", "test大千世界X大千世界", 1, "UNICODE", 8); - assertLocate("界", "test大千世界X大千世界", 13, "UNICODE", 13); - assertLocate("AA", "aaads", 1, "UNICODE_CI", 1); - assertLocate("aa", "aAads", 2, "UNICODE_CI", 2); - assertLocate("aa", "aaAds", 3, "UNICODE_CI", 0); - assertLocate("abC", "abcabc", 1, "UNICODE_CI", 1); - assertLocate("abC", "abCabc", 2, "UNICODE_CI", 4); - assertLocate("abc", "abcabc", 4, "UNICODE_CI", 4); - assertLocate("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8); - assertLocate("界", "test大千世界X大千世界", 13, "UNICODE_CI", 13); - assertLocate("大千", "test大千世界大千世界", 1, "UNICODE_CI", 5); - assertLocate("大千", "test大千世界大千世界", 9, "UNICODE_CI", 9); - assertLocate("大千", "大千世界大千世界", 1, "UNICODE_CI", 1); - // Case-variable character length - assertLocate("\u0307", "i̇", 1, "UTF8_BINARY", 2); - assertLocate("\u0307", "İ", 1, "UTF8_LCASE", 0); // != UTF8_BINARY - assertLocate("i", "i̇", 1, "UNICODE_CI", 0); - assertLocate("\u0307", "i̇", 1, "UNICODE_CI", 0); - assertLocate("i̇", "i", 1, "UNICODE_CI", 0); - assertLocate("İ", "i̇", 1, "UNICODE_CI", 1); - assertLocate("İ", "i", 1, "UNICODE_CI", 0); - assertLocate("i", "i̇", 1, "UTF8_LCASE", 1); // != UNICODE_CI - assertLocate("\u0307", "i̇", 1, "UTF8_LCASE", 2); // != UNICODE_CI - assertLocate("i̇", "i", 1, "UTF8_LCASE", 0); - assertLocate("İ", "i̇", 1, "UTF8_LCASE", 1); - assertLocate("İ", "i", 1, "UTF8_LCASE", 0); - assertLocate("i̇o", "İo世界大千世界", 1, "UNICODE_CI", 1); - assertLocate("i̇o", "大千İo世界大千世界", 1, "UNICODE_CI", 3); - assertLocate("i̇o", "世界İo大千世界大千İo", 4, "UNICODE_CI", 11); - assertLocate("İo", "i̇o世界大千世界", 1, "UNICODE_CI", 1); - assertLocate("İo", "大千i̇o世界大千世界", 1, "UNICODE_CI", 3); - assertLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); + public void testStringLocate() throws SparkException { + // Empty strings. + assertStringLocate("", "", -1, "UTF8_BINARY", 1); + assertStringLocate("", "", -1, "UTF8_LCASE", 1); + assertStringLocate("", "", -1, "UNICODE", 1); + assertStringLocate("", "", -1, "UNICODE_CI", 1); + assertStringLocate("", "", 0, "UTF8_BINARY", 1); + assertStringLocate("", "", 0, "UTF8_LCASE", 1); + assertStringLocate("", "", 0, "UNICODE", 1); + assertStringLocate("", "", 0, "UNICODE_CI", 1); + assertStringLocate("", "", 1, "UTF8_BINARY", 1); + assertStringLocate("", "", 1, "UTF8_LCASE", 1); + assertStringLocate("", "", 1, "UNICODE", 1); + assertStringLocate("", "", 1, "UNICODE_CI", 1); + assertStringLocate("a", "", -1, "UTF8_BINARY", 0); + assertStringLocate("a", "", -1, "UTF8_LCASE", 0); + assertStringLocate("a", "", -1, "UNICODE", 0); + assertStringLocate("a", "", -1, "UNICODE_CI", 0); + assertStringLocate("a", "", 0, "UTF8_BINARY", 0); + assertStringLocate("a", "", 0, "UTF8_LCASE", 0); + assertStringLocate("a", "", 0, "UNICODE", 0); + assertStringLocate("a", "", 0, "UNICODE_CI", 0); + assertStringLocate("a", "", 1, "UTF8_BINARY", 0); + assertStringLocate("a", "", 1, "UTF8_LCASE", 0); + assertStringLocate("a", "", 1, "UNICODE", 0); + assertStringLocate("a", "", 1, "UNICODE_CI", 0); + assertStringLocate("", "x", -1, "UTF8_BINARY", 1); + assertStringLocate("", "x", -1, "UTF8_LCASE", 1); + assertStringLocate("", "x", -1, "UNICODE", 1); + assertStringLocate("", "x", -1, "UNICODE_CI", 1); + assertStringLocate("", "x", 0, "UTF8_BINARY", 1); + assertStringLocate("", "x", 0, "UTF8_LCASE", 1); + assertStringLocate("", "x", 0, "UNICODE", 1); + assertStringLocate("", "x", 0, "UNICODE_CI", 1); + assertStringLocate("", "x", 1, "UTF8_BINARY", 1); + assertStringLocate("", "x", 1, "UTF8_LCASE", 1); + assertStringLocate("", "x", 1, "UNICODE", 1); + assertStringLocate("", "x", 1, "UNICODE_CI", 1); + // Basic tests. + assertStringLocate("aa", "aaads", 1, "UTF8_BINARY", 1); + assertStringLocate("aa", "aaads", 1, "UTF8_LCASE", 1); + assertStringLocate("aa", "aaads", 1, "UNICODE", 1); + assertStringLocate("aa", "aaads", 1, "UNICODE_CI", 1); + assertStringLocate("aa", "aaads", 2, "UTF8_BINARY", 2); + assertStringLocate("aa", "aaads", 2, "UTF8_LCASE", 2); + assertStringLocate("aa", "aaads", 2, "UNICODE", 2); + assertStringLocate("aa", "aaads", 2, "UNICODE_CI", 2); + assertStringLocate("aa", "aaads", 3, "UTF8_BINARY", 0); + assertStringLocate("aa", "aaads", 3, "UTF8_LCASE", 0); + assertStringLocate("aa", "aaads", 3, "UNICODE", 0); + assertStringLocate("aa", "aaads", 3, "UNICODE_CI", 0); + assertStringLocate("Aa", "aaads", 1, "UTF8_BINARY", 0); + assertStringLocate("Aa", "aaads", 1, "UTF8_LCASE", 1); + assertStringLocate("Aa", "aaads", 1, "UNICODE", 0); + assertStringLocate("Aa", "aaads", 1, "UNICODE_CI", 1); + assertStringLocate("Aa", "aaads", 2, "UTF8_BINARY", 0); + assertStringLocate("Aa", "aaads", 2, "UTF8_LCASE", 2); + assertStringLocate("Aa", "aaads", 2, "UNICODE", 0); + assertStringLocate("Aa", "aaads", 2, "UNICODE_CI", 2); + assertStringLocate("Aa", "aaads", 3, "UTF8_BINARY", 0); + assertStringLocate("Aa", "aaads", 3, "UTF8_LCASE", 0); + assertStringLocate("Aa", "aaads", 3, "UNICODE", 0); + assertStringLocate("Aa", "aaads", 3, "UNICODE_CI", 0); + assertStringLocate("Aa", "aAads", 1, "UTF8_BINARY", 2); + assertStringLocate("Aa", "aAads", 1, "UTF8_LCASE", 1); + assertStringLocate("Aa", "aAads", 1, "UNICODE", 2); + assertStringLocate("Aa", "aAads", 1, "UNICODE_CI", 1); + assertStringLocate("AA", "aaads", 1, "UTF8_BINARY", 0); + assertStringLocate("AA", "aaads", 1, "UTF8_LCASE", 1); + assertStringLocate("AA", "aaads", 1, "UNICODE", 0); + assertStringLocate("AA", "aaads", 1, "UNICODE_CI", 1); + assertStringLocate("aa", "aAads", 2, "UTF8_BINARY", 0); + assertStringLocate("aa", "aAads", 2, "UTF8_LCASE", 2); + assertStringLocate("aa", "aAads", 2, "UNICODE", 0); + assertStringLocate("aa", "aAads", 2, "UNICODE_CI", 2); + assertStringLocate("aa", "aaAds", 3, "UTF8_BINARY", 0); + assertStringLocate("aa", "aaAds", 3, "UTF8_LCASE", 0); + assertStringLocate("aa", "aaAds", 3, "UNICODE", 0); + assertStringLocate("aa", "aaAds", 3, "UNICODE_CI", 0); + assertStringLocate("abC", "abcabc", 1, "UTF8_BINARY", 0); + assertStringLocate("abC", "abcabc", 1, "UTF8_LCASE", 1); + assertStringLocate("abC", "abcabc", 1, "UNICODE", 0); + assertStringLocate("abC", "abcabc", 1, "UNICODE_CI", 1); + assertStringLocate("abC", "abCabc", 2, "UTF8_BINARY", 0); + assertStringLocate("abC", "abCabc", 2, "UTF8_LCASE", 4); + assertStringLocate("abC", "abCabc", 2, "UNICODE", 0); + assertStringLocate("abC", "abCabc", 2, "UNICODE_CI", 4); + assertStringLocate("abc", "abcabc", 1, "UTF8_BINARY", 1); + assertStringLocate("abc", "abcabc", 1, "UTF8_LCASE", 1); + assertStringLocate("abc", "abcabc", 1, "UNICODE", 1); + assertStringLocate("abc", "abcabc", 1, "UNICODE_CI", 1); + assertStringLocate("abc", "abcabc", 2, "UTF8_BINARY", 4); + assertStringLocate("abc", "abcabc", 2, "UTF8_LCASE", 4); + assertStringLocate("abc", "abcabc", 2, "UNICODE", 4); + assertStringLocate("abc", "abcabc", 2, "UNICODE_CI", 4); + assertStringLocate("abc", "abcabc", 3, "UTF8_BINARY", 4); + assertStringLocate("abc", "abcabc", 3, "UTF8_LCASE", 4); + assertStringLocate("abc", "abcabc", 3, "UNICODE", 4); + assertStringLocate("abc", "abcabc", 3, "UNICODE_CI", 4); + assertStringLocate("abc", "abcabc", 4, "UTF8_BINARY", 4); + assertStringLocate("abc", "abcabc", 4, "UTF8_LCASE", 4); + assertStringLocate("abc", "abcabc", 4, "UNICODE", 4); + assertStringLocate("abc", "abcabc", 4, "UNICODE_CI", 4); + assertStringLocate("aa", "Aaads", 1, "UTF8_BINARY", 2); + assertStringLocate("aa", "Aaads", 1, "UTF8_LCASE", 1); + assertStringLocate("aa", "Aaads", 1, "UNICODE", 2); + assertStringLocate("aa", "Aaads", 1, "UNICODE_CI", 1); + assertStringLocate("ćČ", "CćČČćCČĆČcČcććČč", 3, "SR", 14); + assertStringLocate("ćČ", "CćČČćCČĆČcČcććČč", 3, "SR_CI_AI", 3); + // Advanced tests. + assertStringLocate("界x", "test大千世界X大千世界", 1, "UTF8_BINARY", 0); + assertStringLocate("界X", "test大千世界X大千世界", 1, "UTF8_BINARY", 8); + assertStringLocate("界", "test大千世界X大千世界", 13, "UTF8_BINARY", 13); + assertStringLocate("界x", "test大千世界X大千世界", 1, "UTF8_LCASE", 8); + assertStringLocate("界X", "test大千世界Xtest大千世界", 1, "UTF8_LCASE", 8); + assertStringLocate("界", "test大千世界X大千世界", 13, "UTF8_LCASE", 13); + assertStringLocate("大千", "test大千世界大千世界", 1, "UTF8_LCASE", 5); + assertStringLocate("大千", "test大千世界大千世界", 9, "UTF8_LCASE", 9); + assertStringLocate("大千", "大千世界大千世界", 1, "UTF8_LCASE", 1); + assertStringLocate("界x", "test大千世界X大千世界", 1, "UNICODE", 0); + assertStringLocate("界X", "test大千世界X大千世界", 1, "UNICODE", 8); + assertStringLocate("界", "test大千世界X大千世界", 13, "UNICODE", 13); + assertStringLocate("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8); + assertStringLocate("界", "test大千世界X大千世界", 13, "UNICODE_CI", 13); + assertStringLocate("大千", "test大千世界大千世界", 1, "UNICODE_CI", 5); + assertStringLocate("大千", "test大千世界大千世界", 9, "UNICODE_CI", 9); + assertStringLocate("大千", "大千世界大千世界", 1, "UNICODE_CI", 1); + // One-to-many case mapping (e.g. Turkish dotted I). + assertStringLocate("\u0307", "i\u0307", 1, "UTF8_BINARY", 2); + assertStringLocate("\u0307", "İ", 1, "UTF8_LCASE", 0); // != UTF8_BINARY + assertStringLocate("i", "i\u0307", 1, "UNICODE_CI", 0); + assertStringLocate("\u0307", "i\u0307", 1, "UNICODE_CI", 0); + assertStringLocate("i\u0307", "i", 1, "UNICODE_CI", 0); + assertStringLocate("İ", "i\u0307", 1, "UNICODE_CI", 1); + assertStringLocate("İ", "i", 1, "UNICODE_CI", 0); + assertStringLocate("i", "i\u0307", 1, "UTF8_LCASE", 1); // != UNICODE_CI + assertStringLocate("\u0307", "i\u0307", 1, "UTF8_LCASE", 2); // != UNICODE_CI + assertStringLocate("i\u0307", "i", 1, "UTF8_LCASE", 0); + assertStringLocate("İ", "i\u0307", 1, "UTF8_LCASE", 1); + assertStringLocate("İ", "i", 1, "UTF8_LCASE", 0); + assertStringLocate("i\u0307o", "İo世界大千世界", 1, "UNICODE_CI", 1); + assertStringLocate("i\u0307o", "大千İo世界大千世界", 1, "UNICODE_CI", 3); + assertStringLocate("i\u0307o", "世界İo大千世界大千İo", 4, "UNICODE_CI", 11); + assertStringLocate("İo", "i̇o世界大千世界", 1, "UNICODE_CI", 1); + assertStringLocate("İo", "大千i̇o世界大千世界", 1, "UNICODE_CI", 3); + assertStringLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); + // Conditional case mapping (e.g. Greek sigmas). + assertStringLocate("σ", "σ", 1, "UTF8_BINARY", 1); + assertStringLocate("σ", "ς", 1, "UTF8_BINARY", 0); + assertStringLocate("σ", "Σ", 1, "UTF8_BINARY", 0); + assertStringLocate("ς", "σ", 1, "UTF8_BINARY", 0); + assertStringLocate("ς", "ς", 1, "UTF8_BINARY", 1); + assertStringLocate("ς", "Σ", 1, "UTF8_BINARY", 0); + assertStringLocate("Σ", "σ", 1, "UTF8_BINARY", 0); + assertStringLocate("Σ", "ς", 1, "UTF8_BINARY", 0); + assertStringLocate("Σ", "Σ", 1, "UTF8_BINARY", 1); + assertStringLocate("σ", "σ", 1, "UTF8_LCASE", 1); + assertStringLocate("σ", "ς", 1, "UTF8_LCASE", 1); + assertStringLocate("σ", "Σ", 1, "UTF8_LCASE", 1); + assertStringLocate("ς", "σ", 1, "UTF8_LCASE", 1); + assertStringLocate("ς", "ς", 1, "UTF8_LCASE", 1); + assertStringLocate("ς", "Σ", 1, "UTF8_LCASE", 1); + assertStringLocate("Σ", "σ", 1, "UTF8_LCASE", 1); + assertStringLocate("Σ", "ς", 1, "UTF8_LCASE", 1); + assertStringLocate("Σ", "Σ", 1, "UTF8_LCASE", 1); + assertStringLocate("σ", "σ", 1, "UNICODE", 1); + assertStringLocate("σ", "ς", 1, "UNICODE", 0); + assertStringLocate("σ", "Σ", 1, "UNICODE", 0); + assertStringLocate("ς", "σ", 1, "UNICODE", 0); + assertStringLocate("ς", "ς", 1, "UNICODE", 1); + assertStringLocate("ς", "Σ", 1, "UNICODE", 0); + assertStringLocate("Σ", "σ", 1, "UNICODE", 0); + assertStringLocate("Σ", "ς", 1, "UNICODE", 0); + assertStringLocate("Σ", "Σ", 1, "UNICODE", 1); + assertStringLocate("σ", "σ", 1, "UNICODE_CI", 1); + assertStringLocate("σ", "ς", 1, "UNICODE_CI", 1); + assertStringLocate("σ", "Σ", 1, "UNICODE_CI", 1); + assertStringLocate("ς", "σ", 1, "UNICODE_CI", 1); + assertStringLocate("ς", "ς", 1, "UNICODE_CI", 1); + assertStringLocate("ς", "Σ", 1, "UNICODE_CI", 1); + assertStringLocate("Σ", "σ", 1, "UNICODE_CI", 1); + assertStringLocate("Σ", "ς", 1, "UNICODE_CI", 1); + assertStringLocate("Σ", "Σ", 1, "UNICODE_CI", 1); + // Surrogate pairs. + assertStringLocate("a", "a🙃b", 1, "UTF8_BINARY", 1); + assertStringLocate("a", "a🙃b", 1, "UTF8_LCASE", 1); + assertStringLocate("a", "a🙃b", 1, "UNICODE", 1); + assertStringLocate("a", "a🙃b", 1, "UNICODE_CI", 1); + assertStringLocate("a", "a🙃b", 2, "UTF8_BINARY", 0); + assertStringLocate("a", "a🙃b", 2, "UTF8_LCASE", 0); + assertStringLocate("a", "a🙃b", 2, "UNICODE", 0); + assertStringLocate("a", "a🙃b", 2, "UNICODE_CI", 0); + assertStringLocate("a", "a🙃b", 3, "UTF8_BINARY", 0); + assertStringLocate("a", "a🙃b", 3, "UTF8_LCASE", 0); + assertStringLocate("a", "a🙃b", 3, "UNICODE", 0); + assertStringLocate("a", "a🙃b", 3, "UNICODE_CI", 0); + assertStringLocate("🙃", "a🙃b", 1, "UTF8_BINARY", 2); + assertStringLocate("🙃", "a🙃b", 1, "UTF8_LCASE", 2); + assertStringLocate("🙃", "a🙃b", 1, "UNICODE", 2); + assertStringLocate("🙃", "a🙃b", 1, "UNICODE_CI", 2); + assertStringLocate("🙃", "a🙃b", 2, "UTF8_BINARY", 2); + assertStringLocate("🙃", "a🙃b", 2, "UTF8_LCASE", 2); + assertStringLocate("🙃", "a🙃b", 2, "UNICODE", 2); + assertStringLocate("🙃", "a🙃b", 2, "UNICODE_CI", 2); + assertStringLocate("🙃", "a🙃b", 3, "UTF8_BINARY", 0); + assertStringLocate("🙃", "a🙃b", 3, "UTF8_LCASE", 0); + assertStringLocate("🙃", "a🙃b", 3, "UNICODE", 0); + assertStringLocate("🙃", "a🙃b", 3, "UNICODE_CI", 0); + assertStringLocate("b", "a🙃b", 1, "UTF8_BINARY", 3); + assertStringLocate("b", "a🙃b", 1, "UTF8_LCASE", 3); + assertStringLocate("b", "a🙃b", 1, "UNICODE", 3); + assertStringLocate("b", "a🙃b", 1, "UNICODE_CI", 3); + assertStringLocate("b", "a🙃b", 2, "UTF8_BINARY", 3); + assertStringLocate("b", "a🙃b", 2, "UTF8_LCASE", 3); + assertStringLocate("b", "a🙃b", 2, "UNICODE", 3); + assertStringLocate("b", "a🙃b", 2, "UNICODE_CI", 3); + assertStringLocate("b", "a🙃b", 3, "UTF8_BINARY", 3); + assertStringLocate("b", "a🙃b", 3, "UTF8_LCASE", 3); + assertStringLocate("b", "a🙃b", 3, "UNICODE", 3); + assertStringLocate("b", "a🙃b", 3, "UNICODE_CI", 3); + assertStringLocate("🙃", "a🙃🙃b", 1, "UTF8_BINARY", 2); + assertStringLocate("🙃", "a🙃🙃b", 1, "UTF8_LCASE", 2); + assertStringLocate("🙃", "a🙃🙃b", 1, "UNICODE", 2); + assertStringLocate("🙃", "a🙃🙃b", 1, "UNICODE_CI", 2); + assertStringLocate("🙃", "a🙃🙃b", 2, "UTF8_BINARY", 2); + assertStringLocate("🙃", "a🙃🙃b", 2, "UTF8_LCASE", 2); + assertStringLocate("🙃", "a🙃🙃b", 2, "UNICODE", 2); + assertStringLocate("🙃", "a🙃🙃b", 2, "UNICODE_CI", 2); + assertStringLocate("🙃", "a🙃🙃b", 3, "UTF8_BINARY", 3); + assertStringLocate("🙃", "a🙃🙃b", 3, "UTF8_LCASE", 3); + assertStringLocate("🙃", "a🙃🙃b", 3, "UNICODE", 3); + assertStringLocate("🙃", "a🙃🙃b", 3, "UNICODE_CI", 3); + assertStringLocate("🙃", "a🙃🙃b", 4, "UTF8_BINARY", 0); + assertStringLocate("🙃", "a🙃🙃b", 4, "UTF8_LCASE", 0); + assertStringLocate("🙃", "a🙃🙃b", 4, "UNICODE", 0); + assertStringLocate("🙃", "a🙃🙃b", 4, "UNICODE_CI", 0); + assertStringLocate("b", "a🙃🙃b", 1, "UTF8_BINARY", 4); + assertStringLocate("b", "a🙃🙃b", 1, "UTF8_LCASE", 4); + assertStringLocate("b", "a🙃🙃b", 1, "UNICODE", 4); + assertStringLocate("b", "a🙃🙃b", 1, "UNICODE_CI", 4); + assertStringLocate("b", "a🙃🙃b", 2, "UTF8_BINARY", 4); + assertStringLocate("b", "a🙃🙃b", 2, "UTF8_LCASE", 4); + assertStringLocate("b", "a🙃🙃b", 2, "UNICODE", 4); + assertStringLocate("b", "a🙃🙃b", 2, "UNICODE_CI", 4); + assertStringLocate("b", "a🙃🙃b", 3, "UTF8_BINARY", 4); + assertStringLocate("b", "a🙃🙃b", 3, "UTF8_LCASE", 4); + assertStringLocate("b", "a🙃🙃b", 3, "UNICODE", 4); + assertStringLocate("b", "a🙃🙃b", 3, "UNICODE_CI", 4); + assertStringLocate("b", "a🙃🙃b", 4, "UTF8_BINARY", 4); + assertStringLocate("b", "a🙃🙃b", 4, "UTF8_LCASE", 4); + assertStringLocate("b", "a🙃🙃b", 4, "UNICODE", 4); + assertStringLocate("b", "a🙃🙃b", 4, "UNICODE_CI", 4); + assertStringLocate("b", "a🙃x🙃b", 1, "UTF8_BINARY", 5); + assertStringLocate("b", "a🙃x🙃b", 1, "UTF8_LCASE", 5); + assertStringLocate("b", "a🙃x🙃b", 1, "UNICODE", 5); + assertStringLocate("b", "a🙃x🙃b", 1, "UNICODE_CI", 5); + assertStringLocate("b", "a🙃x🙃b", 2, "UTF8_BINARY", 5); + assertStringLocate("b", "a🙃x🙃b", 2, "UTF8_LCASE", 5); + assertStringLocate("b", "a🙃x🙃b", 2, "UNICODE", 5); + assertStringLocate("b", "a🙃x🙃b", 2, "UNICODE_CI", 5); + assertStringLocate("b", "a🙃x🙃b", 3, "UTF8_BINARY", 5); + assertStringLocate("b", "a🙃x🙃b", 3, "UTF8_LCASE", 5); + assertStringLocate("b", "a🙃x🙃b", 3, "UNICODE", 5); + assertStringLocate("b", "a🙃x🙃b", 3, "UNICODE_CI", 5); + assertStringLocate("b", "a🙃x🙃b", 4, "UTF8_BINARY", 5); + assertStringLocate("b", "a🙃x🙃b", 4, "UTF8_LCASE", 5); + assertStringLocate("b", "a🙃x🙃b", 4, "UNICODE", 5); + assertStringLocate("b", "a🙃x🙃b", 4, "UNICODE_CI", 5); + // Out of bounds test cases. + assertStringLocate("a", "asd", 4, "UTF8_BINARY", 0); + assertStringLocate("a", "asd", 4, "UTF8_LCASE", 0); + assertStringLocate("a", "asd", 4, "UNICODE", 0); + assertStringLocate("a", "asd", 4, "UNICODE_CI", 0); + assertStringLocate("a", "asd", 100, "UTF8_BINARY", 0); + assertStringLocate("a", "asd", 100, "UTF8_LCASE", 0); + assertStringLocate("a", "asd", 100, "UNICODE", 0); + assertStringLocate("a", "asd", 100, "UNICODE_CI", 0); + assertStringLocate("a", "🙃🙃", 4, "UTF8_BINARY", 0); + assertStringLocate("a", "🙃🙃", 4, "UTF8_LCASE", 0); + assertStringLocate("a", "🙃🙃", 4, "UNICODE", 0); + assertStringLocate("a", "🙃🙃", 4, "UNICODE_CI", 0); + assertStringLocate("", "asd", 100, "UTF8_BINARY", 1); + assertStringLocate("", "asd", 100, "UTF8_LCASE", 1); + assertStringLocate("", "asd", 100, "UNICODE", 1); + assertStringLocate("", "asd", 100, "UNICODE_CI", 1); + assertStringLocate("asd", "", 100, "UTF8_BINARY", 0); + assertStringLocate("asd", "", 100, "UTF8_LCASE", 0); + assertStringLocate("asd", "", 100, "UNICODE", 0); + assertStringLocate("asd", "", 100, "UNICODE_CI", 0); } - private void assertSubstringIndex(String string, String delimiter, Integer count, - String collationName, String expected) throws SparkException { + /** + * Verify the behaviour of the `SubstringIndex` collation support class. + */ + + private void assertSubstringIndex(String string, String delimiter, int count, + String collationName, String expected) throws SparkException { UTF8String str = UTF8String.fromString(string); UTF8String delim = UTF8String.fromString(delimiter); int collationId = CollationFactory.collationNameToId(collationName); - assertEquals(expected, - CollationSupport.SubstringIndex.exec(str, delim, count, collationId).toString()); + UTF8String result = CollationSupport.SubstringIndex.exec(str, delim, count, collationId); + assertEquals(UTF8String.fromString(expected), result); } @Test public void testSubstringIndex() throws SparkException { + // Empty strings. + assertSubstringIndex("", "", 0, "UTF8_BINARY", ""); + assertSubstringIndex("", "", 0, "UTF8_LCASE", ""); + assertSubstringIndex("", "", 0, "UNICODE", ""); + assertSubstringIndex("", "", 0, "UNICODE_CI", ""); + assertSubstringIndex("", "", 1, "UTF8_BINARY", ""); + assertSubstringIndex("", "", 1, "UTF8_LCASE", ""); + assertSubstringIndex("", "", 1, "UNICODE", ""); + assertSubstringIndex("", "", 1, "UNICODE_CI", ""); + assertSubstringIndex("", "", -1, "UTF8_BINARY", ""); + assertSubstringIndex("", "", -1, "UTF8_LCASE", ""); + assertSubstringIndex("", "", -1, "UNICODE", ""); + assertSubstringIndex("", "", -1, "UNICODE_CI", ""); + assertSubstringIndex("", "x", 0, "UTF8_BINARY", ""); + assertSubstringIndex("", "x", 0, "UTF8_LCASE", ""); + assertSubstringIndex("", "x", 0, "UNICODE", ""); + assertSubstringIndex("", "x", 0, "UNICODE_CI", ""); + assertSubstringIndex("", "x", 1, "UTF8_BINARY", ""); + assertSubstringIndex("", "x", 1, "UTF8_LCASE", ""); + assertSubstringIndex("", "x", 1, "UNICODE", ""); + assertSubstringIndex("", "x", 1, "UNICODE_CI", ""); + assertSubstringIndex("", "x", -1, "UTF8_BINARY", ""); + assertSubstringIndex("", "x", -1, "UTF8_LCASE", ""); + assertSubstringIndex("", "x", -1, "UNICODE", ""); + assertSubstringIndex("", "x", -1, "UNICODE_CI", ""); + assertSubstringIndex("abc", "", 0, "UTF8_BINARY", ""); + assertSubstringIndex("abc", "", 0, "UTF8_LCASE", ""); + assertSubstringIndex("abc", "", 0, "UNICODE", ""); + assertSubstringIndex("abc", "", 0, "UNICODE_CI", ""); + assertSubstringIndex("abc", "", 1, "UTF8_BINARY", ""); + assertSubstringIndex("abc", "", 1, "UTF8_LCASE", ""); + assertSubstringIndex("abc", "", 1, "UNICODE", ""); + assertSubstringIndex("abc", "", 1, "UNICODE_CI", ""); + assertSubstringIndex("abc", "", -1, "UTF8_BINARY", ""); + assertSubstringIndex("abc", "", -1, "UTF8_LCASE", ""); + assertSubstringIndex("abc", "", -1, "UNICODE", ""); + assertSubstringIndex("abc", "", -1, "UNICODE_CI", ""); + // Basic tests. + assertSubstringIndex("axbxc", "a", 1, "UTF8_BINARY", ""); + assertSubstringIndex("axbxc", "a", 1, "UTF8_LCASE", ""); + assertSubstringIndex("axbxc", "a", 1, "UNICODE", ""); + assertSubstringIndex("axbxc", "a", 1, "UNICODE_CI", ""); + assertSubstringIndex("axbxc", "x", 1, "UTF8_BINARY", "a"); + assertSubstringIndex("axbxc", "x", 1, "UTF8_LCASE", "a"); + assertSubstringIndex("axbxc", "x", 1, "UNICODE", "a"); + assertSubstringIndex("axbxc", "x", 1, "UNICODE_CI", "a"); + assertSubstringIndex("axbxc", "b", 1, "UTF8_BINARY", "ax"); + assertSubstringIndex("axbxc", "b", 1, "UTF8_LCASE", "ax"); + assertSubstringIndex("axbxc", "b", 1, "UNICODE", "ax"); + assertSubstringIndex("axbxc", "b", 1, "UNICODE_CI", "ax"); + assertSubstringIndex("axbxc", "x", 2, "UTF8_BINARY", "axb"); + assertSubstringIndex("axbxc", "x", 2, "UTF8_LCASE", "axb"); + assertSubstringIndex("axbxc", "x", 2, "UNICODE", "axb"); + assertSubstringIndex("axbxc", "x", 2, "UNICODE_CI", "axb"); + assertSubstringIndex("axbxc", "c", 1, "UTF8_BINARY", "axbx"); + assertSubstringIndex("axbxc", "c", 1, "UTF8_LCASE", "axbx"); + assertSubstringIndex("axbxc", "c", 1, "UNICODE", "axbx"); + assertSubstringIndex("axbxc", "c", 1, "UNICODE_CI", "axbx"); + assertSubstringIndex("axbxc", "x", 3, "UTF8_BINARY", "axbxc"); + assertSubstringIndex("axbxc", "x", 3, "UTF8_LCASE", "axbxc"); + assertSubstringIndex("axbxc", "x", 3, "UNICODE", "axbxc"); + assertSubstringIndex("axbxc", "x", 3, "UNICODE_CI", "axbxc"); + assertSubstringIndex("axbxc", "d", 1, "UTF8_BINARY", "axbxc"); + assertSubstringIndex("axbxc", "d", 1, "UTF8_LCASE", "axbxc"); + assertSubstringIndex("axbxc", "d", 1, "UNICODE", "axbxc"); + assertSubstringIndex("axbxc", "d", 1, "UNICODE_CI", "axbxc"); + assertSubstringIndex("axbxc", "c", -1, "UTF8_BINARY", ""); + assertSubstringIndex("axbxc", "c", -1, "UTF8_LCASE", ""); + assertSubstringIndex("axbxc", "c", -1, "UNICODE", ""); + assertSubstringIndex("axbxc", "c", -1, "UNICODE_CI", ""); + assertSubstringIndex("axbxc", "x", -1, "UTF8_BINARY", "c"); + assertSubstringIndex("axbxc", "x", -1, "UTF8_LCASE", "c"); + assertSubstringIndex("axbxc", "x", -1, "UNICODE", "c"); + assertSubstringIndex("axbxc", "x", -1, "UNICODE_CI", "c"); + assertSubstringIndex("axbxc", "b", -1, "UTF8_BINARY", "xc"); + assertSubstringIndex("axbxc", "b", -1, "UTF8_LCASE", "xc"); + assertSubstringIndex("axbxc", "b", -1, "UNICODE", "xc"); + assertSubstringIndex("axbxc", "b", -1, "UNICODE_CI", "xc"); + assertSubstringIndex("axbxc", "x", -2, "UTF8_BINARY", "bxc"); + assertSubstringIndex("axbxc", "x", -2, "UTF8_LCASE", "bxc"); + assertSubstringIndex("axbxc", "x", -2, "UNICODE", "bxc"); + assertSubstringIndex("axbxc", "x", -2, "UNICODE_CI", "bxc"); + assertSubstringIndex("axbxc", "a", -1, "UTF8_BINARY", "xbxc"); + assertSubstringIndex("axbxc", "a", -1, "UTF8_LCASE", "xbxc"); + assertSubstringIndex("axbxc", "a", -1, "UNICODE", "xbxc"); + assertSubstringIndex("axbxc", "a", -1, "UNICODE_CI", "xbxc"); + assertSubstringIndex("axbxc", "x", -3, "UTF8_BINARY", "axbxc"); + assertSubstringIndex("axbxc", "x", -3, "UTF8_LCASE", "axbxc"); + assertSubstringIndex("axbxc", "x", -3, "UNICODE", "axbxc"); + assertSubstringIndex("axbxc", "x", -3, "UNICODE_CI", "axbxc"); + assertSubstringIndex("axbxc", "d", -1, "UTF8_BINARY", "axbxc"); + assertSubstringIndex("axbxc", "d", -1, "UTF8_LCASE", "axbxc"); + assertSubstringIndex("axbxc", "d", -1, "UNICODE", "axbxc"); + assertSubstringIndex("axbxc", "d", -1, "UNICODE_CI", "axbxc"); + // Advanced tests. assertSubstringIndex("wwwgapachegorg", "g", -3, "UTF8_BINARY", "apachegorg"); assertSubstringIndex("www||apache||org", "||", 2, "UTF8_BINARY", "www||apache"); assertSubstringIndex("aaaaaaaaaa", "aa", 2, "UTF8_BINARY", "a"); @@ -1182,8 +2602,10 @@ public void testSubstringIndex() throws SparkException { assertSubstringIndex("test大千世界X大千世界", "X", 1, "UNICODE_CI", "test大千世界"); assertSubstringIndex("test大千世界大千世界", "千", 2, "UNICODE_CI", "test大千世界大"); assertSubstringIndex("www||APACHE||org", "||", 2, "UNICODE_CI", "www||APACHE"); - assertSubstringIndex("abİo12", "i̇o", 1, "UNICODE_CI", "ab"); - assertSubstringIndex("abİo12", "i̇o", -1, "UNICODE_CI", "12"); + assertSubstringIndex("wwwèapacheËorg", "Ê", -3, "AF_CI_AI", "apacheËorg"); + // One-to-many case mapping (e.g. Turkish dotted I). + assertSubstringIndex("abİo12", "i\u0307o", 1, "UNICODE_CI", "ab"); + assertSubstringIndex("abİo12", "i\u0307o", -1, "UNICODE_CI", "12"); assertSubstringIndex("abi̇o12", "İo", 1, "UNICODE_CI", "ab"); assertSubstringIndex("abi̇o12", "İo", -1, "UNICODE_CI", "12"); assertSubstringIndex("ai̇bi̇o12", "İo", 1, "UNICODE_CI", "ai̇b"); @@ -1191,59 +2613,153 @@ public void testSubstringIndex() throws SparkException { assertSubstringIndex("ai̇bi̇o12i̇o", "İo", -1, "UNICODE_CI", ""); assertSubstringIndex("ai̇bi̇o12i̇o", "İo", -2, "UNICODE_CI", "12i̇o"); assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", -4, "UNICODE_CI", "İo12İoi̇o"); - assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", -4, "UNICODE_CI", "İo12İoi̇o"); + assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i\u0307o", -4, "UNICODE_CI", "İo12İoi̇o"); assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", -4, "UNICODE_CI", "i̇o12i̇oİo"); - assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", "i̇o12i̇oİo"); + assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i\u0307o", -4, "UNICODE_CI", "i̇o12i̇oİo"); assertSubstringIndex("abi̇12", "i", 1, "UNICODE_CI", "abi̇12"); assertSubstringIndex("abi̇12", "\u0307", 1, "UNICODE_CI", "abi̇12"); assertSubstringIndex("abi̇12", "İ", 1, "UNICODE_CI", "ab"); assertSubstringIndex("abİ12", "i", 1, "UNICODE_CI", "abİ12"); assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", -4, "UNICODE_CI", "İo12İoi̇o"); - assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", -4, "UNICODE_CI", "İo12İoi̇o"); + assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i\u0307o", -4, "UNICODE_CI", "İo12İoi̇o"); assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", -4, "UNICODE_CI", "i̇o12i̇oİo"); - assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", "i̇o12i̇oİo"); + assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i\u0307o", -4, "UNICODE_CI", "i̇o12i̇oİo"); assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", 3, "UNICODE_CI", "ai̇bi̇oİo12"); - assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", 3, "UNICODE_CI", "ai̇bi̇oİo12"); + assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i\u0307o", 3, "UNICODE_CI", "ai̇bi̇oİo12"); assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", 3, "UNICODE_CI", "ai̇bİoi̇o12"); - assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", 3, "UNICODE_CI", "ai̇bİoi̇o12"); + assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i\u0307o", 3, "UNICODE_CI", "ai̇bİoi̇o12"); assertSubstringIndex("abi̇12", "i", 1, "UTF8_LCASE", "ab"); // != UNICODE_CI assertSubstringIndex("abi̇12", "\u0307", 1, "UTF8_LCASE", "abi"); // != UNICODE_CI assertSubstringIndex("abi̇12", "İ", 1, "UTF8_LCASE", "ab"); assertSubstringIndex("abİ12", "i", 1, "UTF8_LCASE", "abİ12"); assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", -4, "UTF8_LCASE", "İo12İoi̇o"); - assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", -4, "UTF8_LCASE", "İo12İoi̇o"); + assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i\u0307o", -4, "UTF8_LCASE", "İo12İoi̇o"); assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", -4, "UTF8_LCASE", "i̇o12i̇oİo"); - assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UTF8_LCASE", "i̇o12i̇oİo"); + assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i\u0307o", -4, "UTF8_LCASE", "i̇o12i̇oİo"); assertSubstringIndex("bİoi̇o12i̇o", "\u0307oi", 1, "UTF8_LCASE", "bİoi̇o12i̇o"); assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", 3, "UTF8_LCASE", "ai̇bi̇oİo12"); - assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", 3, "UTF8_LCASE", "ai̇bi̇oİo12"); + assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i\u0307o", 3, "UTF8_LCASE", "ai̇bi̇oİo12"); assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", 3, "UTF8_LCASE", "ai̇bİoi̇o12"); - assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", 3, "UTF8_LCASE", "ai̇bİoi̇o12"); + assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i\u0307o", 3, "UTF8_LCASE", "ai̇bİoi̇o12"); assertSubstringIndex("bİoi̇o12i̇o", "\u0307oi", 1, "UTF8_LCASE", "bİoi̇o12i̇o"); + // Conditional case mapping (e.g. Greek sigmas). + assertSubstringIndex("σ", "σ", 1, "UTF8_BINARY", ""); + assertSubstringIndex("σ", "ς", 1, "UTF8_BINARY", "σ"); + assertSubstringIndex("σ", "Σ", 1, "UTF8_BINARY", "σ"); + assertSubstringIndex("ς", "σ", 1, "UTF8_BINARY", "ς"); + assertSubstringIndex("ς", "ς", 1, "UTF8_BINARY", ""); + assertSubstringIndex("ς", "Σ", 1, "UTF8_BINARY", "ς"); + assertSubstringIndex("Σ", "σ", 1, "UTF8_BINARY", "Σ"); + assertSubstringIndex("Σ", "ς", 1, "UTF8_BINARY", "Σ"); + assertSubstringIndex("Σ", "Σ", 1, "UTF8_BINARY", ""); + assertSubstringIndex("σ", "σ", 1, "UTF8_LCASE", ""); + assertSubstringIndex("σ", "ς", 1, "UTF8_LCASE", ""); + assertSubstringIndex("σ", "Σ", 1, "UTF8_LCASE", ""); + assertSubstringIndex("ς", "σ", 1, "UTF8_LCASE", ""); + assertSubstringIndex("ς", "ς", 1, "UTF8_LCASE", ""); + assertSubstringIndex("ς", "Σ", 1, "UTF8_LCASE", ""); + assertSubstringIndex("Σ", "σ", 1, "UTF8_LCASE", ""); + assertSubstringIndex("Σ", "ς", 1, "UTF8_LCASE", ""); + assertSubstringIndex("Σ", "Σ", 1, "UTF8_LCASE", ""); + assertSubstringIndex("σ", "σ", 1, "UNICODE", ""); + assertSubstringIndex("σ", "ς", 1, "UNICODE", "σ"); + assertSubstringIndex("σ", "Σ", 1, "UNICODE", "σ"); + assertSubstringIndex("ς", "σ", 1, "UNICODE", "ς"); + assertSubstringIndex("ς", "ς", 1, "UNICODE", ""); + assertSubstringIndex("ς", "Σ", 1, "UNICODE", "ς"); + assertSubstringIndex("Σ", "σ", 1, "UNICODE", "Σ"); + assertSubstringIndex("Σ", "ς", 1, "UNICODE", "Σ"); + assertSubstringIndex("Σ", "Σ", 1, "UNICODE", ""); + assertSubstringIndex("σ", "σ", 1, "UNICODE_CI", ""); + assertSubstringIndex("σ", "ς", 1, "UNICODE_CI", ""); + assertSubstringIndex("σ", "Σ", 1, "UNICODE_CI", ""); + assertSubstringIndex("ς", "σ", 1, "UNICODE_CI", ""); + assertSubstringIndex("ς", "ς", 1, "UNICODE_CI", ""); + assertSubstringIndex("ς", "Σ", 1, "UNICODE_CI", ""); + assertSubstringIndex("Σ", "σ", 1, "UNICODE_CI", ""); + assertSubstringIndex("Σ", "ς", 1, "UNICODE_CI", ""); + assertSubstringIndex("Σ", "Σ", 1, "UNICODE_CI", ""); + // Surrogate pairs. + assertSubstringIndex("a🙃b🙃c", "a", 1, "UTF8_BINARY", ""); + assertSubstringIndex("a🙃b🙃c", "a", 1, "UTF8_LCASE", ""); + assertSubstringIndex("a🙃b🙃c", "a", 1, "UNICODE", ""); + assertSubstringIndex("a🙃b🙃c", "a", 1, "UNICODE_CI", ""); + assertSubstringIndex("a🙃b🙃c", "🙃", 1, "UTF8_BINARY", "a"); + assertSubstringIndex("a🙃b🙃c", "🙃", 1, "UTF8_LCASE", "a"); + assertSubstringIndex("a🙃b🙃c", "🙃", 1, "UNICODE", "a"); + assertSubstringIndex("a🙃b🙃c", "🙃", 1, "UNICODE_CI", "a"); + assertSubstringIndex("a🙃b🙃c", "b", 1, "UTF8_BINARY", "a🙃"); + assertSubstringIndex("a🙃b🙃c", "b", 1, "UTF8_LCASE", "a🙃"); + assertSubstringIndex("a🙃b🙃c", "b", 1, "UNICODE", "a🙃"); + assertSubstringIndex("a🙃b🙃c", "b", 1, "UNICODE_CI", "a🙃"); + assertSubstringIndex("a🙃b🙃c", "🙃", 2, "UTF8_BINARY", "a🙃b"); + assertSubstringIndex("a🙃b🙃c", "🙃", 2, "UTF8_LCASE", "a🙃b"); + assertSubstringIndex("a🙃b🙃c", "🙃", 2, "UNICODE", "a🙃b"); + assertSubstringIndex("a🙃b🙃c", "🙃", 2, "UNICODE_CI", "a🙃b"); + assertSubstringIndex("a🙃b🙃c", "c", 1, "UTF8_BINARY", "a🙃b🙃"); + assertSubstringIndex("a🙃b🙃c", "c", 1, "UTF8_LCASE", "a🙃b🙃"); + assertSubstringIndex("a🙃b🙃c", "c", 1, "UNICODE", "a🙃b🙃"); + assertSubstringIndex("a🙃b🙃c", "c", 1, "UNICODE_CI", "a🙃b🙃"); + assertSubstringIndex("a🙃b🙃c", "🙃", 3, "UTF8_BINARY", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", 3, "UTF8_LCASE", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", 3, "UNICODE", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", 3, "UNICODE_CI", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "d", 1, "UTF8_BINARY", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "d", 1, "UTF8_LCASE", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "d", 1, "UNICODE", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "d", 1, "UNICODE_CI", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "c", -1, "UTF8_BINARY", ""); + assertSubstringIndex("a🙃b🙃c", "c", -1, "UTF8_LCASE", ""); + assertSubstringIndex("a🙃b🙃c", "c", -1, "UNICODE", ""); + assertSubstringIndex("a🙃b🙃c", "c", -1, "UNICODE_CI", ""); + assertSubstringIndex("a🙃b🙃c", "🙃", -1, "UTF8_BINARY", "c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -1, "UTF8_LCASE", "c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -1, "UNICODE", "c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -1, "UNICODE_CI", "c"); + assertSubstringIndex("a🙃b🙃c", "b", -1, "UTF8_BINARY", "🙃c"); + assertSubstringIndex("a🙃b🙃c", "b", -1, "UTF8_LCASE", "🙃c"); + assertSubstringIndex("a🙃b🙃c", "b", -1, "UNICODE", "🙃c"); + assertSubstringIndex("a🙃b🙃c", "b", -1, "UNICODE_CI", "🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -2, "UTF8_BINARY", "b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -2, "UTF8_LCASE", "b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -2, "UNICODE", "b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -2, "UNICODE_CI", "b🙃c"); + assertSubstringIndex("a🙃b🙃c", "a", -1, "UTF8_BINARY", "🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "a", -1, "UTF8_LCASE", "🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "a", -1, "UNICODE", "🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "a", -1, "UNICODE_CI", "🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -3, "UTF8_BINARY", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -3, "UTF8_LCASE", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -3, "UNICODE", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "🙃", -3, "UNICODE_CI", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "d", -1, "UTF8_BINARY", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "d", -1, "UTF8_LCASE", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "d", -1, "UNICODE", "a🙃b🙃c"); + assertSubstringIndex("a🙃b🙃c", "d", -1, "UNICODE_CI", "a🙃b🙃c"); } - private void assertStringTrim( - String collation, - String sourceString, - String trimString, - String expectedResultString) throws SparkException { + /** + * Verify the behaviour of the `StringTrim` collation support class. + */ + + private void assertStringTrim(String collationName, String sourceString, String trimString, + String expected) throws SparkException { // Prepare the input and expected result. - int collationId = CollationFactory.collationNameToId(collation); + int collationId = CollationFactory.collationNameToId(collationName); UTF8String src = UTF8String.fromString(sourceString); UTF8String trim = UTF8String.fromString(trimString); - UTF8String resultTrimLeftRight, resultTrimRightLeft; - String resultTrim; + UTF8String result, resultTrimLeftRight, resultTrimRightLeft; if (trimString == null) { // Trim string is ASCII space. - resultTrim = CollationSupport.StringTrim.exec(src).toString(); + result = CollationSupport.StringTrim.exec(src); UTF8String trimLeft = CollationSupport.StringTrimLeft.exec(src); resultTrimLeftRight = CollationSupport.StringTrimRight.exec(trimLeft); UTF8String trimRight = CollationSupport.StringTrimRight.exec(src); resultTrimRightLeft = CollationSupport.StringTrimLeft.exec(trimRight); } else { // Trim string is specified. - resultTrim = CollationSupport.StringTrim.exec(src, trim, collationId).toString(); + result = CollationSupport.StringTrim.exec(src, trim, collationId); UTF8String trimLeft = CollationSupport.StringTrimLeft.exec(src, trim, collationId); resultTrimLeftRight = CollationSupport.StringTrimRight.exec(trimLeft, trim, collationId); UTF8String trimRight = CollationSupport.StringTrimRight.exec(src, trim, collationId); @@ -1251,61 +2767,15 @@ private void assertStringTrim( } // Test that StringTrim result is as expected. - assertEquals(expectedResultString, resultTrim); + assertEquals(UTF8String.fromString(expected), result); // Test that the order of the trims is not important. - assertEquals(resultTrimLeftRight.toString(), resultTrim); - assertEquals(resultTrimRightLeft.toString(), resultTrim); - } - - private void assertStringTrimLeft( - String collation, - String sourceString, - String trimString, - String expectedResultString) throws SparkException { - // Prepare the input and expected result. - int collationId = CollationFactory.collationNameToId(collation); - UTF8String src = UTF8String.fromString(sourceString); - UTF8String trim = UTF8String.fromString(trimString); - String result; - - if (trimString == null) { - // Trim string is ASCII space. - result = CollationSupport.StringTrimLeft.exec(src).toString(); - } else { - // Trim string is specified. - result = CollationSupport.StringTrimLeft.exec(src, trim, collationId).toString(); - } - - // Test that StringTrimLeft result is as expected. - assertEquals(expectedResultString, result); - } - - private void assertStringTrimRight( - String collation, - String sourceString, - String trimString, - String expectedResultString) throws SparkException { - // Prepare the input and expected result. - int collationId = CollationFactory.collationNameToId(collation); - UTF8String src = UTF8String.fromString(sourceString); - UTF8String trim = UTF8String.fromString(trimString); - String result; - - if (trimString == null) { - // Trim string is ASCII space. - result = CollationSupport.StringTrimRight.exec(src).toString(); - } else { - // Trim string is specified. - result = CollationSupport.StringTrimRight.exec(src, trim, collationId).toString(); - } - - // Test that StringTrimRight result is as expected. - assertEquals(expectedResultString, result); + assertEquals(resultTrimLeftRight, result); + assertEquals(resultTrimRightLeft, result); } @Test public void testStringTrim() throws SparkException { - // Basic tests - UTF8_BINARY. + // Basic tests. assertStringTrim("UTF8_BINARY", "", "", ""); assertStringTrim("UTF8_BINARY", "", "xyz", ""); assertStringTrim("UTF8_BINARY", "asd", "", "asd"); @@ -1315,25 +2785,6 @@ public void testStringTrim() throws SparkException { assertStringTrim("UTF8_BINARY", "asd", "x", "asd"); assertStringTrim("UTF8_BINARY", "xxasdxx", "x", "asd"); assertStringTrim("UTF8_BINARY", "xa世ax", "x", "a世a"); - assertStringTrimLeft("UTF8_BINARY", "", "", ""); - assertStringTrimLeft("UTF8_BINARY", "", "xyz", ""); - assertStringTrimLeft("UTF8_BINARY", "asd", "", "asd"); - assertStringTrimLeft("UTF8_BINARY", "asd", null, "asd"); - assertStringTrimLeft("UTF8_BINARY", " asd ", null, "asd "); - assertStringTrimLeft("UTF8_BINARY", " a世a ", null, "a世a "); - assertStringTrimLeft("UTF8_BINARY", "asd", "x", "asd"); - assertStringTrimLeft("UTF8_BINARY", "xxasdxx", "x", "asdxx"); - assertStringTrimLeft("UTF8_BINARY", "xa世ax", "x", "a世ax"); - assertStringTrimRight("UTF8_BINARY", "", "", ""); - assertStringTrimRight("UTF8_BINARY", "", "xyz", ""); - assertStringTrimRight("UTF8_BINARY", "asd", "", "asd"); - assertStringTrimRight("UTF8_BINARY", "asd", null, "asd"); - assertStringTrimRight("UTF8_BINARY", " asd ", null, " asd"); - assertStringTrimRight("UTF8_BINARY", " a世a ", null, " a世a"); - assertStringTrimRight("UTF8_BINARY", "asd", "x", "asd"); - assertStringTrimRight("UTF8_BINARY", "xxasdxx", "x", "xxasd"); - assertStringTrimRight("UTF8_BINARY", "xa世ax", "x", "xa世a"); - // Basic tests - UTF8_LCASE. assertStringTrim("UTF8_LCASE", "", "", ""); assertStringTrim("UTF8_LCASE", "", "xyz", ""); assertStringTrim("UTF8_LCASE", "asd", "", "asd"); @@ -1343,25 +2794,6 @@ public void testStringTrim() throws SparkException { assertStringTrim("UTF8_LCASE", "asd", "x", "asd"); assertStringTrim("UTF8_LCASE", "xxasdxx", "x", "asd"); assertStringTrim("UTF8_LCASE", "xa世ax", "x", "a世a"); - assertStringTrimLeft("UTF8_LCASE", "", "", ""); - assertStringTrimLeft("UTF8_LCASE", "", "xyz", ""); - assertStringTrimLeft("UTF8_LCASE", "asd", "", "asd"); - assertStringTrimLeft("UTF8_LCASE", "asd", null, "asd"); - assertStringTrimLeft("UTF8_LCASE", " asd ", null, "asd "); - assertStringTrimLeft("UTF8_LCASE", " a世a ", null, "a世a "); - assertStringTrimLeft("UTF8_LCASE", "asd", "x", "asd"); - assertStringTrimLeft("UTF8_LCASE", "xxasdxx", "x", "asdxx"); - assertStringTrimLeft("UTF8_LCASE", "xa世ax", "x", "a世ax"); - assertStringTrimRight("UTF8_LCASE", "", "", ""); - assertStringTrimRight("UTF8_LCASE", "", "xyz", ""); - assertStringTrimRight("UTF8_LCASE", "asd", "", "asd"); - assertStringTrimRight("UTF8_LCASE", "asd", null, "asd"); - assertStringTrimRight("UTF8_LCASE", " asd ", null, " asd"); - assertStringTrimRight("UTF8_LCASE", " a世a ", null, " a世a"); - assertStringTrimRight("UTF8_LCASE", "asd", "x", "asd"); - assertStringTrimRight("UTF8_LCASE", "xxasdxx", "x", "xxasd"); - assertStringTrimRight("UTF8_LCASE", "xa世ax", "x", "xa世a"); - // Basic tests - UNICODE. assertStringTrim("UNICODE", "", "", ""); assertStringTrim("UNICODE", "", "xyz", ""); assertStringTrim("UNICODE", "asd", "", "asd"); @@ -1371,25 +2803,6 @@ public void testStringTrim() throws SparkException { assertStringTrim("UNICODE", "asd", "x", "asd"); assertStringTrim("UNICODE", "xxasdxx", "x", "asd"); assertStringTrim("UNICODE", "xa世ax", "x", "a世a"); - assertStringTrimLeft("UNICODE", "", "", ""); - assertStringTrimLeft("UNICODE", "", "xyz", ""); - assertStringTrimLeft("UNICODE", "asd", "", "asd"); - assertStringTrimLeft("UNICODE", "asd", null, "asd"); - assertStringTrimLeft("UNICODE", " asd ", null, "asd "); - assertStringTrimLeft("UNICODE", " a世a ", null, "a世a "); - assertStringTrimLeft("UNICODE", "asd", "x", "asd"); - assertStringTrimLeft("UNICODE", "xxasdxx", "x", "asdxx"); - assertStringTrimLeft("UNICODE", "xa世ax", "x", "a世ax"); - assertStringTrimRight("UNICODE", "", "", ""); - assertStringTrimRight("UNICODE", "", "xyz", ""); - assertStringTrimRight("UNICODE", "asd", "", "asd"); - assertStringTrimRight("UNICODE", "asd", null, "asd"); - assertStringTrimRight("UNICODE", " asd ", null, " asd"); - assertStringTrimRight("UNICODE", " a世a ", null, " a世a"); - assertStringTrimRight("UNICODE", "asd", "x", "asd"); - assertStringTrimRight("UNICODE", "xxasdxx", "x", "xxasd"); - assertStringTrimRight("UNICODE", "xa世ax", "x", "xa世a"); - // Basic tests - UNICODE_CI. assertStringTrim("UNICODE_CI", "", "", ""); assertStringTrim("UNICODE_CI", "", "xyz", ""); assertStringTrim("UNICODE_CI", "asd", "", "asd"); @@ -1399,98 +2812,44 @@ public void testStringTrim() throws SparkException { assertStringTrim("UNICODE_CI", "asd", "x", "asd"); assertStringTrim("UNICODE_CI", "xxasdxx", "x", "asd"); assertStringTrim("UNICODE_CI", "xa世ax", "x", "a世a"); - assertStringTrimLeft("UNICODE_CI", "", "", ""); - assertStringTrimLeft("UNICODE_CI", "", "xyz", ""); - assertStringTrimLeft("UNICODE_CI", "asd", "", "asd"); - assertStringTrimLeft("UNICODE_CI", "asd", null, "asd"); - assertStringTrimLeft("UNICODE_CI", " asd ", null, "asd "); - assertStringTrimLeft("UNICODE_CI", " a世a ", null, "a世a "); - assertStringTrimLeft("UNICODE_CI", "asd", "x", "asd"); - assertStringTrimLeft("UNICODE_CI", "xxasdxx", "x", "asdxx"); - assertStringTrimLeft("UNICODE_CI", "xa世ax", "x", "a世ax"); - assertStringTrimRight("UNICODE_CI", "", "", ""); - assertStringTrimRight("UNICODE_CI", "", "xyz", ""); - assertStringTrimRight("UNICODE_CI", "asd", "", "asd"); - assertStringTrimRight("UNICODE_CI", "asd", null, "asd"); - assertStringTrimRight("UNICODE_CI", " asd ", null, " asd"); - assertStringTrimRight("UNICODE_CI", " a世a ", null, " a世a"); - assertStringTrimRight("UNICODE_CI", "asd", "x", "asd"); - assertStringTrimRight("UNICODE_CI", "xxasdxx", "x", "xxasd"); - assertStringTrimRight("UNICODE_CI", "xa世ax", "x", "xa世a"); - - // Case variation - UTF8_BINARY. + // Case variation. assertStringTrim("UTF8_BINARY", "asd", "A", "asd"); assertStringTrim("UTF8_BINARY", "ddsXXXaa", "asd", "XXX"); assertStringTrim("UTF8_BINARY", "ASD", "a", "ASD"); - assertStringTrimLeft("UTF8_BINARY", "ddsXXXaa", "asd", "XXXaa"); - assertStringTrimRight("UTF8_BINARY", "ddsXXXaa", "asd", "ddsXXX"); - // Case variation - UTF8_LCASE. assertStringTrim("UTF8_LCASE", "asd", "A", "sd"); assertStringTrim("UTF8_LCASE", "ASD", "a", "SD"); assertStringTrim("UTF8_LCASE", "ddsXXXaa", "ASD", "XXX"); - assertStringTrimLeft("UTF8_LCASE", "ddsXXXaa", "aSd", "XXXaa"); - assertStringTrimRight("UTF8_LCASE", "ddsXXXaa", "AsD", "ddsXXX"); - // Case variation - UNICODE. assertStringTrim("UNICODE", "asd", "A", "asd"); assertStringTrim("UNICODE", "ASD", "a", "ASD"); assertStringTrim("UNICODE", "ddsXXXaa", "asd", "XXX"); - assertStringTrimLeft("UNICODE", "ddsXXXaa", "asd", "XXXaa"); - assertStringTrimRight("UNICODE", "ddsXXXaa", "asd", "ddsXXX"); - // Case variation - UNICODE_CI. assertStringTrim("UNICODE_CI", "asd", "A", "sd"); assertStringTrim("UNICODE_CI", "ASD", "a", "SD"); assertStringTrim("UNICODE_CI", "ddsXXXaa", "ASD", "XXX"); - assertStringTrimLeft("UNICODE_CI", "ddsXXXaa", "aSd", "XXXaa"); - assertStringTrimRight("UNICODE_CI", "ddsXXXaa", "AsD", "ddsXXX"); - - // Case-variable character length - UTF8_BINARY. + assertStringTrim("SR_CI_AI", "cSCšćČXXXsčšČŠsć", "čš", "XXX"); + // One-to-many case mapping (e.g. Turkish dotted I).. assertStringTrim("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); - assertStringTrimLeft("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); - assertStringTrimRight("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); assertStringTrim("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); - assertStringTrimLeft("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); - assertStringTrimRight("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); assertStringTrim("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaa"); - assertStringTrimLeft("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaaẞ"); - assertStringTrimRight("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "Ëaaa"); - // Case-variable character length - UTF8_LCASE. assertStringTrim("UTF8_LCASE", "ẞaaaẞ", "ß", "aaa"); - assertStringTrimLeft("UTF8_LCASE", "ẞaaaẞ", "ß", "aaaẞ"); - assertStringTrimRight("UTF8_LCASE", "ẞaaaẞ", "ß", "ẞaaa"); assertStringTrim("UTF8_LCASE", "ßaaaß", "ẞ", "aaa"); - assertStringTrimLeft("UTF8_LCASE", "ßaaaß", "ẞ", "aaaß"); - assertStringTrimRight("UTF8_LCASE", "ßaaaß", "ẞ", "ßaaa"); assertStringTrim("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "aaa"); - assertStringTrimLeft("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "aaaẞ"); - assertStringTrimRight("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); - // Case-variable character length - UNICODE. assertStringTrim("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); - assertStringTrimLeft("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); - assertStringTrimRight("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); assertStringTrim("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); - assertStringTrimLeft("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); - assertStringTrimRight("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); assertStringTrim("UNICODE", "Ëaaaẞ", "Ëẞ", "aaa"); - assertStringTrimLeft("UNICODE", "Ëaaaẞ", "Ëẞ", "aaaẞ"); - assertStringTrimRight("UNICODE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); - // Case-variable character length - UNICODE_CI. assertStringTrim("UNICODE_CI", "ẞaaaẞ", "ß", "aaa"); - assertStringTrimLeft("UNICODE_CI", "ẞaaaẞ", "ß", "aaaẞ"); - assertStringTrimRight("UNICODE_CI", "ẞaaaẞ", "ß", "ẞaaa"); assertStringTrim("UNICODE_CI", "ßaaaß", "ẞ", "aaa"); - assertStringTrimLeft("UNICODE_CI", "ßaaaß", "ẞ", "aaaß"); - assertStringTrimRight("UNICODE_CI", "ßaaaß", "ẞ", "ßaaa"); assertStringTrim("UNICODE_CI", "Ëaaaẞ", "Ëẞ", "aaa"); - assertStringTrimLeft("UNICODE_CI", "Ëaaaẞ", "Ëẞ", "aaaẞ"); - assertStringTrimRight("UNICODE_CI", "Ëaaaẞ", "Ëẞ", "Ëaaa"); - - // One-to-many case mapping - UTF8_BINARY. + // One-to-many case mapping (e.g. Turkish dotted I). assertStringTrim("UTF8_BINARY", "i", "i", ""); assertStringTrim("UTF8_BINARY", "iii", "I", "iii"); assertStringTrim("UTF8_BINARY", "I", "iii", "I"); assertStringTrim("UTF8_BINARY", "ixi", "i", "x"); assertStringTrim("UTF8_BINARY", "i", "İ", "i"); assertStringTrim("UTF8_BINARY", "i\u0307", "İ", "i\u0307"); + assertStringTrim("UTF8_BINARY", "ii\u0307", "İi", "\u0307"); + assertStringTrim("UTF8_BINARY", "iii\u0307", "İi", "\u0307"); + assertStringTrim("UTF8_BINARY", "iiii\u0307", "iİ", "\u0307"); + assertStringTrim("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "\u0307ii\u0307"); assertStringTrim("UTF8_BINARY", "i\u0307", "i", "\u0307"); assertStringTrim("UTF8_BINARY", "i\u0307", "\u0307", "i"); assertStringTrim("UTF8_BINARY", "i\u0307", "i\u0307", ""); @@ -1510,63 +2869,16 @@ public void testStringTrim() throws SparkException { assertStringTrim("UTF8_BINARY", "Ixİ", "i\u0307", "Ixİ"); assertStringTrim("UTF8_BINARY", "IXİ", "ix\u0307", "IXİ"); assertStringTrim("UTF8_BINARY", "xi\u0307", "\u0307IX", "xi"); - assertStringTrimLeft("UTF8_BINARY", "i", "i", ""); - assertStringTrimLeft("UTF8_BINARY", "iii", "I", "iii"); - assertStringTrimLeft("UTF8_BINARY", "I", "iii", "I"); - assertStringTrimLeft("UTF8_BINARY", "ixi", "i", "xi"); - assertStringTrimLeft("UTF8_BINARY", "i", "İ", "i"); - assertStringTrimLeft("UTF8_BINARY", "i\u0307", "İ", "i\u0307"); - assertStringTrimLeft("UTF8_BINARY", "i\u0307", "i", "\u0307"); - assertStringTrimLeft("UTF8_BINARY", "i\u0307", "\u0307", "i\u0307"); - assertStringTrimLeft("UTF8_BINARY", "i\u0307", "i\u0307", ""); - assertStringTrimLeft("UTF8_BINARY", "i\u0307i\u0307", "i\u0307", ""); - assertStringTrimLeft("UTF8_BINARY", "i\u0307\u0307", "i\u0307", ""); - assertStringTrimLeft("UTF8_BINARY", "i\u0307i", "i\u0307", ""); - assertStringTrimLeft("UTF8_BINARY", "i\u0307i", "İ", "i\u0307i"); - assertStringTrimLeft("UTF8_BINARY", "i\u0307İ", "i\u0307", "İ"); - assertStringTrimLeft("UTF8_BINARY", "i\u0307İ", "İ", "i\u0307İ"); - assertStringTrimLeft("UTF8_BINARY", "İ", "İ", ""); - assertStringTrimLeft("UTF8_BINARY", "IXi", "İ", "IXi"); - assertStringTrimLeft("UTF8_BINARY", "ix\u0307", "Ixİ", "ix\u0307"); - assertStringTrimLeft("UTF8_BINARY", "i\u0307x", "IXİ", "i\u0307x"); - assertStringTrimLeft("UTF8_BINARY", "i\u0307x", "ix\u0307İ", ""); - assertStringTrimLeft("UTF8_BINARY", "İ", "i", "İ"); - assertStringTrimLeft("UTF8_BINARY", "İ", "\u0307", "İ"); - assertStringTrimLeft("UTF8_BINARY", "Ixİ", "i\u0307", "Ixİ"); - assertStringTrimLeft("UTF8_BINARY", "IXİ", "ix\u0307", "IXİ"); - assertStringTrimLeft("UTF8_BINARY", "xi\u0307", "\u0307IX", "xi\u0307"); - assertStringTrimRight("UTF8_BINARY", "i", "i", ""); - assertStringTrimRight("UTF8_BINARY", "iii", "I", "iii"); - assertStringTrimRight("UTF8_BINARY", "I", "iii", "I"); - assertStringTrimRight("UTF8_BINARY", "ixi", "i", "ix"); - assertStringTrimRight("UTF8_BINARY", "i", "İ", "i"); - assertStringTrimRight("UTF8_BINARY", "i\u0307", "İ", "i\u0307"); - assertStringTrimRight("UTF8_BINARY", "i\u0307", "i", "i\u0307"); - assertStringTrimRight("UTF8_BINARY", "i\u0307", "\u0307", "i"); - assertStringTrimRight("UTF8_BINARY", "i\u0307", "i\u0307", ""); - assertStringTrimRight("UTF8_BINARY", "i\u0307i\u0307", "i\u0307", ""); - assertStringTrimRight("UTF8_BINARY", "i\u0307\u0307", "i\u0307", ""); - assertStringTrimRight("UTF8_BINARY", "i\u0307i", "i\u0307", ""); - assertStringTrimRight("UTF8_BINARY", "i\u0307i", "İ", "i\u0307i"); - assertStringTrimRight("UTF8_BINARY", "i\u0307İ", "i\u0307", "i\u0307İ"); - assertStringTrimRight("UTF8_BINARY", "i\u0307İ", "İ", "i\u0307"); - assertStringTrimRight("UTF8_BINARY", "İ", "İ", ""); - assertStringTrimRight("UTF8_BINARY", "IXi", "İ", "IXi"); - assertStringTrimRight("UTF8_BINARY", "ix\u0307", "Ixİ", "ix\u0307"); - assertStringTrimRight("UTF8_BINARY", "i\u0307x", "IXİ", "i\u0307x"); - assertStringTrimRight("UTF8_BINARY", "i\u0307x", "ix\u0307İ", ""); - assertStringTrimRight("UTF8_BINARY", "İ", "i", "İ"); - assertStringTrimRight("UTF8_BINARY", "İ", "\u0307", "İ"); - assertStringTrimRight("UTF8_BINARY", "Ixİ", "i\u0307", "Ixİ"); - assertStringTrimRight("UTF8_BINARY", "IXİ", "ix\u0307", "IXİ"); - assertStringTrimRight("UTF8_BINARY", "xi\u0307", "\u0307IX", "xi"); - // One-to-many case mapping - UTF8_LCASE. assertStringTrim("UTF8_LCASE", "i", "i", ""); assertStringTrim("UTF8_LCASE", "iii", "I", ""); assertStringTrim("UTF8_LCASE", "I", "iii", ""); assertStringTrim("UTF8_LCASE", "ixi", "i", "x"); assertStringTrim("UTF8_LCASE", "i", "İ", "i"); assertStringTrim("UTF8_LCASE", "i\u0307", "İ", ""); + assertStringTrim("UTF8_LCASE", "ii\u0307", "İi", ""); + assertStringTrim("UTF8_LCASE", "iii\u0307", "İi", ""); + assertStringTrim("UTF8_LCASE", "iiii\u0307", "iİ", ""); + assertStringTrim("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", ""); assertStringTrim("UTF8_LCASE", "i\u0307", "i", "\u0307"); assertStringTrim("UTF8_LCASE", "i\u0307", "\u0307", "i"); assertStringTrim("UTF8_LCASE", "i\u0307", "i\u0307", ""); @@ -1586,63 +2898,16 @@ public void testStringTrim() throws SparkException { assertStringTrim("UTF8_LCASE", "Ixİ", "i\u0307", "xİ"); assertStringTrim("UTF8_LCASE", "IXİ", "ix\u0307", "İ"); assertStringTrim("UTF8_LCASE", "xi\u0307", "\u0307IX", ""); - assertStringTrimLeft("UTF8_LCASE", "i", "i", ""); - assertStringTrimLeft("UTF8_LCASE", "iii", "I", ""); - assertStringTrimLeft("UTF8_LCASE", "I", "iii", ""); - assertStringTrimLeft("UTF8_LCASE", "ixi", "i", "xi"); - assertStringTrimLeft("UTF8_LCASE", "i", "İ", "i"); - assertStringTrimLeft("UTF8_LCASE", "i\u0307", "İ", ""); - assertStringTrimLeft("UTF8_LCASE", "i\u0307", "i", "\u0307"); - assertStringTrimLeft("UTF8_LCASE", "i\u0307", "\u0307", "i\u0307"); - assertStringTrimLeft("UTF8_LCASE", "i\u0307", "i\u0307", ""); - assertStringTrimLeft("UTF8_LCASE", "i\u0307i\u0307", "i\u0307", ""); - assertStringTrimLeft("UTF8_LCASE", "i\u0307\u0307", "i\u0307", ""); - assertStringTrimLeft("UTF8_LCASE", "i\u0307i", "i\u0307", ""); - assertStringTrimLeft("UTF8_LCASE", "i\u0307i", "İ", "i"); - assertStringTrimLeft("UTF8_LCASE", "i\u0307İ", "i\u0307", "İ"); - assertStringTrimLeft("UTF8_LCASE", "i\u0307İ", "İ", ""); - assertStringTrimLeft("UTF8_LCASE", "İ", "İ", ""); - assertStringTrimLeft("UTF8_LCASE", "IXi", "İ", "IXi"); - assertStringTrimLeft("UTF8_LCASE", "ix\u0307", "Ixİ", "\u0307"); - assertStringTrimLeft("UTF8_LCASE", "i\u0307x", "IXİ", ""); - assertStringTrimLeft("UTF8_LCASE", "i\u0307x", "I\u0307xİ", ""); - assertStringTrimLeft("UTF8_LCASE", "İ", "i", "İ"); - assertStringTrimLeft("UTF8_LCASE", "İ", "\u0307", "İ"); - assertStringTrimLeft("UTF8_LCASE", "Ixİ", "i\u0307", "xİ"); - assertStringTrimLeft("UTF8_LCASE", "IXİ", "ix\u0307", "İ"); - assertStringTrimLeft("UTF8_LCASE", "xi\u0307", "\u0307IX", ""); - assertStringTrimRight("UTF8_LCASE", "i", "i", ""); - assertStringTrimRight("UTF8_LCASE", "iii", "I", ""); - assertStringTrimRight("UTF8_LCASE", "I", "iii", ""); - assertStringTrimRight("UTF8_LCASE", "ixi", "i", "ix"); - assertStringTrimRight("UTF8_LCASE", "i", "İ", "i"); - assertStringTrimRight("UTF8_LCASE", "i\u0307", "İ", ""); - assertStringTrimRight("UTF8_LCASE", "i\u0307", "i", "i\u0307"); - assertStringTrimRight("UTF8_LCASE", "i\u0307", "\u0307", "i"); - assertStringTrimRight("UTF8_LCASE", "i\u0307", "i\u0307", ""); - assertStringTrimRight("UTF8_LCASE", "i\u0307i\u0307", "i\u0307", ""); - assertStringTrimRight("UTF8_LCASE", "i\u0307\u0307", "i\u0307", ""); - assertStringTrimRight("UTF8_LCASE", "i\u0307i", "i\u0307", ""); - assertStringTrimRight("UTF8_LCASE", "i\u0307i", "İ", "i\u0307i"); - assertStringTrimRight("UTF8_LCASE", "i\u0307İ", "i\u0307", "i\u0307İ"); - assertStringTrimRight("UTF8_LCASE", "i\u0307İ", "İ", ""); - assertStringTrimRight("UTF8_LCASE", "İ", "İ", ""); - assertStringTrimRight("UTF8_LCASE", "IXi", "İ", "IXi"); - assertStringTrimRight("UTF8_LCASE", "ix\u0307", "Ixİ", "ix\u0307"); - assertStringTrimRight("UTF8_LCASE", "i\u0307x", "IXİ", ""); - assertStringTrimRight("UTF8_LCASE", "i\u0307x", "I\u0307xİ", ""); - assertStringTrimRight("UTF8_LCASE", "İ", "i", "İ"); - assertStringTrimRight("UTF8_LCASE", "İ", "\u0307", "İ"); - assertStringTrimRight("UTF8_LCASE", "Ixİ", "i\u0307", "Ixİ"); - assertStringTrimRight("UTF8_LCASE", "IXİ", "ix\u0307", "IXİ"); - assertStringTrimRight("UTF8_LCASE", "xi\u0307", "\u0307IX", ""); - // One-to-many case mapping - UNICODE. assertStringTrim("UNICODE", "i", "i", ""); assertStringTrim("UNICODE", "iii", "I", "iii"); assertStringTrim("UNICODE", "I", "iii", "I"); assertStringTrim("UNICODE", "ixi", "i", "x"); assertStringTrim("UNICODE", "i", "İ", "i"); assertStringTrim("UNICODE", "i\u0307", "İ", "i\u0307"); + assertStringTrim("UNICODE", "ii\u0307", "İi", "i\u0307"); + assertStringTrim("UNICODE", "iii\u0307", "İi", "i\u0307"); + assertStringTrim("UNICODE", "iiii\u0307", "iİ", "i\u0307"); + assertStringTrim("UNICODE", "ii\u0307ii\u0307", "iİ", "i\u0307ii\u0307"); assertStringTrim("UNICODE", "i\u0307", "i", "i\u0307"); assertStringTrim("UNICODE", "i\u0307", "\u0307", "i\u0307"); assertStringTrim("UNICODE", "i\u0307", "i\u0307", "i\u0307"); @@ -1663,65 +2928,16 @@ public void testStringTrim() throws SparkException { assertStringTrim("UNICODE", "Ixİ", "i\u0307", "Ixİ"); assertStringTrim("UNICODE", "IXİ", "ix\u0307", "IXİ"); assertStringTrim("UNICODE", "xi\u0307", "\u0307IX", "xi\u0307"); - assertStringTrimLeft("UNICODE", "i", "i", ""); - assertStringTrimLeft("UNICODE", "iii", "I", "iii"); - assertStringTrimLeft("UNICODE", "I", "iii", "I"); - assertStringTrimLeft("UNICODE", "ixi", "i", "xi"); - assertStringTrimLeft("UNICODE", "i", "İ", "i"); - assertStringTrimLeft("UNICODE", "i\u0307", "İ", "i\u0307"); - assertStringTrimLeft("UNICODE", "i\u0307", "i", "i\u0307"); - assertStringTrimLeft("UNICODE", "i\u0307", "\u0307", "i\u0307"); - assertStringTrimLeft("UNICODE", "i\u0307", "i\u0307", "i\u0307"); - assertStringTrimLeft("UNICODE", "i\u0307i\u0307", "i\u0307", "i\u0307i\u0307"); - assertStringTrimLeft("UNICODE", "i\u0307\u0307", "i\u0307", "i\u0307\u0307"); - assertStringTrimLeft("UNICODE", "i\u0307i", "i\u0307", "i\u0307i"); - assertStringTrimLeft("UNICODE", "i\u0307i", "İ", "i\u0307i"); - assertStringTrimLeft("UNICODE", "i\u0307İ", "i\u0307", "i\u0307İ"); - assertStringTrimLeft("UNICODE", "i\u0307İ", "İ", "i\u0307İ"); - assertStringTrimLeft("UNICODE", "İ", "İ", ""); - assertStringTrimLeft("UNICODE", "IXi", "İ", "IXi"); - assertStringTrimLeft("UNICODE", "ix\u0307", "Ixİ", "ix\u0307"); - assertStringTrimLeft("UNICODE", "i\u0307x", "IXİ", "i\u0307x"); - assertStringTrimLeft("UNICODE", "i\u0307x", "ix\u0307İ", "i\u0307x"); - assertStringTrimLeft("UNICODE", "İ", "i", "İ"); - assertStringTrimLeft("UNICODE", "İ", "\u0307", "İ"); - assertStringTrimLeft("UNICODE", "i\u0307", "i\u0307", "i\u0307"); - assertStringTrimLeft("UNICODE", "Ixİ", "i\u0307", "Ixİ"); - assertStringTrimLeft("UNICODE", "IXİ", "ix\u0307", "IXİ"); - assertStringTrimLeft("UNICODE", "xi\u0307", "\u0307IX", "xi\u0307"); - assertStringTrimRight("UNICODE", "i", "i", ""); - assertStringTrimRight("UNICODE", "iii", "I", "iii"); - assertStringTrimRight("UNICODE", "I", "iii", "I"); - assertStringTrimRight("UNICODE", "ixi", "i", "ix"); - assertStringTrimRight("UNICODE", "i", "İ", "i"); - assertStringTrimRight("UNICODE", "i\u0307", "İ", "i\u0307"); - assertStringTrimRight("UNICODE", "i\u0307", "i", "i\u0307"); - assertStringTrimRight("UNICODE", "i\u0307", "\u0307", "i\u0307"); - assertStringTrimRight("UNICODE", "i\u0307", "i\u0307", "i\u0307"); - assertStringTrimRight("UNICODE", "i\u0307i\u0307", "i\u0307", "i\u0307i\u0307"); - assertStringTrimRight("UNICODE", "i\u0307\u0307", "i\u0307", "i\u0307\u0307"); - assertStringTrimRight("UNICODE", "i\u0307i", "i\u0307", "i\u0307"); - assertStringTrimRight("UNICODE", "i\u0307i", "İ", "i\u0307i"); - assertStringTrimRight("UNICODE", "i\u0307İ", "i\u0307", "i\u0307İ"); - assertStringTrimRight("UNICODE", "i\u0307İ", "İ", "i\u0307"); - assertStringTrimRight("UNICODE", "İ", "İ", ""); - assertStringTrimRight("UNICODE", "IXi", "İ", "IXi"); - assertStringTrimRight("UNICODE", "ix\u0307", "Ixİ", "ix\u0307"); - assertStringTrimRight("UNICODE", "i\u0307x", "IXİ", "i\u0307x"); - assertStringTrimRight("UNICODE", "i\u0307x", "ix\u0307İ", "i\u0307"); - assertStringTrimRight("UNICODE", "İ", "i", "İ"); - assertStringTrimRight("UNICODE", "İ", "\u0307", "İ"); - assertStringTrimRight("UNICODE", "i\u0307", "i\u0307", "i\u0307"); - assertStringTrimRight("UNICODE", "Ixİ", "i\u0307", "Ixİ"); - assertStringTrimRight("UNICODE", "IXİ", "ix\u0307", "IXİ"); - assertStringTrimRight("UNICODE", "xi\u0307", "\u0307IX", "xi\u0307"); - // One-to-many case mapping - UNICODE_CI. assertStringTrim("UNICODE_CI", "i", "i", ""); assertStringTrim("UNICODE_CI", "iii", "I", ""); assertStringTrim("UNICODE_CI", "I", "iii", ""); assertStringTrim("UNICODE_CI", "ixi", "i", "x"); assertStringTrim("UNICODE_CI", "i", "İ", "i"); assertStringTrim("UNICODE_CI", "i\u0307", "İ", ""); + assertStringTrim("UNICODE_CI", "ii\u0307", "İi", ""); + assertStringTrim("UNICODE_CI", "iii\u0307", "İi", ""); + assertStringTrim("UNICODE_CI", "iiii\u0307", "iİ", ""); + assertStringTrim("UNICODE_CI", "ii\u0307ii\u0307", "iİ", ""); assertStringTrim("UNICODE_CI", "i\u0307", "i", "i\u0307"); assertStringTrim("UNICODE_CI", "i\u0307", "\u0307", "i\u0307"); assertStringTrim("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307"); @@ -1742,12 +2958,282 @@ public void testStringTrim() throws SparkException { assertStringTrim("UNICODE_CI", "Ixİ", "i\u0307", "xİ"); assertStringTrim("UNICODE_CI", "IXİ", "ix\u0307", "İ"); assertStringTrim("UNICODE_CI", "xi\u0307", "\u0307IX", "i\u0307"); + // Conditional case mapping (e.g. Greek sigmas). + assertStringTrim("UTF8_BINARY", "ςxς", "σ", "ςxς"); + assertStringTrim("UTF8_BINARY", "ςxς", "ς", "x"); + assertStringTrim("UTF8_BINARY", "ςxς", "Σ", "ςxς"); + assertStringTrim("UTF8_BINARY", "σxσ", "σ", "x"); + assertStringTrim("UTF8_BINARY", "σxσ", "ς", "σxσ"); + assertStringTrim("UTF8_BINARY", "σxσ", "Σ", "σxσ"); + assertStringTrim("UTF8_BINARY", "ΣxΣ", "σ", "ΣxΣ"); + assertStringTrim("UTF8_BINARY", "ΣxΣ", "ς", "ΣxΣ"); + assertStringTrim("UTF8_BINARY", "ΣxΣ", "Σ", "x"); + assertStringTrim("UTF8_LCASE", "ςxς", "σ", "x"); + assertStringTrim("UTF8_LCASE", "ςxς", "ς", "x"); + assertStringTrim("UTF8_LCASE", "ςxς", "Σ", "x"); + assertStringTrim("UTF8_LCASE", "σxσ", "σ", "x"); + assertStringTrim("UTF8_LCASE", "σxσ", "ς", "x"); + assertStringTrim("UTF8_LCASE", "σxσ", "Σ", "x"); + assertStringTrim("UTF8_LCASE", "ΣxΣ", "σ", "x"); + assertStringTrim("UTF8_LCASE", "ΣxΣ", "ς", "x"); + assertStringTrim("UTF8_LCASE", "ΣxΣ", "Σ", "x"); + assertStringTrim("UNICODE", "ςxς", "σ", "ςxς"); + assertStringTrim("UNICODE", "ςxς", "ς", "x"); + assertStringTrim("UNICODE", "ςxς", "Σ", "ςxς"); + assertStringTrim("UNICODE", "σxσ", "σ", "x"); + assertStringTrim("UNICODE", "σxσ", "ς", "σxσ"); + assertStringTrim("UNICODE", "σxσ", "Σ", "σxσ"); + assertStringTrim("UNICODE", "ΣxΣ", "σ", "ΣxΣ"); + assertStringTrim("UNICODE", "ΣxΣ", "ς", "ΣxΣ"); + assertStringTrim("UNICODE", "ΣxΣ", "Σ", "x"); + assertStringTrim("UNICODE_CI", "ςxς", "σ", "x"); + assertStringTrim("UNICODE_CI", "ςxς", "ς", "x"); + assertStringTrim("UNICODE_CI", "ςxς", "Σ", "x"); + assertStringTrim("UNICODE_CI", "σxσ", "σ", "x"); + assertStringTrim("UNICODE_CI", "σxσ", "ς", "x"); + assertStringTrim("UNICODE_CI", "σxσ", "Σ", "x"); + assertStringTrim("UNICODE_CI", "ΣxΣ", "σ", "x"); + assertStringTrim("UNICODE_CI", "ΣxΣ", "ς", "x"); + assertStringTrim("UNICODE_CI", "ΣxΣ", "Σ", "x"); + // Unicode normalization. + assertStringTrim("UTF8_BINARY", "åβγδa\u030A", "å", "βγδa\u030A"); + assertStringTrim("UTF8_LCASE", "åβγδa\u030A", "Å", "βγδa\u030A"); + assertStringTrim("UNICODE", "åβγδa\u030A", "å", "βγδ"); + assertStringTrim("UNICODE_CI", "åβγδa\u030A", "Å", "βγδ"); + // Surrogate pairs. + assertStringTrim("UTF8_BINARY", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrim("UTF8_LCASE", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrim("UNICODE", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrim("UNICODE_CI", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrim("UTF8_BINARY", "a🙃b🙃c", "ac", "🙃b🙃"); + assertStringTrim("UTF8_LCASE", "a🙃b🙃c", "ac", "🙃b🙃"); + assertStringTrim("UNICODE", "a🙃b🙃c", "ac", "🙃b🙃"); + assertStringTrim("UNICODE_CI", "a🙃b🙃c", "ac", "🙃b🙃"); + assertStringTrim("UTF8_BINARY", "a🙃b🙃c", "a🙃c", "b"); + assertStringTrim("UTF8_LCASE", "a🙃b🙃c", "a🙃c", "b"); + assertStringTrim("UNICODE", "a🙃b🙃c", "a🙃c", "b"); + assertStringTrim("UNICODE_CI", "a🙃b🙃c", "a🙃c", "b"); + assertStringTrim("UTF8_BINARY", "a🙃b🙃c", "abc🙃", ""); + assertStringTrim("UTF8_LCASE", "a🙃b🙃c", "abc🙃", ""); + assertStringTrim("UNICODE", "a🙃b🙃c", "abc🙃", ""); + assertStringTrim("UNICODE_CI", "a🙃b🙃c", "abc🙃", ""); + assertStringTrim("UTF8_BINARY", "😀😆😃😄", "😀😄", "😆😃"); + assertStringTrim("UTF8_LCASE", "😀😆😃😄", "😀😄", "😆😃"); + assertStringTrim("UNICODE", "😀😆😃😄", "😀😄", "😆😃"); + assertStringTrim("UNICODE_CI", "😀😆😃😄", "😀😄", "😆😃"); + assertStringTrim("UTF8_BINARY", "😀😆😃😄", "😃😄", "😀😆"); + assertStringTrim("UTF8_LCASE", "😀😆😃😄", "😃😄", "😀😆"); + assertStringTrim("UNICODE", "😀😆😃😄", "😃😄", "😀😆"); + assertStringTrim("UNICODE_CI", "😀😆😃😄", "😃😄", "😀😆"); + assertStringTrim("UTF8_BINARY", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrim("UTF8_LCASE", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrim("UNICODE", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrim("UNICODE_CI", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrim("UTF8_BINARY", "𐐅", "𐐅", ""); + assertStringTrim("UTF8_LCASE", "𐐅", "𐐅", ""); + assertStringTrim("UNICODE", "𐐅", "𐐅", ""); + assertStringTrim("UNICODE_CI", "𐐅", "𐐅", ""); + assertStringTrim("UTF8_BINARY", "𐐅", "𐐭", "𐐅"); + assertStringTrim("UTF8_LCASE", "𐐅", "𐐭", ""); + assertStringTrim("UNICODE", "𐐅", "𐐭", "𐐅"); + assertStringTrim("UNICODE_CI", "𐐅", "𐐭", ""); + assertStringTrim("UTF8_BINARY", "𝔸", "𝔸", ""); + assertStringTrim("UTF8_LCASE", "𝔸", "𝔸", ""); + assertStringTrim("UNICODE", "𝔸", "𝔸", ""); + assertStringTrim("UNICODE_CI", "𝔸", "𝔸", ""); + assertStringTrim("UTF8_BINARY", "𝔸", "A", "𝔸"); + assertStringTrim("UTF8_LCASE", "𝔸", "A", "𝔸"); + assertStringTrim("UNICODE", "𝔸", "A", "𝔸"); + assertStringTrim("UNICODE_CI", "𝔸", "A", ""); + assertStringTrim("UTF8_BINARY", "𝔸", "a", "𝔸"); + assertStringTrim("UTF8_LCASE", "𝔸", "a", "𝔸"); + assertStringTrim("UNICODE", "𝔸", "a", "𝔸"); + assertStringTrim("UNICODE_CI", "𝔸", "a", ""); + } + + /** + * Verify the behaviour of the `StringTrimLeft` collation support class. + */ + + private void assertStringTrimLeft(String collationName, String sourceString, String trimString, + String expected) throws SparkException { + // Prepare the input and expected result. + int collationId = CollationFactory.collationNameToId(collationName); + UTF8String src = UTF8String.fromString(sourceString); + UTF8String trim = UTF8String.fromString(trimString); + UTF8String result; + + if (trimString == null) { + // Trim string is ASCII space. + result = CollationSupport.StringTrimLeft.exec(src); + } else { + // Trim string is specified. + result = CollationSupport.StringTrimLeft.exec(src, trim, collationId); + } + + // Test that StringTrimLeft result is as expected. + assertEquals(UTF8String.fromString(expected), result); + } + + @Test + public void testStringTrimLeft() throws SparkException { + // Basic tests - UTF8_BINARY. + assertStringTrimLeft("UTF8_BINARY", "", "", ""); + assertStringTrimLeft("UTF8_BINARY", "", "xyz", ""); + assertStringTrimLeft("UTF8_BINARY", "asd", "", "asd"); + assertStringTrimLeft("UTF8_BINARY", "asd", null, "asd"); + assertStringTrimLeft("UTF8_BINARY", " asd ", null, "asd "); + assertStringTrimLeft("UTF8_BINARY", " a世a ", null, "a世a "); + assertStringTrimLeft("UTF8_BINARY", "asd", "x", "asd"); + assertStringTrimLeft("UTF8_BINARY", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UTF8_BINARY", "xa世ax", "x", "a世ax"); + // Basic tests - UTF8_LCASE. + assertStringTrimLeft("UTF8_LCASE", "", "", ""); + assertStringTrimLeft("UTF8_LCASE", "", "xyz", ""); + assertStringTrimLeft("UTF8_LCASE", "asd", "", "asd"); + assertStringTrimLeft("UTF8_LCASE", "asd", null, "asd"); + assertStringTrimLeft("UTF8_LCASE", " asd ", null, "asd "); + assertStringTrimLeft("UTF8_LCASE", " a世a ", null, "a世a "); + assertStringTrimLeft("UTF8_LCASE", "asd", "x", "asd"); + assertStringTrimLeft("UTF8_LCASE", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UTF8_LCASE", "xa世ax", "x", "a世ax"); + // Basic tests - UNICODE. + assertStringTrimLeft("UNICODE", "", "", ""); + assertStringTrimLeft("UNICODE", "", "xyz", ""); + assertStringTrimLeft("UNICODE", "asd", "", "asd"); + assertStringTrimLeft("UNICODE", "asd", null, "asd"); + assertStringTrimLeft("UNICODE", " asd ", null, "asd "); + assertStringTrimLeft("UNICODE", " a世a ", null, "a世a "); + assertStringTrimLeft("UNICODE", "asd", "x", "asd"); + assertStringTrimLeft("UNICODE", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UNICODE", "xa世ax", "x", "a世ax"); + // Basic tests - UNICODE_CI. + assertStringTrimLeft("UNICODE_CI", "", "", ""); + assertStringTrimLeft("UNICODE_CI", "", "xyz", ""); + assertStringTrimLeft("UNICODE_CI", "asd", "", "asd"); + assertStringTrimLeft("UNICODE_CI", "asd", null, "asd"); + assertStringTrimLeft("UNICODE_CI", " asd ", null, "asd "); + assertStringTrimLeft("UNICODE_CI", " a世a ", null, "a世a "); + assertStringTrimLeft("UNICODE_CI", "asd", "x", "asd"); + assertStringTrimLeft("UNICODE_CI", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UNICODE_CI", "xa世ax", "x", "a世ax"); + // Case variation. + assertStringTrimLeft("UTF8_BINARY", "ddsXXXaa", "asd", "XXXaa"); + assertStringTrimLeft("UTF8_LCASE", "ddsXXXaa", "aSd", "XXXaa"); + assertStringTrimLeft("UNICODE", "ddsXXXaa", "asd", "XXXaa"); + assertStringTrimLeft("UNICODE_CI", "ddsXXXaa", "aSd", "XXXaa"); + // One-to-many case mapping (e.g. Turkish dotted I).. + assertStringTrimLeft("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimLeft("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimLeft("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimLeft("UTF8_LCASE", "ẞaaaẞ", "ß", "aaaẞ"); + assertStringTrimLeft("UTF8_LCASE", "ßaaaß", "ẞ", "aaaß"); + assertStringTrimLeft("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimLeft("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimLeft("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimLeft("UNICODE", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimLeft("UNICODE_CI", "ẞaaaẞ", "ß", "aaaẞ"); + assertStringTrimLeft("UNICODE_CI", "ßaaaß", "ẞ", "aaaß"); + assertStringTrimLeft("UNICODE_CI", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + // One-to-many case mapping (e.g. Turkish dotted I). + assertStringTrimLeft("UTF8_BINARY", "i", "i", ""); + assertStringTrimLeft("UTF8_BINARY", "iii", "I", "iii"); + assertStringTrimLeft("UTF8_BINARY", "I", "iii", "I"); + assertStringTrimLeft("UTF8_BINARY", "ixi", "i", "xi"); + assertStringTrimLeft("UTF8_BINARY", "i", "İ", "i"); + assertStringTrimLeft("UTF8_BINARY", "i\u0307", "İ", "i\u0307"); + assertStringTrimLeft("UTF8_BINARY", "ii\u0307", "İi", "\u0307"); + assertStringTrimLeft("UTF8_BINARY", "iii\u0307", "İi", "\u0307"); + assertStringTrimLeft("UTF8_BINARY", "iiii\u0307", "iİ", "\u0307"); + assertStringTrimLeft("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "\u0307ii\u0307"); + assertStringTrimLeft("UTF8_BINARY", "i\u0307", "i", "\u0307"); + assertStringTrimLeft("UTF8_BINARY", "i\u0307", "\u0307", "i\u0307"); + assertStringTrimLeft("UTF8_BINARY", "i\u0307", "i\u0307", ""); + assertStringTrimLeft("UTF8_BINARY", "i\u0307i\u0307", "i\u0307", ""); + assertStringTrimLeft("UTF8_BINARY", "i\u0307\u0307", "i\u0307", ""); + assertStringTrimLeft("UTF8_BINARY", "i\u0307i", "i\u0307", ""); + assertStringTrimLeft("UTF8_BINARY", "i\u0307i", "İ", "i\u0307i"); + assertStringTrimLeft("UTF8_BINARY", "i\u0307İ", "i\u0307", "İ"); + assertStringTrimLeft("UTF8_BINARY", "i\u0307İ", "İ", "i\u0307İ"); + assertStringTrimLeft("UTF8_BINARY", "İ", "İ", ""); + assertStringTrimLeft("UTF8_BINARY", "IXi", "İ", "IXi"); + assertStringTrimLeft("UTF8_BINARY", "ix\u0307", "Ixİ", "ix\u0307"); + assertStringTrimLeft("UTF8_BINARY", "i\u0307x", "IXİ", "i\u0307x"); + assertStringTrimLeft("UTF8_BINARY", "i\u0307x", "ix\u0307İ", ""); + assertStringTrimLeft("UTF8_BINARY", "İ", "i", "İ"); + assertStringTrimLeft("UTF8_BINARY", "İ", "\u0307", "İ"); + assertStringTrimLeft("UTF8_BINARY", "Ixİ", "i\u0307", "Ixİ"); + assertStringTrimLeft("UTF8_BINARY", "IXİ", "ix\u0307", "IXİ"); + assertStringTrimLeft("UTF8_BINARY", "xi\u0307", "\u0307IX", "xi\u0307"); + assertStringTrimLeft("UTF8_LCASE", "i", "i", ""); + assertStringTrimLeft("UTF8_LCASE", "iii", "I", ""); + assertStringTrimLeft("UTF8_LCASE", "I", "iii", ""); + assertStringTrimLeft("UTF8_LCASE", "ixi", "i", "xi"); + assertStringTrimLeft("UTF8_LCASE", "i", "İ", "i"); + assertStringTrimLeft("UTF8_LCASE", "i\u0307", "İ", ""); + assertStringTrimLeft("UTF8_LCASE", "ii\u0307", "İi", ""); + assertStringTrimLeft("UTF8_LCASE", "iii\u0307", "İi", ""); + assertStringTrimLeft("UTF8_LCASE", "iiii\u0307", "iİ", ""); + assertStringTrimLeft("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", ""); + assertStringTrimLeft("UTF8_LCASE", "i\u0307", "i", "\u0307"); + assertStringTrimLeft("UTF8_LCASE", "i\u0307", "\u0307", "i\u0307"); + assertStringTrimLeft("UTF8_LCASE", "i\u0307", "i\u0307", ""); + assertStringTrimLeft("UTF8_LCASE", "i\u0307i\u0307", "i\u0307", ""); + assertStringTrimLeft("UTF8_LCASE", "i\u0307\u0307", "i\u0307", ""); + assertStringTrimLeft("UTF8_LCASE", "i\u0307i", "i\u0307", ""); + assertStringTrimLeft("UTF8_LCASE", "i\u0307i", "İ", "i"); + assertStringTrimLeft("UTF8_LCASE", "i\u0307İ", "i\u0307", "İ"); + assertStringTrimLeft("UTF8_LCASE", "i\u0307İ", "İ", ""); + assertStringTrimLeft("UTF8_LCASE", "İ", "İ", ""); + assertStringTrimLeft("UTF8_LCASE", "IXi", "İ", "IXi"); + assertStringTrimLeft("UTF8_LCASE", "ix\u0307", "Ixİ", "\u0307"); + assertStringTrimLeft("UTF8_LCASE", "i\u0307x", "IXİ", ""); + assertStringTrimLeft("UTF8_LCASE", "i\u0307x", "I\u0307xİ", ""); + assertStringTrimLeft("UTF8_LCASE", "İ", "i", "İ"); + assertStringTrimLeft("UTF8_LCASE", "İ", "\u0307", "İ"); + assertStringTrimLeft("UTF8_LCASE", "Ixİ", "i\u0307", "xİ"); + assertStringTrimLeft("UTF8_LCASE", "IXİ", "ix\u0307", "İ"); + assertStringTrimLeft("UTF8_LCASE", "xi\u0307", "\u0307IX", ""); + assertStringTrimLeft("UNICODE", "i", "i", ""); + assertStringTrimLeft("UNICODE", "iii", "I", "iii"); + assertStringTrimLeft("UNICODE", "I", "iii", "I"); + assertStringTrimLeft("UNICODE", "ixi", "i", "xi"); + assertStringTrimLeft("UNICODE", "i", "İ", "i"); + assertStringTrimLeft("UNICODE", "i\u0307", "İ", "i\u0307"); + assertStringTrimLeft("UNICODE", "ii\u0307", "İi", "i\u0307"); + assertStringTrimLeft("UNICODE", "iii\u0307", "İi", "i\u0307"); + assertStringTrimLeft("UNICODE", "iiii\u0307", "iİ", "i\u0307"); + assertStringTrimLeft("UNICODE", "ii\u0307ii\u0307", "iİ", "i\u0307ii\u0307"); + assertStringTrimLeft("UNICODE", "i\u0307", "i", "i\u0307"); + assertStringTrimLeft("UNICODE", "i\u0307", "\u0307", "i\u0307"); + assertStringTrimLeft("UNICODE", "i\u0307", "i\u0307", "i\u0307"); + assertStringTrimLeft("UNICODE", "i\u0307i\u0307", "i\u0307", "i\u0307i\u0307"); + assertStringTrimLeft("UNICODE", "i\u0307\u0307", "i\u0307", "i\u0307\u0307"); + assertStringTrimLeft("UNICODE", "i\u0307i", "i\u0307", "i\u0307i"); + assertStringTrimLeft("UNICODE", "i\u0307i", "İ", "i\u0307i"); + assertStringTrimLeft("UNICODE", "i\u0307İ", "i\u0307", "i\u0307İ"); + assertStringTrimLeft("UNICODE", "i\u0307İ", "İ", "i\u0307İ"); + assertStringTrimLeft("UNICODE", "İ", "İ", ""); + assertStringTrimLeft("UNICODE", "IXi", "İ", "IXi"); + assertStringTrimLeft("UNICODE", "ix\u0307", "Ixİ", "ix\u0307"); + assertStringTrimLeft("UNICODE", "i\u0307x", "IXİ", "i\u0307x"); + assertStringTrimLeft("UNICODE", "i\u0307x", "ix\u0307İ", "i\u0307x"); + assertStringTrimLeft("UNICODE", "İ", "i", "İ"); + assertStringTrimLeft("UNICODE", "İ", "\u0307", "İ"); + assertStringTrimLeft("UNICODE", "i\u0307", "i\u0307", "i\u0307"); + assertStringTrimLeft("UNICODE", "Ixİ", "i\u0307", "Ixİ"); + assertStringTrimLeft("UNICODE", "IXİ", "ix\u0307", "IXİ"); + assertStringTrimLeft("UNICODE", "xi\u0307", "\u0307IX", "xi\u0307"); assertStringTrimLeft("UNICODE_CI", "i", "i", ""); assertStringTrimLeft("UNICODE_CI", "iii", "I", ""); assertStringTrimLeft("UNICODE_CI", "I", "iii", ""); assertStringTrimLeft("UNICODE_CI", "ixi", "i", "xi"); assertStringTrimLeft("UNICODE_CI", "i", "İ", "i"); assertStringTrimLeft("UNICODE_CI", "i\u0307", "İ", ""); + assertStringTrimLeft("UNICODE_CI", "ii\u0307", "İi", ""); + assertStringTrimLeft("UNICODE_CI", "iii\u0307", "İi", ""); + assertStringTrimLeft("UNICODE_CI", "iiii\u0307", "iİ", ""); + assertStringTrimLeft("UNICODE_CI", "ii\u0307ii\u0307", "iİ", ""); assertStringTrimLeft("UNICODE_CI", "i\u0307", "i", "i\u0307"); assertStringTrimLeft("UNICODE_CI", "i\u0307", "\u0307", "i\u0307"); assertStringTrimLeft("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307"); @@ -1768,12 +3254,283 @@ public void testStringTrim() throws SparkException { assertStringTrimLeft("UNICODE_CI", "Ixİ", "i\u0307", "xİ"); assertStringTrimLeft("UNICODE_CI", "IXİ", "ix\u0307", "İ"); assertStringTrimLeft("UNICODE_CI", "xi\u0307", "\u0307IX", "i\u0307"); + // Conditional case mapping (e.g. Greek sigmas). + assertStringTrimLeft("UTF8_BINARY", "ςxς", "σ", "ςxς"); + assertStringTrimLeft("UTF8_BINARY", "ςxς", "ς", "xς"); + assertStringTrimLeft("UTF8_BINARY", "ςxς", "Σ", "ςxς"); + assertStringTrimLeft("UTF8_BINARY", "σxσ", "σ", "xσ"); + assertStringTrimLeft("UTF8_BINARY", "σxσ", "ς", "σxσ"); + assertStringTrimLeft("UTF8_BINARY", "σxσ", "Σ", "σxσ"); + assertStringTrimLeft("UTF8_BINARY", "ΣxΣ", "σ", "ΣxΣ"); + assertStringTrimLeft("UTF8_BINARY", "ΣxΣ", "ς", "ΣxΣ"); + assertStringTrimLeft("UTF8_BINARY", "ΣxΣ", "Σ", "xΣ"); + assertStringTrimLeft("UTF8_LCASE", "ςxς", "σ", "xς"); + assertStringTrimLeft("UTF8_LCASE", "ςxς", "ς", "xς"); + assertStringTrimLeft("UTF8_LCASE", "ςxς", "Σ", "xς"); + assertStringTrimLeft("UTF8_LCASE", "σxσ", "σ", "xσ"); + assertStringTrimLeft("UTF8_LCASE", "σxσ", "ς", "xσ"); + assertStringTrimLeft("UTF8_LCASE", "σxσ", "Σ", "xσ"); + assertStringTrimLeft("UTF8_LCASE", "ΣxΣ", "σ", "xΣ"); + assertStringTrimLeft("UTF8_LCASE", "ΣxΣ", "ς", "xΣ"); + assertStringTrimLeft("UTF8_LCASE", "ΣxΣ", "Σ", "xΣ"); + assertStringTrimLeft("UNICODE", "ςxς", "σ", "ςxς"); + assertStringTrimLeft("UNICODE", "ςxς", "ς", "xς"); + assertStringTrimLeft("UNICODE", "ςxς", "Σ", "ςxς"); + assertStringTrimLeft("UNICODE", "σxσ", "σ", "xσ"); + assertStringTrimLeft("UNICODE", "σxσ", "ς", "σxσ"); + assertStringTrimLeft("UNICODE", "σxσ", "Σ", "σxσ"); + assertStringTrimLeft("UNICODE", "ΣxΣ", "σ", "ΣxΣ"); + assertStringTrimLeft("UNICODE", "ΣxΣ", "ς", "ΣxΣ"); + assertStringTrimLeft("UNICODE", "ΣxΣ", "Σ", "xΣ"); + assertStringTrimLeft("UNICODE_CI", "ςxς", "σ", "xς"); + assertStringTrimLeft("UNICODE_CI", "ςxς", "ς", "xς"); + assertStringTrimLeft("UNICODE_CI", "ςxς", "Σ", "xς"); + assertStringTrimLeft("UNICODE_CI", "σxσ", "σ", "xσ"); + assertStringTrimLeft("UNICODE_CI", "σxσ", "ς", "xσ"); + assertStringTrimLeft("UNICODE_CI", "σxσ", "Σ", "xσ"); + assertStringTrimLeft("UNICODE_CI", "ΣxΣ", "σ", "xΣ"); + assertStringTrimLeft("UNICODE_CI", "ΣxΣ", "ς", "xΣ"); + assertStringTrimLeft("UNICODE_CI", "ΣxΣ", "Σ", "xΣ"); + // Unicode normalization. + assertStringTrimLeft("UTF8_BINARY", "åβγδa\u030A", "å", "βγδa\u030A"); + assertStringTrimLeft("UTF8_LCASE", "åβγδa\u030A", "Å", "βγδa\u030A"); + assertStringTrimLeft("UNICODE", "åβγδa\u030A", "å", "βγδa\u030A"); + assertStringTrimLeft("UNICODE_CI", "åβγδa\u030A", "Å", "βγδa\u030A"); + // Surrogate pairs. + assertStringTrimLeft("UTF8_BINARY", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrimLeft("UTF8_LCASE", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrimLeft("UNICODE", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrimLeft("UNICODE_CI", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrimLeft("UTF8_BINARY", "a🙃b🙃c", "a", "🙃b🙃c"); + assertStringTrimLeft("UTF8_LCASE", "a🙃b🙃c", "a", "🙃b🙃c"); + assertStringTrimLeft("UNICODE", "a🙃b🙃c", "a", "🙃b🙃c"); + assertStringTrimLeft("UNICODE_CI", "a🙃b🙃c", "a", "🙃b🙃c"); + assertStringTrimLeft("UTF8_BINARY", "a🙃b🙃c", "a🙃", "b🙃c"); + assertStringTrimLeft("UTF8_LCASE", "a🙃b🙃c", "a🙃", "b🙃c"); + assertStringTrimLeft("UNICODE", "a🙃b🙃c", "a🙃", "b🙃c"); + assertStringTrimLeft("UNICODE_CI", "a🙃b🙃c", "a🙃", "b🙃c"); + assertStringTrimLeft("UTF8_BINARY", "a🙃b🙃c", "a🙃b", "c"); + assertStringTrimLeft("UTF8_LCASE", "a🙃b🙃c", "a🙃b", "c"); + assertStringTrimLeft("UNICODE", "a🙃b🙃c", "a🙃b", "c"); + assertStringTrimLeft("UNICODE_CI", "a🙃b🙃c", "a🙃b", "c"); + assertStringTrimLeft("UTF8_BINARY", "a🙃b🙃c", "abc🙃", ""); + assertStringTrimLeft("UTF8_LCASE", "a🙃b🙃c", "abc🙃", ""); + assertStringTrimLeft("UNICODE", "a🙃b🙃c", "abc🙃", ""); + assertStringTrimLeft("UNICODE_CI", "a🙃b🙃c", "abc🙃", ""); + assertStringTrimLeft("UTF8_BINARY", "😀😆😃😄", "😆😃", "😀😆😃😄"); + assertStringTrimLeft("UTF8_LCASE", "😀😆😃😄", "😆😃", "😀😆😃😄"); + assertStringTrimLeft("UNICODE", "😀😆😃😄", "😆😃", "😀😆😃😄"); + assertStringTrimLeft("UNICODE_CI", "😀😆😃😄", "😆😃", "😀😆😃😄"); + assertStringTrimLeft("UTF8_BINARY", "😀😆😃😄", "😀😆", "😃😄"); + assertStringTrimLeft("UTF8_LCASE", "😀😆😃😄", "😀😆", "😃😄"); + assertStringTrimLeft("UNICODE", "😀😆😃😄", "😀😆", "😃😄"); + assertStringTrimLeft("UNICODE_CI", "😀😆😃😄", "😀😆", "😃😄"); + assertStringTrimLeft("UTF8_BINARY", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrimLeft("UTF8_LCASE", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrimLeft("UNICODE", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrimLeft("UNICODE_CI", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrimLeft("UTF8_BINARY", "𐐅", "𐐅", ""); + assertStringTrimLeft("UTF8_LCASE", "𐐅", "𐐅", ""); + assertStringTrimLeft("UNICODE", "𐐅", "𐐅", ""); + assertStringTrimLeft("UNICODE_CI", "𐐅", "𐐅", ""); + assertStringTrimLeft("UTF8_BINARY", "𐐅", "𐐭", "𐐅"); + assertStringTrimLeft("UTF8_LCASE", "𐐅", "𐐭", ""); + assertStringTrimLeft("UNICODE", "𐐅", "𐐭", "𐐅"); + assertStringTrimLeft("UNICODE_CI", "𐐅", "𐐭", ""); + assertStringTrimLeft("UTF8_BINARY", "𝔸", "𝔸", ""); + assertStringTrimLeft("UTF8_LCASE", "𝔸", "𝔸", ""); + assertStringTrimLeft("UNICODE", "𝔸", "𝔸", ""); + assertStringTrimLeft("UNICODE_CI", "𝔸", "𝔸", ""); + assertStringTrimLeft("UTF8_BINARY", "𝔸", "A", "𝔸"); + assertStringTrimLeft("UTF8_LCASE", "𝔸", "A", "𝔸"); + assertStringTrimLeft("UNICODE", "𝔸", "A", "𝔸"); + assertStringTrimLeft("UNICODE_CI", "𝔸", "A", ""); + assertStringTrimLeft("UTF8_BINARY", "𝔸", "a", "𝔸"); + assertStringTrimLeft("UTF8_LCASE", "𝔸", "a", "𝔸"); + assertStringTrimLeft("UNICODE", "𝔸", "a", "𝔸"); + assertStringTrimLeft("UNICODE_CI", "𝔸", "a", ""); + } + + /** + * Verify the behaviour of the `StringTrimRight` collation support class. + */ + + private void assertStringTrimRight(String collationName, String sourceString, String trimString, + String expected) throws SparkException { + // Prepare the input and expected result. + int collationId = CollationFactory.collationNameToId(collationName); + UTF8String src = UTF8String.fromString(sourceString); + UTF8String trim = UTF8String.fromString(trimString); + UTF8String result; + + if (trimString == null) { + // Trim string is ASCII space. + result = CollationSupport.StringTrimRight.exec(src); + } else { + // Trim string is specified. + result = CollationSupport.StringTrimRight.exec(src, trim, collationId); + } + + // Test that StringTrimRight result is as expected. + assertEquals(UTF8String.fromString(expected), result); + } + + @Test + public void testStringTrimRight() throws SparkException { + // Basic tests. + assertStringTrimRight("UTF8_BINARY", "", "", ""); + assertStringTrimRight("UTF8_BINARY", "", "xyz", ""); + assertStringTrimRight("UTF8_BINARY", "asd", "", "asd"); + assertStringTrimRight("UTF8_BINARY", "asd", null, "asd"); + assertStringTrimRight("UTF8_BINARY", " asd ", null, " asd"); + assertStringTrimRight("UTF8_BINARY", " a世a ", null, " a世a"); + assertStringTrimRight("UTF8_BINARY", "asd", "x", "asd"); + assertStringTrimRight("UTF8_BINARY", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UTF8_BINARY", "xa世ax", "x", "xa世a"); + assertStringTrimRight("UTF8_LCASE", "", "", ""); + assertStringTrimRight("UTF8_LCASE", "", "xyz", ""); + assertStringTrimRight("UTF8_LCASE", "asd", "", "asd"); + assertStringTrimRight("UTF8_LCASE", "asd", null, "asd"); + assertStringTrimRight("UTF8_LCASE", " asd ", null, " asd"); + assertStringTrimRight("UTF8_LCASE", " a世a ", null, " a世a"); + assertStringTrimRight("UTF8_LCASE", "asd", "x", "asd"); + assertStringTrimRight("UTF8_LCASE", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UTF8_LCASE", "xa世ax", "x", "xa世a"); + assertStringTrimRight("UNICODE", "", "", ""); + assertStringTrimRight("UNICODE", "", "xyz", ""); + assertStringTrimRight("UNICODE", "asd", "", "asd"); + assertStringTrimRight("UNICODE", "asd", null, "asd"); + assertStringTrimRight("UNICODE", " asd ", null, " asd"); + assertStringTrimRight("UNICODE", " a世a ", null, " a世a"); + assertStringTrimRight("UNICODE", "asd", "x", "asd"); + assertStringTrimRight("UNICODE", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UNICODE", "xa世ax", "x", "xa世a"); + assertStringTrimRight("UNICODE_CI", "", "", ""); + assertStringTrimRight("UNICODE_CI", "", "xyz", ""); + assertStringTrimRight("UNICODE_CI", "asd", "", "asd"); + assertStringTrimRight("UNICODE_CI", "asd", null, "asd"); + assertStringTrimRight("UNICODE_CI", " asd ", null, " asd"); + assertStringTrimRight("UNICODE_CI", " a世a ", null, " a世a"); + assertStringTrimRight("UNICODE_CI", "asd", "x", "asd"); + assertStringTrimRight("UNICODE_CI", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UNICODE_CI", "xa世ax", "x", "xa世a"); + // Case variation. + assertStringTrimRight("UTF8_BINARY", "ddsXXXaa", "asd", "ddsXXX"); + assertStringTrimRight("UTF8_LCASE", "ddsXXXaa", "AsD", "ddsXXX"); + assertStringTrimRight("UNICODE", "ddsXXXaa", "asd", "ddsXXX"); + assertStringTrimRight("UNICODE_CI", "ddsXXXaa", "AsD", "ddsXXX"); + // One-to-many case mapping (e.g. Turkish dotted I).. + assertStringTrimRight("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimRight("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimRight("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + assertStringTrimRight("UTF8_LCASE", "ẞaaaẞ", "ß", "ẞaaa"); + assertStringTrimRight("UTF8_LCASE", "ßaaaß", "ẞ", "ßaaa"); + assertStringTrimRight("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + assertStringTrimRight("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimRight("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimRight("UNICODE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + assertStringTrimRight("UNICODE_CI", "ẞaaaẞ", "ß", "ẞaaa"); + assertStringTrimRight("UNICODE_CI", "ßaaaß", "ẞ", "ßaaa"); + assertStringTrimRight("UNICODE_CI", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + // One-to-many case mapping (e.g. Turkish dotted I). + assertStringTrimRight("UTF8_BINARY", "i", "i", ""); + assertStringTrimRight("UTF8_BINARY", "iii", "I", "iii"); + assertStringTrimRight("UTF8_BINARY", "I", "iii", "I"); + assertStringTrimRight("UTF8_BINARY", "ixi", "i", "ix"); + assertStringTrimRight("UTF8_BINARY", "i", "İ", "i"); + assertStringTrimRight("UTF8_BINARY", "i\u0307", "İ", "i\u0307"); + assertStringTrimRight("UTF8_BINARY", "ii\u0307", "İi", "ii\u0307"); + assertStringTrimRight("UTF8_BINARY", "iii\u0307", "İi", "iii\u0307"); + assertStringTrimRight("UTF8_BINARY", "iiii\u0307", "iİ", "iiii\u0307"); + assertStringTrimRight("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "ii\u0307ii\u0307"); + assertStringTrimRight("UTF8_BINARY", "i\u0307", "i", "i\u0307"); + assertStringTrimRight("UTF8_BINARY", "i\u0307", "\u0307", "i"); + assertStringTrimRight("UTF8_BINARY", "i\u0307", "i\u0307", ""); + assertStringTrimRight("UTF8_BINARY", "i\u0307i\u0307", "i\u0307", ""); + assertStringTrimRight("UTF8_BINARY", "i\u0307\u0307", "i\u0307", ""); + assertStringTrimRight("UTF8_BINARY", "i\u0307i", "i\u0307", ""); + assertStringTrimRight("UTF8_BINARY", "i\u0307i", "İ", "i\u0307i"); + assertStringTrimRight("UTF8_BINARY", "i\u0307İ", "i\u0307", "i\u0307İ"); + assertStringTrimRight("UTF8_BINARY", "i\u0307İ", "İ", "i\u0307"); + assertStringTrimRight("UTF8_BINARY", "İ", "İ", ""); + assertStringTrimRight("UTF8_BINARY", "IXi", "İ", "IXi"); + assertStringTrimRight("UTF8_BINARY", "ix\u0307", "Ixİ", "ix\u0307"); + assertStringTrimRight("UTF8_BINARY", "i\u0307x", "IXİ", "i\u0307x"); + assertStringTrimRight("UTF8_BINARY", "i\u0307x", "ix\u0307İ", ""); + assertStringTrimRight("UTF8_BINARY", "İ", "i", "İ"); + assertStringTrimRight("UTF8_BINARY", "İ", "\u0307", "İ"); + assertStringTrimRight("UTF8_BINARY", "Ixİ", "i\u0307", "Ixİ"); + assertStringTrimRight("UTF8_BINARY", "IXİ", "ix\u0307", "IXİ"); + assertStringTrimRight("UTF8_BINARY", "xi\u0307", "\u0307IX", "xi"); + assertStringTrimRight("UTF8_LCASE", "i", "i", ""); + assertStringTrimRight("UTF8_LCASE", "iii", "I", ""); + assertStringTrimRight("UTF8_LCASE", "I", "iii", ""); + assertStringTrimRight("UTF8_LCASE", "ixi", "i", "ix"); + assertStringTrimRight("UTF8_LCASE", "i", "İ", "i"); + assertStringTrimRight("UTF8_LCASE", "i\u0307", "İ", ""); + assertStringTrimRight("UTF8_LCASE", "ii\u0307", "İi", ""); + assertStringTrimRight("UTF8_LCASE", "iii\u0307", "İi", ""); + assertStringTrimRight("UTF8_LCASE", "iiii\u0307", "iİ", ""); + assertStringTrimRight("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", ""); + assertStringTrimRight("UTF8_LCASE", "i\u0307", "i", "i\u0307"); + assertStringTrimRight("UTF8_LCASE", "i\u0307", "\u0307", "i"); + assertStringTrimRight("UTF8_LCASE", "i\u0307", "i\u0307", ""); + assertStringTrimRight("UTF8_LCASE", "i\u0307i\u0307", "i\u0307", ""); + assertStringTrimRight("UTF8_LCASE", "i\u0307\u0307", "i\u0307", ""); + assertStringTrimRight("UTF8_LCASE", "i\u0307i", "i\u0307", ""); + assertStringTrimRight("UTF8_LCASE", "i\u0307i", "İ", "i\u0307i"); + assertStringTrimRight("UTF8_LCASE", "i\u0307İ", "i\u0307", "i\u0307İ"); + assertStringTrimRight("UTF8_LCASE", "i\u0307İ", "İ", ""); + assertStringTrimRight("UTF8_LCASE", "İ", "İ", ""); + assertStringTrimRight("UTF8_LCASE", "IXi", "İ", "IXi"); + assertStringTrimRight("UTF8_LCASE", "ix\u0307", "Ixİ", "ix\u0307"); + assertStringTrimRight("UTF8_LCASE", "i\u0307x", "IXİ", ""); + assertStringTrimRight("UTF8_LCASE", "i\u0307x", "I\u0307xİ", ""); + assertStringTrimRight("UTF8_LCASE", "İ", "i", "İ"); + assertStringTrimRight("UTF8_LCASE", "İ", "\u0307", "İ"); + assertStringTrimRight("UTF8_LCASE", "Ixİ", "i\u0307", "Ixİ"); + assertStringTrimRight("UTF8_LCASE", "IXİ", "ix\u0307", "IXİ"); + assertStringTrimRight("UTF8_LCASE", "xi\u0307", "\u0307IX", ""); + assertStringTrimRight("UNICODE", "i", "i", ""); + assertStringTrimRight("UNICODE", "iii", "I", "iii"); + assertStringTrimRight("UNICODE", "I", "iii", "I"); + assertStringTrimRight("UNICODE", "ixi", "i", "ix"); + assertStringTrimRight("UNICODE", "i", "İ", "i"); + assertStringTrimRight("UNICODE", "i\u0307", "İ", "i\u0307"); + assertStringTrimRight("UTF8_BINARY", "ii\u0307", "İi", "ii\u0307"); + assertStringTrimRight("UTF8_BINARY", "iii\u0307", "İi", "iii\u0307"); + assertStringTrimRight("UTF8_BINARY", "iiii\u0307", "iİ", "iiii\u0307"); + assertStringTrimRight("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "ii\u0307ii\u0307"); + assertStringTrimRight("UNICODE", "i\u0307", "i", "i\u0307"); + assertStringTrimRight("UNICODE", "i\u0307", "\u0307", "i\u0307"); + assertStringTrimRight("UNICODE", "i\u0307", "i\u0307", "i\u0307"); + assertStringTrimRight("UNICODE", "i\u0307i\u0307", "i\u0307", "i\u0307i\u0307"); + assertStringTrimRight("UNICODE", "i\u0307\u0307", "i\u0307", "i\u0307\u0307"); + assertStringTrimRight("UNICODE", "i\u0307i", "i\u0307", "i\u0307"); + assertStringTrimRight("UNICODE", "i\u0307i", "İ", "i\u0307i"); + assertStringTrimRight("UNICODE", "i\u0307İ", "i\u0307", "i\u0307İ"); + assertStringTrimRight("UNICODE", "i\u0307İ", "İ", "i\u0307"); + assertStringTrimRight("UNICODE", "İ", "İ", ""); + assertStringTrimRight("UNICODE", "IXi", "İ", "IXi"); + assertStringTrimRight("UNICODE", "ix\u0307", "Ixİ", "ix\u0307"); + assertStringTrimRight("UNICODE", "i\u0307x", "IXİ", "i\u0307x"); + assertStringTrimRight("UNICODE", "i\u0307x", "ix\u0307İ", "i\u0307"); + assertStringTrimRight("UNICODE", "İ", "i", "İ"); + assertStringTrimRight("UNICODE", "İ", "\u0307", "İ"); + assertStringTrimRight("UNICODE", "i\u0307", "i\u0307", "i\u0307"); + assertStringTrimRight("UNICODE", "Ixİ", "i\u0307", "Ixİ"); + assertStringTrimRight("UNICODE", "IXİ", "ix\u0307", "IXİ"); + assertStringTrimRight("UNICODE", "xi\u0307", "\u0307IX", "xi\u0307"); assertStringTrimRight("UNICODE_CI", "i", "i", ""); assertStringTrimRight("UNICODE_CI", "iii", "I", ""); assertStringTrimRight("UNICODE_CI", "I", "iii", ""); assertStringTrimRight("UNICODE_CI", "ixi", "i", "ix"); assertStringTrimRight("UNICODE_CI", "i", "İ", "i"); assertStringTrimRight("UNICODE_CI", "i\u0307", "İ", ""); + assertStringTrimRight("UNICODE_CI", "ii\u0307", "İi", ""); + assertStringTrimRight("UNICODE_CI", "iii\u0307", "İi", ""); + assertStringTrimRight("UNICODE_CI", "iiii\u0307", "iİ", ""); + assertStringTrimRight("UNICODE_CI", "ii\u0307ii\u0307", "iİ", ""); assertStringTrimRight("UNICODE_CI", "i\u0307", "i", "i\u0307"); assertStringTrimRight("UNICODE_CI", "i\u0307", "\u0307", "i\u0307"); assertStringTrimRight("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307"); @@ -1791,29 +3548,10 @@ public void testStringTrim() throws SparkException { assertStringTrimRight("UNICODE_CI", "İ", "i", "İ"); assertStringTrimRight("UNICODE_CI", "İ", "\u0307", "İ"); assertStringTrimRight("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307"); - assertStringTrimRight("UNICODE_CI", "Ixİ", "i\u0307", "Ixİ"); - assertStringTrimRight("UNICODE_CI", "IXİ", "ix\u0307", "IXİ"); - assertStringTrimRight("UNICODE_CI", "xi\u0307", "\u0307IX", "xi\u0307"); - - // Greek sigmas - UTF8_BINARY. - assertStringTrim("UTF8_BINARY", "ςxς", "σ", "ςxς"); - assertStringTrim("UTF8_BINARY", "ςxς", "ς", "x"); - assertStringTrim("UTF8_BINARY", "ςxς", "Σ", "ςxς"); - assertStringTrim("UTF8_BINARY", "σxσ", "σ", "x"); - assertStringTrim("UTF8_BINARY", "σxσ", "ς", "σxσ"); - assertStringTrim("UTF8_BINARY", "σxσ", "Σ", "σxσ"); - assertStringTrim("UTF8_BINARY", "ΣxΣ", "σ", "ΣxΣ"); - assertStringTrim("UTF8_BINARY", "ΣxΣ", "ς", "ΣxΣ"); - assertStringTrim("UTF8_BINARY", "ΣxΣ", "Σ", "x"); - assertStringTrimLeft("UTF8_BINARY", "ςxς", "σ", "ςxς"); - assertStringTrimLeft("UTF8_BINARY", "ςxς", "ς", "xς"); - assertStringTrimLeft("UTF8_BINARY", "ςxς", "Σ", "ςxς"); - assertStringTrimLeft("UTF8_BINARY", "σxσ", "σ", "xσ"); - assertStringTrimLeft("UTF8_BINARY", "σxσ", "ς", "σxσ"); - assertStringTrimLeft("UTF8_BINARY", "σxσ", "Σ", "σxσ"); - assertStringTrimLeft("UTF8_BINARY", "ΣxΣ", "σ", "ΣxΣ"); - assertStringTrimLeft("UTF8_BINARY", "ΣxΣ", "ς", "ΣxΣ"); - assertStringTrimLeft("UTF8_BINARY", "ΣxΣ", "Σ", "xΣ"); + assertStringTrimRight("UNICODE_CI", "Ixİ", "i\u0307", "Ixİ"); + assertStringTrimRight("UNICODE_CI", "IXİ", "ix\u0307", "IXİ"); + assertStringTrimRight("UNICODE_CI", "xi\u0307", "\u0307IX", "xi\u0307"); + // Conditional case mapping (e.g. Greek sigmas). assertStringTrimRight("UTF8_BINARY", "ςxς", "σ", "ςxς"); assertStringTrimRight("UTF8_BINARY", "ςxς", "ς", "ςx"); assertStringTrimRight("UTF8_BINARY", "ςxς", "Σ", "ςxς"); @@ -1823,25 +3561,6 @@ public void testStringTrim() throws SparkException { assertStringTrimRight("UTF8_BINARY", "ΣxΣ", "σ", "ΣxΣ"); assertStringTrimRight("UTF8_BINARY", "ΣxΣ", "ς", "ΣxΣ"); assertStringTrimRight("UTF8_BINARY", "ΣxΣ", "Σ", "Σx"); - // Greek sigmas - UTF8_LCASE. - assertStringTrim("UTF8_LCASE", "ςxς", "σ", "x"); - assertStringTrim("UTF8_LCASE", "ςxς", "ς", "x"); - assertStringTrim("UTF8_LCASE", "ςxς", "Σ", "x"); - assertStringTrim("UTF8_LCASE", "σxσ", "σ", "x"); - assertStringTrim("UTF8_LCASE", "σxσ", "ς", "x"); - assertStringTrim("UTF8_LCASE", "σxσ", "Σ", "x"); - assertStringTrim("UTF8_LCASE", "ΣxΣ", "σ", "x"); - assertStringTrim("UTF8_LCASE", "ΣxΣ", "ς", "x"); - assertStringTrim("UTF8_LCASE", "ΣxΣ", "Σ", "x"); - assertStringTrimLeft("UTF8_LCASE", "ςxς", "σ", "xς"); - assertStringTrimLeft("UTF8_LCASE", "ςxς", "ς", "xς"); - assertStringTrimLeft("UTF8_LCASE", "ςxς", "Σ", "xς"); - assertStringTrimLeft("UTF8_LCASE", "σxσ", "σ", "xσ"); - assertStringTrimLeft("UTF8_LCASE", "σxσ", "ς", "xσ"); - assertStringTrimLeft("UTF8_LCASE", "σxσ", "Σ", "xσ"); - assertStringTrimLeft("UTF8_LCASE", "ΣxΣ", "σ", "xΣ"); - assertStringTrimLeft("UTF8_LCASE", "ΣxΣ", "ς", "xΣ"); - assertStringTrimLeft("UTF8_LCASE", "ΣxΣ", "Σ", "xΣ"); assertStringTrimRight("UTF8_LCASE", "ςxς", "σ", "ςx"); assertStringTrimRight("UTF8_LCASE", "ςxς", "ς", "ςx"); assertStringTrimRight("UTF8_LCASE", "ςxς", "Σ", "ςx"); @@ -1851,25 +3570,6 @@ public void testStringTrim() throws SparkException { assertStringTrimRight("UTF8_LCASE", "ΣxΣ", "σ", "Σx"); assertStringTrimRight("UTF8_LCASE", "ΣxΣ", "ς", "Σx"); assertStringTrimRight("UTF8_LCASE", "ΣxΣ", "Σ", "Σx"); - // Greek sigmas - UNICODE. - assertStringTrim("UNICODE", "ςxς", "σ", "ςxς"); - assertStringTrim("UNICODE", "ςxς", "ς", "x"); - assertStringTrim("UNICODE", "ςxς", "Σ", "ςxς"); - assertStringTrim("UNICODE", "σxσ", "σ", "x"); - assertStringTrim("UNICODE", "σxσ", "ς", "σxσ"); - assertStringTrim("UNICODE", "σxσ", "Σ", "σxσ"); - assertStringTrim("UNICODE", "ΣxΣ", "σ", "ΣxΣ"); - assertStringTrim("UNICODE", "ΣxΣ", "ς", "ΣxΣ"); - assertStringTrim("UNICODE", "ΣxΣ", "Σ", "x"); - assertStringTrimLeft("UNICODE", "ςxς", "σ", "ςxς"); - assertStringTrimLeft("UNICODE", "ςxς", "ς", "xς"); - assertStringTrimLeft("UNICODE", "ςxς", "Σ", "ςxς"); - assertStringTrimLeft("UNICODE", "σxσ", "σ", "xσ"); - assertStringTrimLeft("UNICODE", "σxσ", "ς", "σxσ"); - assertStringTrimLeft("UNICODE", "σxσ", "Σ", "σxσ"); - assertStringTrimLeft("UNICODE", "ΣxΣ", "σ", "ΣxΣ"); - assertStringTrimLeft("UNICODE", "ΣxΣ", "ς", "ΣxΣ"); - assertStringTrimLeft("UNICODE", "ΣxΣ", "Σ", "xΣ"); assertStringTrimRight("UNICODE", "ςxς", "σ", "ςxς"); assertStringTrimRight("UNICODE", "ςxς", "ς", "ςx"); assertStringTrimRight("UNICODE", "ςxς", "Σ", "ςxς"); @@ -1879,25 +3579,6 @@ public void testStringTrim() throws SparkException { assertStringTrimRight("UNICODE", "ΣxΣ", "σ", "ΣxΣ"); assertStringTrimRight("UNICODE", "ΣxΣ", "ς", "ΣxΣ"); assertStringTrimRight("UNICODE", "ΣxΣ", "Σ", "Σx"); - // Greek sigmas - UNICODE_CI. - assertStringTrim("UNICODE_CI", "ςxς", "σ", "x"); - assertStringTrim("UNICODE_CI", "ςxς", "ς", "x"); - assertStringTrim("UNICODE_CI", "ςxς", "Σ", "x"); - assertStringTrim("UNICODE_CI", "σxσ", "σ", "x"); - assertStringTrim("UNICODE_CI", "σxσ", "ς", "x"); - assertStringTrim("UNICODE_CI", "σxσ", "Σ", "x"); - assertStringTrim("UNICODE_CI", "ΣxΣ", "σ", "x"); - assertStringTrim("UNICODE_CI", "ΣxΣ", "ς", "x"); - assertStringTrim("UNICODE_CI", "ΣxΣ", "Σ", "x"); - assertStringTrimLeft("UNICODE_CI", "ςxς", "σ", "xς"); - assertStringTrimLeft("UNICODE_CI", "ςxς", "ς", "xς"); - assertStringTrimLeft("UNICODE_CI", "ςxς", "Σ", "xς"); - assertStringTrimLeft("UNICODE_CI", "σxσ", "σ", "xσ"); - assertStringTrimLeft("UNICODE_CI", "σxσ", "ς", "xσ"); - assertStringTrimLeft("UNICODE_CI", "σxσ", "Σ", "xσ"); - assertStringTrimLeft("UNICODE_CI", "ΣxΣ", "σ", "xΣ"); - assertStringTrimLeft("UNICODE_CI", "ΣxΣ", "ς", "xΣ"); - assertStringTrimLeft("UNICODE_CI", "ΣxΣ", "Σ", "xΣ"); assertStringTrimRight("UNICODE_CI", "ςxς", "σ", "ςx"); assertStringTrimRight("UNICODE_CI", "ςxς", "ς", "ςx"); assertStringTrimRight("UNICODE_CI", "ςxς", "Σ", "ςx"); @@ -1907,186 +3588,287 @@ public void testStringTrim() throws SparkException { assertStringTrimRight("UNICODE_CI", "ΣxΣ", "σ", "Σx"); assertStringTrimRight("UNICODE_CI", "ΣxΣ", "ς", "Σx"); assertStringTrimRight("UNICODE_CI", "ΣxΣ", "Σ", "Σx"); - - // Unicode normalization - UTF8_BINARY. - assertStringTrim("UTF8_BINARY", "åβγδa\u030A", "å", "βγδa\u030A"); - assertStringTrimLeft("UTF8_BINARY", "åβγδa\u030A", "å", "βγδa\u030A"); + // Unicode normalization. assertStringTrimRight("UTF8_BINARY", "åβγδa\u030A", "å", "åβγδa\u030A"); - // Unicode normalization - UTF8_LCASE. - assertStringTrim("UTF8_LCASE", "åβγδa\u030A", "Å", "βγδa\u030A"); - assertStringTrimLeft("UTF8_LCASE", "åβγδa\u030A", "Å", "βγδa\u030A"); assertStringTrimRight("UTF8_LCASE", "åβγδa\u030A", "Å", "åβγδa\u030A"); - // Unicode normalization - UNICODE. - assertStringTrim("UNICODE", "åβγδa\u030A", "å", "βγδ"); - assertStringTrimLeft("UNICODE", "åβγδa\u030A", "å", "βγδa\u030A"); assertStringTrimRight("UNICODE", "åβγδa\u030A", "å", "åβγδ"); - // Unicode normalization - UNICODE_CI. - assertStringTrim("UNICODE_CI", "åβγδa\u030A", "Å", "βγδ"); - assertStringTrimLeft("UNICODE_CI", "åβγδa\u030A", "Å", "βγδa\u030A"); assertStringTrimRight("UNICODE_CI", "åβγδa\u030A", "Å", "åβγδ"); + // Surrogate pairs. + assertStringTrimRight("UTF8_BINARY", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrimRight("UTF8_LCASE", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrimRight("UNICODE", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrimRight("UNICODE_CI", "a🙃b🙃c", "🙃", "a🙃b🙃c"); + assertStringTrimRight("UTF8_BINARY", "a🙃b🙃c", "c", "a🙃b🙃"); + assertStringTrimRight("UTF8_LCASE", "a🙃b🙃c", "c", "a🙃b🙃"); + assertStringTrimRight("UNICODE", "a🙃b🙃c", "c", "a🙃b🙃"); + assertStringTrimRight("UNICODE_CI", "a🙃b🙃c", "c", "a🙃b🙃"); + assertStringTrimRight("UTF8_BINARY", "a🙃b🙃c", "c🙃", "a🙃b"); + assertStringTrimRight("UTF8_LCASE", "a🙃b🙃c", "c🙃", "a🙃b"); + assertStringTrimRight("UNICODE", "a🙃b🙃c", "c🙃", "a🙃b"); + assertStringTrimRight("UNICODE_CI", "a🙃b🙃c", "c🙃", "a🙃b"); + assertStringTrimRight("UTF8_BINARY", "a🙃b🙃c", "c🙃b", "a"); + assertStringTrimRight("UTF8_LCASE", "a🙃b🙃c", "c🙃b", "a"); + assertStringTrimRight("UNICODE", "a🙃b🙃c", "c🙃b", "a"); + assertStringTrimRight("UNICODE_CI", "a🙃b🙃c", "c🙃b", "a"); + assertStringTrimRight("UTF8_BINARY", "a🙃b🙃c", "abc🙃", ""); + assertStringTrimRight("UTF8_LCASE", "a🙃b🙃c", "abc🙃", ""); + assertStringTrimRight("UNICODE", "a🙃b🙃c", "abc🙃", ""); + assertStringTrimRight("UNICODE_CI", "a🙃b🙃c", "abc🙃", ""); + assertStringTrimRight("UTF8_BINARY", "😀😆😃😄", "😆😃", "😀😆😃😄"); + assertStringTrimRight("UTF8_LCASE", "😀😆😃😄", "😆😃", "😀😆😃😄"); + assertStringTrimRight("UNICODE", "😀😆😃😄", "😆😃", "😀😆😃😄"); + assertStringTrimRight("UNICODE_CI", "😀😆😃😄", "😆😃", "😀😆😃😄"); + assertStringTrimRight("UTF8_BINARY", "😀😆😃😄", "😃😄", "😀😆"); + assertStringTrimRight("UTF8_LCASE", "😀😆😃😄", "😃😄", "😀😆"); + assertStringTrimRight("UNICODE", "😀😆😃😄", "😃😄", "😀😆"); + assertStringTrimRight("UNICODE_CI", "😀😆😃😄", "😃😄", "😀😆"); + assertStringTrimRight("UTF8_BINARY", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrimRight("UTF8_LCASE", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrimRight("UNICODE", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrimRight("UNICODE_CI", "😀😆😃😄", "😀😆😃😄", ""); + assertStringTrimRight("UTF8_BINARY", "𐐅", "𐐅", ""); + assertStringTrimRight("UTF8_LCASE", "𐐅", "𐐅", ""); + assertStringTrimRight("UNICODE", "𐐅", "𐐅", ""); + assertStringTrimRight("UNICODE_CI", "𐐅", "𐐅", ""); + assertStringTrimRight("UTF8_BINARY", "𐐅", "𐐭", "𐐅"); + assertStringTrimRight("UTF8_LCASE", "𐐅", "𐐭", ""); + assertStringTrimRight("UNICODE", "𐐅", "𐐭", "𐐅"); + assertStringTrimRight("UNICODE_CI", "𐐅", "𐐭", ""); + assertStringTrimRight("UTF8_BINARY", "𝔸", "𝔸", ""); + assertStringTrimRight("UTF8_LCASE", "𝔸", "𝔸", ""); + assertStringTrimRight("UNICODE", "𝔸", "𝔸", ""); + assertStringTrimRight("UNICODE_CI", "𝔸", "𝔸", ""); + assertStringTrimRight("UTF8_BINARY", "𝔸", "A", "𝔸"); + assertStringTrimRight("UTF8_LCASE", "𝔸", "A", "𝔸"); + assertStringTrimRight("UNICODE", "𝔸", "A", "𝔸"); + assertStringTrimRight("UNICODE_CI", "𝔸", "A", ""); + assertStringTrimRight("UTF8_BINARY", "𝔸", "a", "𝔸"); + assertStringTrimRight("UTF8_LCASE", "𝔸", "a", "𝔸"); + assertStringTrimRight("UNICODE", "𝔸", "a", "𝔸"); + assertStringTrimRight("UNICODE_CI", "𝔸", "a", ""); } - private void assertStringTranslate( - String inputString, - String matchingString, - String replaceString, - String collationName, - String expectedResultString) throws SparkException { + /** + * Verify the behaviour of the `StringTranslate` collation support class. + */ + + private void assertStringTranslate(String inputString, String matchingString, + String replaceString, String collationName, String expected) throws SparkException { int collationId = CollationFactory.collationNameToId(collationName); Map dict = buildDict(matchingString, replaceString); UTF8String source = UTF8String.fromString(inputString); UTF8String result = CollationSupport.StringTranslate.exec(source, dict, collationId); - assertEquals(expectedResultString, result.toString()); + assertEquals(UTF8String.fromString(expected), result); } @Test public void testStringTranslate() throws SparkException { - // Basic tests - UTF8_BINARY. + // Empty strings. + assertStringTranslate("", "", "", "UTF8_BINARY", ""); + assertStringTranslate("", "", "", "UTF8_LCASE", ""); + assertStringTranslate("", "", "", "UNICODE", ""); + assertStringTranslate("", "", "", "UNICODE_CI", ""); + assertStringTranslate("abc", "", "", "UTF8_BINARY", "abc"); + assertStringTranslate("abc", "", "", "UTF8_LCASE", "abc"); + assertStringTranslate("abc", "", "", "UNICODE", "abc"); + assertStringTranslate("abc", "", "", "UNICODE_CI", "abc"); + assertStringTranslate("", "b", "", "UTF8_BINARY", ""); + assertStringTranslate("", "b", "", "UTF8_LCASE", ""); + assertStringTranslate("", "b", "", "UNICODE", ""); + assertStringTranslate("", "b", "", "UNICODE_CI", ""); + assertStringTranslate("", "", "x", "UTF8_BINARY", ""); + assertStringTranslate("", "", "x", "UTF8_LCASE", ""); + assertStringTranslate("", "", "x", "UNICODE", ""); + assertStringTranslate("", "", "x", "UNICODE_CI", ""); + assertStringTranslate("abc", "b", "", "UTF8_BINARY", "ac"); + assertStringTranslate("abc", "b", "", "UTF8_LCASE", "ac"); + assertStringTranslate("abc", "b", "", "UNICODE", "ac"); + assertStringTranslate("abc", "b", "", "UNICODE_CI", "ac"); + assertStringTranslate("abc", "", "x", "UTF8_BINARY", "abc"); + assertStringTranslate("abc", "", "x", "UTF8_LCASE", "abc"); + assertStringTranslate("abc", "", "x", "UNICODE", "abc"); + assertStringTranslate("abc", "", "x", "UNICODE_CI", "abc"); + assertStringTranslate("", "b", "x", "UTF8_BINARY", ""); + assertStringTranslate("", "b", "x", "UTF8_LCASE", ""); + assertStringTranslate("", "b", "x", "UNICODE", ""); + assertStringTranslate("", "b", "x", "UNICODE_CI", ""); + // Basic tests. + assertStringTranslate("abc", "b", "x", "UTF8_BINARY", "axc"); + assertStringTranslate("abc", "b", "x", "UTF8_LCASE", "axc"); + assertStringTranslate("abc", "b", "x", "UNICODE", "axc"); + assertStringTranslate("abc", "b", "x", "UNICODE_CI", "axc"); assertStringTranslate("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"); - assertStringTranslate("Translate", "Rn", "1234", "UTF8_BINARY", "Tra2slate"); - assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_BINARY", "Tra2s3a4e"); - assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_BINARY", "TRaxsXaxe"); - assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY", "TxaxsXaxeX"); - assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY", "TXaxsXaxex"); - assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY", "test大千世AX大千世A"); - assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY", "大千世界test大千世界"); - assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY", "Oeso大千世界大千世界"); - assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY", "大千世界大千世界oesO"); - assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY", "世世世界世世世界tesT"); - assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"); - assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"); - assertStringTranslate("abcdef", "abcde", "123", "UTF8_BINARY", "123f"); - // Basic tests - UTF8_LCASE. assertStringTranslate("Translate", "Rnlt", "12", "UTF8_LCASE", "1a2sae"); - assertStringTranslate("Translate", "Rn", "1234", "UTF8_LCASE", "T1a2slate"); - assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"); - assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"); - assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", "xxaxsXaxex"); - assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", "xXaxsXaxeX"); - assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", "test大千世AB大千世A"); - assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", "大千世界abca大千世界"); - assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"); - assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"); - assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"); - assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"); - assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"); - assertStringTranslate("abcdef", "abcde", "123", "UTF8_LCASE", "123f"); - // Basic tests - UNICODE. assertStringTranslate("Translate", "Rnlt", "12", "UNICODE", "Tra2sae"); - assertStringTranslate("Translate", "Rn", "1234", "UNICODE", "Tra2slate"); - assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"); - assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"); - assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"); - assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"); - assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"); - assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE", "大千世界test大千世界"); - assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"); - assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"); - assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE", "世世世界世世世界tesT"); - assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"); - assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"); - assertStringTranslate("abcdef", "abcde", "123", "UNICODE", "123f"); - // Basic tests - UNICODE_CI. assertStringTranslate("Translate", "Rnlt", "12", "UNICODE_CI", "1a2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UTF8_BINARY", "Tra2slate"); + assertStringTranslate("Translate", "Rn", "1234", "UTF8_LCASE", "T1a2slate"); + assertStringTranslate("Translate", "Rn", "1234", "UNICODE", "Tra2slate"); assertStringTranslate("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate"); + assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_BINARY", "Tra2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"); assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_BINARY", "TRaxsXaxe"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"); assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY", "TxaxsXaxeX"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", "xxaxsXaxex"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"); assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY", "TXaxsXaxex"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", "xXaxsXaxeX"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"); assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY", "test大千世AX大千世A"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", "test大千世AB大千世A"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"); assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY", "大千世界test大千世界"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", "大千世界abca大千世界"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE", "大千世界test大千世界"); assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY", "Oeso大千世界大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"); assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY", "大千世界大千世界oesO"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"); assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY", "世世世界世世世界tesT"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE", "世世世界世世世界tesT"); assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"); assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"); assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UTF8_BINARY", "123f"); + assertStringTranslate("abcdef", "abcde", "123", "UTF8_LCASE", "123f"); + assertStringTranslate("abcdef", "abcde", "123", "UNICODE", "123f"); assertStringTranslate("abcdef", "abcde", "123", "UNICODE_CI", "123f"); - - // One-to-many case mapping - UTF8_BINARY. + assertStringTranslate("abcdëÈêf", "ÊèË", "123", "AF_CI", "abcd321f"); + // One-to-many case mapping (e.g. Turkish dotted I). assertStringTranslate("İ", "i\u0307", "xy", "UTF8_BINARY", "İ"); - assertStringTranslate("i\u0307", "İ", "xy", "UTF8_BINARY", "i\u0307"); - assertStringTranslate("i\u030A", "İ", "x", "UTF8_BINARY", "i\u030A"); - assertStringTranslate("i\u030A", "İi", "xy", "UTF8_BINARY", "y\u030A"); - assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"); - assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"); - assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"); - assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY", "12bc3"); - assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY", "a2bcå"); - assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY", "3\u030Aβφδ1\u0307"); - // One-to-many case mapping - UTF8_LCASE. assertStringTranslate("İ", "i\u0307", "xy", "UTF8_LCASE", "İ"); - assertStringTranslate("i\u0307", "İ", "xy", "UTF8_LCASE", "x"); - assertStringTranslate("i\u030A", "İ", "x", "UTF8_LCASE", "i\u030A"); - assertStringTranslate("i\u030A", "İi", "xy", "UTF8_LCASE", "y\u030A"); - assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11"); - assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_LCASE", "11"); - assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23"); - assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE", "12bc3"); - assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE", "12bc3"); - assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE", "3\u030Aβφδ2"); - // One-to-many case mapping - UNICODE. assertStringTranslate("İ", "i\u0307", "xy", "UNICODE", "İ"); - assertStringTranslate("i\u0307", "İ", "xy", "UNICODE", "i\u0307"); - assertStringTranslate("i\u030A", "İ", "x", "UNICODE", "i\u030A"); - assertStringTranslate("i\u030A", "İi", "xy", "UNICODE", "i\u030A"); - assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE", "1i\u0307"); - assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"); - assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE", "İi\u0307"); - assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3"); - assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE", "a\u030Abcå"); - assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE", "4βφδ2"); - // One-to-many case mapping - UNICODE_CI. assertStringTranslate("İ", "i\u0307", "xy", "UNICODE_CI", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UTF8_BINARY", "i\u0307"); + assertStringTranslate("i\u0307", "İ", "xy", "UTF8_LCASE", "x"); + assertStringTranslate("i\u0307", "İ", "xy", "UNICODE", "i\u0307"); assertStringTranslate("i\u0307", "İ", "xy", "UNICODE_CI", "x"); + assertStringTranslate("i\u030A", "İ", "x", "UTF8_BINARY", "i\u030A"); + assertStringTranslate("i\u030A", "İ", "x", "UTF8_LCASE", "i\u030A"); + assertStringTranslate("i\u030A", "İ", "x", "UNICODE", "i\u030A"); assertStringTranslate("i\u030A", "İ", "x", "UNICODE_CI", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UTF8_BINARY", "y\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UTF8_LCASE", "y\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UNICODE", "i\u030A"); assertStringTranslate("i\u030A", "İi", "xy", "UNICODE_CI", "i\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE", "1i\u0307"); assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "11"); + assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"); + assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_LCASE", "11"); + assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"); assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE_CI", "11"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE", "İi\u0307"); assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İi\u0307"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY", "12bc3"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE", "12bc3"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3"); assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE_CI", "3bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY", "a2bcå"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE", "12bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE", "a\u030Abcå"); assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI", "3bc3"); + assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY", "3\u030Aβφδ1\u0307"); + assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE", "3\u030Aβφδ2"); + assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE", "4βφδ2"); assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI", "4βφδ2"); - - // Greek sigmas - UTF8_BINARY. + // Conditional case mapping (e.g. Greek sigmas). assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_BINARY", "σΥσΤΗΜΑΤΙΚΟσ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_BINARY", "ςΥςΤΗΜΑΤΙΚΟς"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); - assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_BINARY", "σιστιματικος"); - assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_BINARY", "σιστιματικος"); - assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_BINARY", "σιστιματικοσ"); - // Greek sigmas - UTF8_LCASE. assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); - assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); - assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); - assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); - // Greek sigmas - UNICODE. assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE", "σΥσΤΗΜΑΤΙΚΟσ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE", "ςΥςΤΗΜΑΤΙΚΟς"); - assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); - assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE", "σιστιματικος"); - assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE", "σιστιματικος"); - assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE", "σιστιματικοσ"); - // Greek sigmas - UNICODE_CI. assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_BINARY", "ςΥςΤΗΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE", "ςΥςΤΗΜΑΤΙΚΟς"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_BINARY", "σιστιματικος"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE", "σιστιματικος"); assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE_CI", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_BINARY", "σιστιματικος"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE", "σιστιματικος"); assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE_CI", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_BINARY", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE", "σιστιματικοσ"); assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE_CI", "σιστιματικοσ"); + // Surrogate pairs. + assertStringTranslate("a🙃b🙃c", "a", "x", "UTF8_BINARY", "x🙃b🙃c"); + assertStringTranslate("a🙃b🙃c", "a🙃", "xy", "UTF8_BINARY", "xybyc"); + assertStringTranslate("a🙃b🙃c", "a🙃b", "xyz", "UTF8_BINARY", "xyzyc"); + assertStringTranslate("a🙃b🙃c", "a🙃bc", "xyzw", "UTF8_BINARY", "xyzyw"); + assertStringTranslate("😀😆😃😄", "😄😆", "😅😂", "UTF8_BINARY", "😀😂😃😅"); + assertStringTranslate("😀😆😃😄", "😄😆", "😅😂", "UTF8_LCASE", "😀😂😃😅"); + assertStringTranslate("😀😆😃😄", "😄😆", "😅😂", "UNICODE", "😀😂😃😅"); + assertStringTranslate("😀😆😃😄", "😄😆", "😅😂", "UNICODE_CI", "😀😂😃😅"); + assertStringTranslate("𐐅", "𐐅", "x", "UTF8_BINARY", "x"); + assertStringTranslate("𐐅", "𐐅", "x", "UTF8_LCASE", "x"); + assertStringTranslate("𐐅", "𐐅", "x", "UNICODE", "x"); + assertStringTranslate("𐐅", "𐐅", "x", "UNICODE_CI", "x"); + assertStringTranslate("𐐅", "𐐭", "x", "UTF8_BINARY", "𐐅"); + assertStringTranslate("𐐅", "𐐭", "x", "UTF8_LCASE", "x"); + assertStringTranslate("𐐅", "𐐭", "x", "UNICODE", "𐐅"); + assertStringTranslate("𐐅", "𐐭", "x", "UNICODE_CI", "x"); + assertStringTranslate("A", "A", "𐐅", "UTF8_BINARY", "𐐅"); + assertStringTranslate("A", "A", "𐐅", "UTF8_LCASE", "𐐅"); + assertStringTranslate("A", "A", "𐐅", "UNICODE", "𐐅"); + assertStringTranslate("A", "A", "𐐅", "UNICODE_CI", "𐐅"); + assertStringTranslate("A", "a", "𐐅", "UTF8_BINARY", "A"); + assertStringTranslate("A", "a", "𐐅", "UTF8_LCASE", "𐐅"); + assertStringTranslate("A", "a", "𐐅", "UNICODE", "A"); + assertStringTranslate("A", "a", "𐐅", "UNICODE_CI", "𐐅"); + assertStringTranslate("a", "A", "𐐅", "UTF8_BINARY", "a"); + assertStringTranslate("a", "A", "𐐅", "UTF8_LCASE", "𐐅"); + assertStringTranslate("a", "A", "𐐅", "UNICODE", "a"); + assertStringTranslate("a", "A", "𐐅", "UNICODE_CI", "𐐅"); + assertStringTranslate("𝔸", "𝔸", "x", "UTF8_BINARY", "x"); + assertStringTranslate("𝔸", "𝔸", "x", "UTF8_LCASE", "x"); + assertStringTranslate("𝔸", "𝔸", "x", "UNICODE", "x"); + assertStringTranslate("𝔸", "𝔸", "x", "UNICODE_CI", "x"); + assertStringTranslate("𝔸", "𝕒", "x", "UTF8_BINARY", "𝔸"); + assertStringTranslate("𝔸", "𝕒", "x", "UTF8_LCASE", "𝔸"); + assertStringTranslate("𝔸", "𝕒", "x", "UNICODE", "𝔸"); + assertStringTranslate("𝔸", "𝕒", "x", "UNICODE_CI", "x"); } private Map buildDict(String matching, String replace) { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 2428d40fe8016..c4a66fdffdd4d 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -26,6 +26,8 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UTF8StringBuilder; + import org.junit.jupiter.api.Test; import static org.apache.spark.unsafe.types.UTF8String.fromString; @@ -1362,4 +1364,27 @@ public void toBinaryString() { UTF8String.fromString("111111111111111111111111111111111111111111111111111111111111111"), UTF8String.toBinaryString(Long.MAX_VALUE)); } + + /** + * This tests whether appending a codepoint to a 'UTF8StringBuilder' correctly appends every + * single codepoint. We test it against an already existing 'StringBuilder.appendCodePoint' and + * 'UTF8String.fromString'. We skip testing the surrogate codepoints because at some point while + * converting the surrogate codepoint to 'UTF8String' (via 'StringBuilder' and 'UTF8String') we + * get an ill-formated byte sequence (probably because 'String' is in UTF-16 format, and a single + * surrogate codepoint is handled differently in UTF-16 than in UTF-8, so somewhere during those + * conversions some different behaviour happens). + */ + @Test + public void testAppendCodepointToUTF8StringBuilder() { + int surrogateRangeLowerBound = 0xD800; + int surrogateRangeUpperBound = 0xDFFF; + for (int i = Character.MIN_CODE_POINT; i <= Character.MAX_CODE_POINT; ++i) { + if(surrogateRangeLowerBound <= i && i <= surrogateRangeUpperBound) continue; + UTF8StringBuilder usb = new UTF8StringBuilder(); + usb.appendCodePoint(i); + StringBuilder sb = new StringBuilder(); + sb.appendCodePoint(i); + assert(usb.build().equals(UTF8String.fromString(sb.toString()))); + } + } } diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 23dae47f6ff2c..1f64547da7415 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -32,28 +32,35 @@ import org.apache.spark.sql.catalyst.util.CollationFactory._ import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8} class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ignore funsuite + + val currentIcuVersion: String = "76.1" + test("collationId stability") { assert(INDETERMINATE_COLLATION_ID == -1) assert(UTF8_BINARY_COLLATION_ID == 0) val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID) assert(utf8Binary.collationName == "UTF8_BINARY") - assert(utf8Binary.supportsBinaryEquality) + assert(utf8Binary.isUtf8BinaryType) + assert(utf8Binary.version == currentIcuVersion) assert(UTF8_LCASE_COLLATION_ID == 1) - val utf8BinaryLcase = fetchCollation(UTF8_LCASE_COLLATION_ID) - assert(utf8BinaryLcase.collationName == "UTF8_LCASE") - assert(!utf8BinaryLcase.supportsBinaryEquality) + val utf8Lcase = fetchCollation(UTF8_LCASE_COLLATION_ID) + assert(utf8Lcase.collationName == "UTF8_LCASE") + assert(!utf8Lcase.isUtf8BinaryType) + assert(utf8Lcase.version == currentIcuVersion) assert(UNICODE_COLLATION_ID == (1 << 29)) val unicode = fetchCollation(UNICODE_COLLATION_ID) assert(unicode.collationName == "UNICODE") - assert(!unicode.supportsBinaryEquality) + assert(!unicode.isUtf8BinaryType) + assert(unicode.version == currentIcuVersion) assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17))) val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID) assert(unicodeCi.collationName == "UNICODE_CI") - assert(!unicodeCi.supportsBinaryEquality) + assert(!unicodeCi.isUtf8BinaryType) + assert(unicodeCi.version == currentIcuVersion) } test("UTF8_BINARY and ICU root locale collation names") { @@ -93,27 +100,33 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig Seq( ("UTF8_BINARY_CS", "UTF8_BINARY"), ("UTF8_BINARY_AS", "UTF8_BINARY"), // this should be UNICODE_AS - ("UTF8_BINARY_CS_AS","UTF8_BINARY"), // this should be UNICODE_CS_AS - ("UTF8_BINARY_AS_CS","UTF8_BINARY"), - ("UTF8_BINARY_CI","UTF8_BINARY"), - ("UTF8_BINARY_AI","UTF8_BINARY"), - ("UTF8_BINARY_CI_AI","UTF8_BINARY"), - ("UTF8_BINARY_AI_CI","UTF8_BINARY"), - ("UTF8_BS","UTF8_LCASE"), - ("BINARY_UTF8","ar_SAU"), - ("UTF8_BINARY_A","UTF8_BINARY"), - ("UNICODE_X","UNICODE"), - ("UNICODE_CI_X","UNICODE"), - ("UNICODE_LCASE_X","UNICODE"), - ("UTF8_UNICODE","UTF8_LCASE"), - ("UTF8_BINARY_UNICODE","UTF8_BINARY"), + ("UTF8_BINARY_CS_AS", "UTF8_BINARY"), // this should be UNICODE_CS_AS + ("UTF8_BINARY_AS_CS", "UTF8_BINARY"), + ("UTF8_BINARY_CI", "UTF8_BINARY"), + ("UTF8_BINARY_AI", "UTF8_BINARY"), + ("UTF8_BINARY_CI_AI", "UTF8_BINARY"), + ("UTF8_BINARY_AI_CI", "UTF8_BINARY"), + ("UTF8_BINARY_AI_RTRIM", "UTF8_BINARY_RTRIM"), + ("UTF8_BINARY_CI_RTRIM", "UTF8_BINARY_RTRIM"), + ("UTF8_BINARY_AI_CI_RTRIM", "UTF8_BINARY_RTRIM"), + ("UTF8_BS", "UTF8_LCASE"), + ("BINARY_UTF8", "ar_SAU"), + ("UTF8_BINARY_A", "UTF8_BINARY"), + ("UNICODE_X", "UNICODE"), + ("UNICODE_CI_X", "UNICODE"), + ("UNICODE_LCASE_X", "UNICODE"), + ("UNICODE_RTRIM_LCASE_X", "UNICODE"), + ("UTF8_UNICODE", "UTF8_LCASE"), + ("UTF8_BINARY_UNICODE", "UTF8_BINARY"), ("CI_UNICODE", "UNICODE"), ("LCASE_UNICODE", "UNICODE"), + ("RTRIM_UNICODE", "UNICODE"), ("UNICODE_UNSPECIFIED", "UNICODE"), ("UNICODE_CI_UNSPECIFIED", "UNICODE"), ("UNICODE_UNSPECIFIED_CI_UNSPECIFIED", "UNICODE"), ("UNICODE_INDETERMINATE", "UNICODE"), - ("UNICODE_CI_INDETERMINATE", "UNICODE") + ("UNICODE_CI_INDETERMINATE", "UNICODE"), + ("UNICODE_RTRIM_INDETERMINATE", "UNICODE") ).foreach{case (collationName, proposals) => checkCollationNameError(collationName, proposals) } @@ -127,6 +140,11 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), CollationTestCase("UTF8_BINARY", "aaa", "bbb", false), CollationTestCase("UTF8_BINARY", "å", "a\u030A", false), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " ", true), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -134,15 +152,36 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", false), CollationTestCase("UTF8_LCASE", "aaa", "bbb", false), CollationTestCase("UTF8_LCASE", "å", "a\u030A", false), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", true), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), CollationTestCase("UNICODE", "å", "a\u030A", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), - CollationTestCase("UNICODE_CI", "Å", "a\u030A", true) + CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true), + CollationTestCase("SR_CI", "cČć", "CčĆ", true), + CollationTestCase("SR_CI", "cCc", "CčĆ", false), + CollationTestCase("SR_CI_AI", "cCc", "CčĆ", true), + CollationTestCase("sr_Cyrl_CI", "цЧћ", "ЦчЋ", true), + CollationTestCase("sr_Cyrl_CI", "цЦц", "ЦчЋ", false), + CollationTestCase("sr_Cyrl_CI_AI", "цЦц", "ЦчЋ", false) ) checks.foreach(testCase => { @@ -162,19 +201,50 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", 1), CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1), CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " " , 0), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "aa", 1), CollationTestCase("UTF8_LCASE", "aaa", "bbb", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), CollationTestCase("UNICODE", "aaa", "BBB", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1), + CollationTestCase("UNICODE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), - CollationTestCase("UNICODE_CI", "aaa", "bbb", -1)) + CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0), + CollationTestCase("SR_CI_AI", "cČć", "ČćC", 0), + CollationTestCase("SR_CI", "cČć", "ČćC", -1) + ) checks.foreach(testCase => { val collation = fetchCollation(testCase.collationName) @@ -192,7 +262,10 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI", "abcde", "abcde", 5), CollationTestCase("UNICODE_CI", "abcde", "ABCDE", 5), CollationTestCase("UNICODE_CI", "abcde", "fgh", 0), - CollationTestCase("UNICODE_CI", "abcde", "FGH", 0) + CollationTestCase("UNICODE_CI", "abcde", "FGH", 0), + CollationTestCase("SR_CI_AI", "abcčċ", "CCC", 3), + CollationTestCase("SR_CI", "abcčċ", "C", 1), + CollationTestCase("SR", "abcčċ", "CCC", 0) ) checks.foreach(testCase => { @@ -229,7 +302,9 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig "UNICODE_CI", "UNICODE_AI", "UNICODE_CI_AI", - "UNICODE_AI_CI" + "UNICODE_AI_CI", + "DE_CI_AI", + "MT_CI" ).foreach(collationId => { val col1 = fetchCollation(collationId) val col2 = fetchCollation(collationId) @@ -303,15 +378,23 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig ("CI_en", "ceb"), ("USA_CI_en", "UNICODE"), ("en_CI_USA", "en_USA"), + ("en_RTRIM_USA", "en_USA"), ("CI_sr_Cyrl_SRB", "sr_Cyrl_SRB"), + ("RTRIM_sr_Cyrl_SRB", "sr_Cyrl_SRB"), ("sr_CI_Cyrl_SRB", "sr_Cyrl_SRB"), + ("sr_RTRIM_Cyrl_SRB", "sr_Cyrl_SRB"), ("sr_Cyrl_CI_SRB", "sr_Cyrl_SRB"), + ("sr_Cyrl_RTRIM_SRB", "sr_Cyrl_SRB"), ("CI_Cyrl_sr", "sr_Cyrl_SRB"), + ("RTRIM_Cyrl_sr", "sr_Cyrl_SRB"), ("Cyrl_CI_sr", "he_ISR"), ("Cyrl_CI_sr_SRB", "sr_Cyrl_SRB"), + ("Cyrl_RTRIM_sr_SRB", "sr_Cyrl_SRB"), ("Cyrl_sr_CI_SRB", "sr_Cyrl_SRB"), + ("Cyrl_sr_RTRIM_SRB", "sr_Cyrl_SRB"), // no locale specified ("_CI_AI", "af_CI_AI, am_CI_AI, ar_CI_AI"), + ("_CI_AI_RTRIM", "af_CI_AI_RTRIM, am_CI_AI_RTRIM, ar_CI_AI_RTRIM"), ("", "af, am, ar") ).foreach { case (collationName, proposals) => checkCollationNameError(collationName, proposals) @@ -369,9 +452,9 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig 1 << 15, // UTF8_BINARY mandatory zero bit 15 breach. 1 << 16, // UTF8_BINARY mandatory zero bit 16 breach. 1 << 17, // UTF8_BINARY mandatory zero bit 17 breach. - 1 << 18, // UTF8_BINARY mandatory zero bit 18 breach. 1 << 19, // UTF8_BINARY mandatory zero bit 19 breach. 1 << 20, // UTF8_BINARY mandatory zero bit 20 breach. + 1 << 21, // UTF8_BINARY mandatory zero bit 21 breach. 1 << 23, // UTF8_BINARY mandatory zero bit 23 breach. 1 << 24, // UTF8_BINARY mandatory zero bit 24 breach. 1 << 25, // UTF8_BINARY mandatory zero bit 25 breach. @@ -382,7 +465,6 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig (1 << 29) | (1 << 13), // ICU mandatory zero bit 13 breach. (1 << 29) | (1 << 14), // ICU mandatory zero bit 14 breach. (1 << 29) | (1 << 15), // ICU mandatory zero bit 15 breach. - (1 << 29) | (1 << 18), // ICU mandatory zero bit 18 breach. (1 << 29) | (1 << 19), // ICU mandatory zero bit 19 breach. (1 << 29) | (1 << 20), // ICU mandatory zero bit 20 breach. (1 << 29) | (1 << 21), // ICU mandatory zero bit 21 breach. @@ -408,6 +490,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig ("UNICODE_CI_CI", "UNICODE_CI"), ("UNICODE_CI_CS", "UNICODE_CS"), ("UNICODE_CS_CI", "UNICODE_CS"), + ("UNICODE_RTRIM_RTRIM", "UNICODE_RTRIM"), ("UNICODE_AS_AS", "UNICODE_AS"), ("UNICODE_AI_AI", "UNICODE_AI"), ("UNICODE_AS_AI", "UNICODE_AS"), @@ -417,6 +500,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig ("UNICODE_CS_AS_CI_AI", "UNICODE_CS_AS"), ("UNICODE__CS__AS", "UNICODE_AS"), ("UNICODE-CS-AS", "UNICODE"), + ("UNICODE__CS__RTRIM", "UNICODE_RTRIM"), ("UNICODECSAS", "UNICODE"), ("_CS_AS_UNICODE", "UNICODE") ).foreach { case (collationName, proposals) => @@ -457,7 +541,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig val e = intercept[SparkException] { fetchCollation(collationName) } - assert(e.getErrorClass === "COLLATION_INVALID_NAME") + assert(e.getCondition === "COLLATION_INVALID_NAME") assert(e.getMessageParameters.asScala === Map( "collationName" -> collationName, "proposals" -> proposals)) } diff --git a/common/utils/src/main/java/org/apache/spark/SparkThrowable.java b/common/utils/src/main/java/org/apache/spark/SparkThrowable.java index e1235b2982ba0..39808f58b08ae 100644 --- a/common/utils/src/main/java/org/apache/spark/SparkThrowable.java +++ b/common/utils/src/main/java/org/apache/spark/SparkThrowable.java @@ -35,19 +35,29 @@ */ @Evolving public interface SparkThrowable { - // Succinct, human-readable, unique, and consistent representation of the error category - // If null, error class is not set - String getErrorClass(); + /** + * Succinct, human-readable, unique, and consistent representation of the error condition. + * If null, error condition is not set. + */ + String getCondition(); + + /** + * Succinct, human-readable, unique, and consistent representation of the error category. + * If null, error class is not set. + * @deprecated Use {@link #getCondition()} instead. + */ + @Deprecated + default String getErrorClass() { return getCondition(); } // Portable error identifier across SQL engines // If null, error class or SQLSTATE is not set default String getSqlState() { - return SparkThrowableHelper.getSqlState(this.getErrorClass()); + return SparkThrowableHelper.getSqlState(this.getCondition()); } // True if this error is an internal error. default boolean isInternalError() { - return SparkThrowableHelper.isInternalError(this.getErrorClass()); + return SparkThrowableHelper.isInternalError(this.getCondition()); } default Map getMessageParameters() { diff --git a/common/utils/src/main/scala/org/apache/spark/unsafe/array/ByteArrayUtils.java b/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java similarity index 100% rename from common/utils/src/main/scala/org/apache/spark/unsafe/array/ByteArrayUtils.java rename to common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index d8edc89ba83ea..77437f6c56179 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1,4 +1,10 @@ { + "ADD_DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION" : { "message" : [ "Non-deterministic expression should not appear in the arguments of an aggregate function." @@ -121,10 +127,16 @@ }, "BINARY_ARITHMETIC_OVERFLOW" : { "message" : [ - " caused overflow." + " caused overflow. Use to ignore overflow problem and return NULL." ], "sqlState" : "22003" }, + "BOOLEAN_STATEMENT_WITH_EMPTY_ROW" : { + "message" : [ + "Boolean statement is invalid. Expected single row with a value of the BOOLEAN type, but got an empty row." + ], + "sqlState" : "21000" + }, "CALL_ON_STREAMING_DATASET_UNSUPPORTED" : { "message" : [ "The method can not be called on streaming Dataset/DataFrame." @@ -256,6 +268,26 @@ "Error reading streaming state file of does not exist. If the stream job is restarted with a new or updated state operation, please create a new checkpoint location or clear the existing checkpoint location." ] }, + "HDFS_STORE_PROVIDER_OUT_OF_MEMORY" : { + "message" : [ + "Could not load HDFS state store with id because of an out of memory exception." + ] + }, + "INVALID_CHANGE_LOG_READER_VERSION" : { + "message" : [ + "The change log reader version cannot be ." + ] + }, + "INVALID_CHANGE_LOG_WRITER_VERSION" : { + "message" : [ + "The change log writer version cannot be ." + ] + }, + "ROCKSDB_STORE_PROVIDER_OUT_OF_MEMORY" : { + "message" : [ + "Could not load RocksDB state store with id because of an out of memory exception." + ] + }, "SNAPSHOT_PARTITION_ID_NOT_FOUND" : { "message" : [ "Partition id not found for state of operator at ." @@ -344,6 +376,12 @@ ], "sqlState" : "429BB" }, + "CANNOT_REMOVE_RESERVED_PROPERTY" : { + "message" : [ + "Cannot remove reserved property: ." + ], + "sqlState" : "42000" + }, "CANNOT_RENAME_ACROSS_SCHEMA" : { "message" : [ "Renaming a across schemas is not allowed." @@ -368,12 +406,6 @@ ], "sqlState" : "58030" }, - "CANNOT_SAVE_VARIANT" : { - "message" : [ - "Cannot save variant data type into external storage." - ], - "sqlState" : "0A000" - }, "CANNOT_UPDATE_FIELD" : { "message" : [ "Cannot update field type:" @@ -414,6 +446,12 @@ ], "sqlState" : "42846" }, + "CANNOT_USE_KRYO" : { + "message" : [ + "Cannot load Kryo serialization codec. Kryo serialization cannot be used in the Spark Connect client. Use Java serialization, provide a custom Codec, or use Spark Classic instead." + ], + "sqlState" : "22KD3" + }, "CANNOT_WRITE_STATE_STORE" : { "message" : [ "Error writing state store files for provider ." @@ -429,13 +467,13 @@ }, "CAST_INVALID_INPUT" : { "message" : [ - "The value of the type cannot be cast to because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set to \"false\" to bypass this error." + "The value of the type cannot be cast to because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead." ], "sqlState" : "22018" }, "CAST_OVERFLOW" : { "message" : [ - "The value of the type cannot be cast to due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set to \"false\" to bypass this error." + "The value of the type cannot be cast to due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead." ], "sqlState" : "22003" }, @@ -459,6 +497,12 @@ ], "sqlState" : "56000" }, + "CIRCULAR_CLASS_REFERENCE" : { + "message" : [ + "Cannot have circular references in class, but got the circular reference of class ." + ], + "sqlState" : "42602" + }, "CLASS_NOT_OVERRIDE_EXPECTED_METHOD" : { "message" : [ " must override either or ." @@ -471,6 +515,20 @@ ], "sqlState" : "0A000" }, + "CLUSTERING_COLUMNS_MISMATCH" : { + "message" : [ + "Specified clustering does not match that of the existing table .", + "Specified clustering columns: [].", + "Existing clustering columns: []." + ], + "sqlState" : "42P10" + }, + "CLUSTERING_NOT_SUPPORTED" : { + "message" : [ + "'' does not support clustering." + ], + "sqlState" : "42000" + }, "CODEC_NOT_AVAILABLE" : { "message" : [ "The codec is not available." @@ -519,7 +577,7 @@ }, "IMPLICIT" : { "message" : [ - "Error occurred due to the mismatch between multiple implicit non-default collations. Use COLLATE function to set the collation explicitly." + "Error occurred due to the mismatch between implicit collations: []. Use COLLATE function to set the collation explicitly." ] } }, @@ -560,6 +618,12 @@ ], "sqlState" : "42711" }, + "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH" : { + "message" : [ + "Some values in field are incompatible with the column array type. Expected type ." + ], + "sqlState" : "0A000" + }, "COLUMN_NOT_DEFINED_IN_TABLE" : { "message" : [ " column is not defined in table , defined table columns are: ." @@ -572,6 +636,13 @@ ], "sqlState" : "42703" }, + "COLUMN_ORDINAL_OUT_OF_BOUNDS" : { + "message" : [ + "Column ordinal out of bounds. The number of columns in the table is , but the column ordinal is .", + "Attributes are the following: ." + ], + "sqlState" : "22003" + }, "COMPARATOR_RETURNS_NULL" : { "message" : [ "The comparator has returned a NULL for a comparison between and .", @@ -585,6 +656,11 @@ "Cannot process input data types for the expression: ." ], "subClass" : { + "BAD_INPUTS" : { + "message" : [ + "The input data types to must be valid, but found the input types ." + ] + }, "MISMATCHED_TYPES" : { "message" : [ "All input types must be the same except nullable, containsNull, valueContainsNull flags, but found the input types ." @@ -611,6 +687,27 @@ ], "sqlState" : "40000" }, + "CONFLICTING_DIRECTORY_STRUCTURES" : { + "message" : [ + "Conflicting directory structures detected.", + "Suspicious paths:", + "", + "If provided paths are partition directories, please set \"basePath\" in the options of the data source to specify the root directory of the table.", + "If there are multiple root directories, please load them separately and then union them." + ], + "sqlState" : "KD009" + }, + "CONFLICTING_PARTITION_COLUMN_NAMES" : { + "message" : [ + "Conflicting partition column names detected:", + "", + "For partitioned table directories, data files should only live in leaf directories.", + "And directories at the same level should have the same partition column name.", + "Please check the following directories for unexpected files or inconsistent partition column names:", + "" + ], + "sqlState" : "KD009" + }, "CONNECT" : { "message" : [ "Generic Spark Connect error." @@ -853,7 +950,7 @@ }, "NON_STRING_TYPE" : { "message" : [ - "all arguments must be strings." + "all arguments of the function must be strings." ] }, "NULL_TYPE" : { @@ -868,7 +965,7 @@ }, "RANGE_FRAME_INVALID_TYPE" : { "message" : [ - "The data type used in the order specification does not match the data type which is used in the range frame." + "The data type used in the order specification does not support the data type which is used in the range frame." ] }, "RANGE_FRAME_MULTI_ORDER" : { @@ -954,16 +1051,6 @@ "The input of can't be type data." ] }, - "UNSUPPORTED_UDF_INPUT_TYPE" : { - "message" : [ - "UDFs do not support '' as an input data type." - ] - }, - "UNSUPPORTED_UDF_OUTPUT_TYPE" : { - "message" : [ - "UDFs do not support '' as an output data type." - ] - }, "VALUE_OUT_OF_RANGE" : { "message" : [ "The must be between (current value = )." @@ -994,6 +1081,12 @@ ], "sqlState" : "42710" }, + "DATA_SOURCE_EXTERNAL_ERROR" : { + "message" : [ + "Encountered error when saving to external data source." + ], + "sqlState" : "KD010" + }, "DATA_SOURCE_NOT_EXIST" : { "message" : [ "Data source '' not found. Please make sure the data source is registered." @@ -1014,6 +1107,12 @@ ], "sqlState" : "42K03" }, + "DATETIME_FIELD_OUT_OF_BOUNDS" : { + "message" : [ + ". If necessary set to \"false\" to bypass this error." + ], + "sqlState" : "22023" + }, "DATETIME_OVERFLOW" : { "message" : [ "Datetime operation overflow: ." @@ -1039,6 +1138,12 @@ ], "sqlState" : "42608" }, + "DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED" : { "message" : [ "Distinct window functions are not supported: ." @@ -1051,6 +1156,12 @@ ], "sqlState" : "22012" }, + "DUPLICATED_CTE_NAMES" : { + "message" : [ + "CTE definition can't have duplicate names: ." + ], + "sqlState" : "42602" + }, "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" : { "message" : [ "Duplicated field names in Arrow Struct are not allowed, got ." @@ -1121,6 +1232,12 @@ ], "sqlState" : "42604" }, + "EMPTY_SCHEMA_NOT_SUPPORTED_FOR_DATASOURCE" : { + "message" : [ + "The datasource does not support writing empty or nested empty schemas. Please make sure the data schema has at least one or more column(s)." + ], + "sqlState" : "0A000" + }, "ENCODER_NOT_FOUND" : { "message" : [ "Not found an encoder of the type to Spark SQL internal representation.", @@ -1387,6 +1504,12 @@ ], "sqlState" : "2203G" }, + "FAILED_TO_LOAD_ROUTINE" : { + "message" : [ + "Failed to load routine ." + ], + "sqlState" : "38000" + }, "FAILED_TO_PARSE_TOO_COMPLEX" : { "message" : [ "The statement, including potential SQL functions and referenced views, was too complex to parse.", @@ -1412,6 +1535,12 @@ ], "sqlState" : "42704" }, + "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR" : { + "message" : [ + "An error occurred in the user provided function in flatMapGroupsWithState. Reason: " + ], + "sqlState" : "39000" + }, "FORBIDDEN_OPERATION" : { "message" : [ "The operation is not allowed on the : ." @@ -1424,6 +1553,12 @@ ], "sqlState" : "39000" }, + "FOREACH_USER_FUNCTION_ERROR" : { + "message" : [ + "An error occurred in the user provided function in foreach sink. Reason: " + ], + "sqlState" : "39000" + }, "FOUND_MULTIPLE_DATA_SOURCES" : { "message" : [ "Detected multiple data sources with the name ''. Please check the data source isn't simultaneously registered and located in the classpath." @@ -1520,6 +1655,36 @@ ], "sqlState" : "42601" }, + "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION" : { + "message" : [ + "Duplicated IDENTITY column sequence generator option: ." + ], + "sqlState" : "42601" + }, + "IDENTITY_COLUMNS_ILLEGAL_STEP" : { + "message" : [ + "IDENTITY column step cannot be 0." + ], + "sqlState" : "42611" + }, + "IDENTITY_COLUMNS_UNSUPPORTED_DATA_TYPE" : { + "message" : [ + "DataType is not supported for IDENTITY columns." + ], + "sqlState" : "428H2" + }, + "IDENTITY_COLUMN_WITH_DEFAULT_VALUE" : { + "message" : [ + "A column cannot have both a default value and an identity column specification but column has default value: () and identity column specification: ()." + ], + "sqlState" : "42623" + }, + "ILLEGAL_DAY_OF_WEEK" : { + "message" : [ + "Illegal input for day of week: ." + ], + "sqlState" : "22009" + }, "ILLEGAL_STATE_STORE_VALUE" : { "message" : [ "Illegal value provided to the State Store" @@ -1859,8 +2024,20 @@ }, "INTERVAL_ARITHMETIC_OVERFLOW" : { "message" : [ - "." + "Integer overflow while operating with intervals." ], + "subClass" : { + "WITHOUT_SUGGESTION" : { + "message" : [ + "Try devising appropriate values for the interval parameters." + ] + }, + "WITH_SUGGESTION" : { + "message" : [ + "Use to tolerate overflow and return NULL instead." + ] + } + }, "sqlState" : "22015" }, "INTERVAL_DIVIDED_BY_ZERO" : { @@ -1897,6 +2074,12 @@ }, "sqlState" : "42903" }, + "INVALID_AGNOSTIC_ENCODER" : { + "message" : [ + "Found an invalid agnostic encoder. Expects an instance of AgnosticEncoder but got . For more information consult '/api/java/index.html?org/apache/spark/sql/Encoder.html'." + ], + "sqlState" : "42001" + }, "INVALID_ARRAY_INDEX" : { "message" : [ "The index is out of bounds. The array has elements. Use the SQL function `get()` to tolerate accessing element at invalid index and return NULL instead. If necessary set to \"false\" to bypass this error." @@ -1909,12 +2092,24 @@ ], "sqlState" : "22003" }, + "INVALID_ATTRIBUTE_NAME_SYNTAX" : { + "message" : [ + "Syntax error in the attribute name: . Check that backticks appear in pairs, a quoted string is a complete name part and use a backtick only inside quoted name parts." + ], + "sqlState" : "42601" + }, "INVALID_BITMAP_POSITION" : { "message" : [ "The 0-indexed bitmap position is out of bounds. The bitmap has bits ( bytes)." ], "sqlState" : "22003" }, + "INVALID_BOOLEAN_STATEMENT" : { + "message" : [ + "Boolean statement is expected in the condition, but was found." + ], + "sqlState" : "22546" + }, "INVALID_BOUNDARY" : { "message" : [ "The boundary is invalid: ." @@ -1981,6 +2176,12 @@ }, "sqlState" : "22022" }, + "INVALID_CORRUPT_RECORD_TYPE" : { + "message" : [ + "The column for corrupt records must have the nullable STRING type, but got ." + ], + "sqlState" : "42804" + }, "INVALID_CURSOR" : { "message" : [ "The cursor is invalid." @@ -2023,6 +2224,11 @@ "message" : [ "Too many letters in datetime pattern: . Please reduce pattern length." ] + }, + "SECONDS_FRACTION" : { + "message" : [ + "Cannot detect a seconds fraction pattern of variable length. Please make sure the pattern contains 'S', and does not contain illegal characters." + ] } }, "sqlState" : "22007" @@ -2126,6 +2332,12 @@ ], "sqlState" : "42001" }, + "INVALID_EXTERNAL_TYPE" : { + "message" : [ + "The external type is not valid for the type at the expression ." + ], + "sqlState" : "42K0N" + }, "INVALID_EXTRACT_BASE_FIELD_TYPE" : { "message" : [ "Can't extract a value from . Need a complex type [STRUCT, ARRAY, MAP] but got ." @@ -2215,7 +2427,8 @@ }, "INVALID_FRACTION_OF_SECOND" : { "message" : [ - "The fraction of sec must be zero. Valid range is [0, 60]. If necessary set to \"false\" to bypass this error." + "Valid range for seconds is [0, 60] (inclusive), but the provided value is . To avoid this error, use `try_make_timestamp`, which returns NULL on error.", + "If you do not want to use the session default timestamp version of this function, use `try_make_timestamp_ntz` or `try_make_timestamp_ltz`." ], "sqlState" : "22023" }, @@ -2315,6 +2528,11 @@ "Uncaught arithmetic exception while parsing ''." ] }, + "DAY_TIME_PARSING" : { + "message" : [ + "Error parsing interval day-time string: ." + ] + }, "INPUT_IS_EMPTY" : { "message" : [ "Interval string cannot be empty." @@ -2325,6 +2543,11 @@ "Interval string cannot be null." ] }, + "INTERVAL_PARSING" : { + "message" : [ + "Error parsing interval string." + ] + }, "INVALID_FRACTION" : { "message" : [ " cannot have fractional part." @@ -2360,19 +2583,50 @@ "Expect a unit name after but hit EOL." ] }, + "SECOND_NANO_FORMAT" : { + "message" : [ + "Interval string does not match second-nano format of ss.nnnnnnnnn." + ] + }, + "TIMEZONE_INTERVAL_OUT_OF_RANGE" : { + "message" : [ + "The interval value must be in the range of [-18, +18] hours with second precision." + ] + }, "UNKNOWN_PARSING_ERROR" : { "message" : [ "Unknown error when parsing ." ] }, + "UNMATCHED_FORMAT_STRING" : { + "message" : [ + "Interval string does not match format of when cast to : ." + ] + }, + "UNMATCHED_FORMAT_STRING_WITH_NOTICE" : { + "message" : [ + "Interval string does not match format of when cast to : . Set \"spark.sql.legacy.fromDayTimeString.enabled\" to \"true\" to restore the behavior before Spark 3.0." + ] + }, "UNRECOGNIZED_NUMBER" : { "message" : [ "Unrecognized number ." ] + }, + "UNSUPPORTED_FROM_TO_EXPRESSION" : { + "message" : [ + "Cannot support (interval '' to ) expression." + ] } }, "sqlState" : "22006" }, + "INVALID_INTERVAL_WITH_MICROSECONDS_ADDITION" : { + "message" : [ + "Cannot add an interval to a date because its microseconds part is not 0. If necessary set to \"false\" to bypass this error." + ], + "sqlState" : "22006" + }, "INVALID_INVERSE_DISTRIBUTION_FUNCTION" : { "message" : [ "Invalid inverse distribution function ." @@ -2396,6 +2650,13 @@ }, "sqlState" : "42K0K" }, + "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME" : { + "message" : [ + " is not a valid identifier of Java and cannot be used as field name", + "." + ], + "sqlState" : "46121" + }, "INVALID_JOIN_TYPE_FOR_JOINWITH" : { "message" : [ "Invalid join type in joinWith: ." @@ -2414,6 +2675,12 @@ ], "sqlState" : "2203G" }, + "INVALID_JSON_RECORD_TYPE" : { + "message" : [ + "Detected an invalid type of a JSON record while inferring a common schema in the mode . Expected a STRUCT type, but found ." + ], + "sqlState" : "22023" + }, "INVALID_JSON_ROOT_FIELD" : { "message" : [ "Cannot convert JSON root field to target Spark type." @@ -2432,6 +2699,24 @@ ], "sqlState" : "F0000" }, + "INVALID_LABEL_USAGE" : { + "message" : [ + "The usage of the label is invalid." + ], + "subClass" : { + "DOES_NOT_EXIST" : { + "message" : [ + "Label was used in the statement, but the label does not belong to any surrounding block." + ] + }, + "ITERATE_IN_COMPOUND" : { + "message" : [ + "ITERATE statement cannot be used with a label that belongs to a compound (BEGIN...END) body." + ] + } + }, + "sqlState" : "42K0L" + }, "INVALID_LAMBDA_FUNCTION_CALL" : { "message" : [ "Invalid lambda function call." @@ -2598,6 +2883,11 @@ "expects an integer value in [0, ), but got ." ] }, + "BOOLEAN" : { + "message" : [ + "expects a boolean literal, but got ." + ] + }, "CHARSET" : { "message" : [ "expects one of the , but got ." @@ -2608,11 +2898,31 @@ "expects one of the units without quotes YEAR, QUARTER, MONTH, WEEK, DAY, DAYOFYEAR, HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND, but got the string literal ." ] }, + "DOUBLE" : { + "message" : [ + "expects an double literal, but got ." + ] + }, + "DTYPE" : { + "message" : [ + "Unsupported dtype: . Valid values: float64, float32." + ] + }, + "INTEGER" : { + "message" : [ + "expects an integer literal, but got ." + ] + }, "LENGTH" : { "message" : [ "Expects `length` greater than or equal to 0, but got ." ] }, + "LONG" : { + "message" : [ + "expects a long literal, but got ." + ] + }, "NULL" : { "message" : [ "expects a non-NULL value." @@ -2633,6 +2943,11 @@ "Expects a positive or a negative value for `start`, but got 0." ] }, + "STRING" : { + "message" : [ + "expects a string literal, but got ." + ] + }, "ZERO_INDEX" : { "message" : [ "expects %1$, %2$ and so on, but got %0$." @@ -2665,6 +2980,12 @@ }, "sqlState" : "42601" }, + "INVALID_PARTITION_VALUE" : { + "message" : [ + "Failed to cast value to data type for partition column . Ensure the value matches the expected data type for this partition column." + ], + "sqlState" : "42846" + }, "INVALID_PROPERTY_KEY" : { "message" : [ " is an invalid property key, please use quotes, e.g. SET =." @@ -2683,6 +3004,18 @@ ], "sqlState" : "42613" }, + "INVALID_REGEXP_REPLACE" : { + "message" : [ + "Could not perform regexp_replace for source = \"\", pattern = \"\", replacement = \"\" and position = ." + ], + "sqlState" : "22023" + }, + "INVALID_RESET_COMMAND_FORMAT" : { + "message" : [ + "Expected format is 'RESET' or 'RESET key'. If you want to include special characters in key, please use quotes, e.g., RESET `key`." + ], + "sqlState" : "42000" + }, "INVALID_SAVE_MODE" : { "message" : [ "The specified save mode is invalid. Valid save modes include \"append\", \"overwrite\", \"ignore\", \"error\", \"errorifexists\", and \"default\"." @@ -2799,7 +3132,7 @@ }, "MULTI_PART_NAME" : { "message" : [ - " with multiple part function name() is not allowed." + " with multiple part name() is not allowed." ] }, "OPTION_IS_INVALID" : { @@ -2837,6 +3170,11 @@ "Unsupported function name ." ] }, + "UNSUPPORTED_SQL_STATEMENT" : { + "message" : [ + "Unsupported SQL statement: ." + ] + }, "VARIABLE_TYPE_OR_DEFAULT_REQUIRED" : { "message" : [ "The definition of a SQL variable requires either a datatype or a DEFAULT clause.", @@ -2877,6 +3215,12 @@ ], "sqlState" : "42K0F" }, + "INVALID_TIMEZONE" : { + "message" : [ + "The timezone: is invalid. The timezone must be either a region-based zone ID or a zone offset. Region IDs must have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH', '(+|-)HH:mm’ or '(+|-)HH:mm:ss', e.g '-08' , '+01:00' or '-13:33:33', and must be in the range from -18:00 to +18:00. 'Z' and 'UTC' are accepted as synonyms for '+00:00'." + ], + "sqlState" : "22009" + }, "INVALID_TIME_TRAVEL_SPEC" : { "message" : [ "Cannot specify both version and timestamp when time travelling the table." @@ -2925,7 +3269,7 @@ }, "INVALID_URL" : { "message" : [ - "The url is invalid: . If necessary set to \"false\" to bypass this error." + "The url is invalid: . Use `try_parse_url` to tolerate invalid URL and return NULL instead." ], "sqlState" : "22P02" }, @@ -2941,6 +3285,24 @@ ], "sqlState" : "22029" }, + "INVALID_VARIABLE_DECLARATION" : { + "message" : [ + "Invalid variable declaration." + ], + "subClass" : { + "NOT_ALLOWED_IN_SCOPE" : { + "message" : [ + "Declaration of the variable is not allowed in this scope." + ] + }, + "ONLY_AT_BEGINNING" : { + "message" : [ + "Variable can only be declared at the beginning of the compound." + ] + } + }, + "sqlState" : "42K0M" + }, "INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE" : { "message" : [ "Variable type must be string type but got ." @@ -2983,6 +3345,12 @@ ], "sqlState" : "22023" }, + "INVALID_VARIANT_SHREDDING_SCHEMA" : { + "message" : [ + "The schema `` is not a valid variant shredding schema." + ], + "sqlState" : "22023" + }, "INVALID_WHERE_CONDITION" : { "message" : [ "The WHERE condition contains invalid expressions: .", @@ -3043,6 +3411,12 @@ ], "sqlState" : "42K0L" }, + "LABEL_ALREADY_EXISTS" : { + "message" : [ + "The label ` not found in schema ." ] }, - "_LEGACY_ERROR_TEMP_1156" : { - "message" : [ - "Column not found in schema ." - ] - }, "_LEGACY_ERROR_TEMP_1158" : { "message" : [ "Saving data into a view is not allowed." @@ -5785,21 +6476,6 @@ "createTableColumnTypes option column not found in schema ." ] }, - "_LEGACY_ERROR_TEMP_1172" : { - "message" : [ - "Parquet type not yet supported: ." - ] - }, - "_LEGACY_ERROR_TEMP_1173" : { - "message" : [ - "Illegal Parquet type: ." - ] - }, - "_LEGACY_ERROR_TEMP_1174" : { - "message" : [ - "Unrecognized Parquet type: ." - ] - }, "_LEGACY_ERROR_TEMP_1181" : { "message" : [ "Stream-stream join without equality predicate is not supported." @@ -5915,7 +6591,7 @@ "Detected implicit cartesian product for join between logical plans", "", "and", - "rightPlan", + "", "Join condition is missing or trivial.", "Either: use the CROSS JOIN syntax to allow cartesian products between these relations, or: enable implicit cartesian products by setting the configuration variable spark.sql.crossJoin.enabled=true." ] @@ -6013,11 +6689,6 @@ "Operation not allowed: ALTER TABLE SET [SERDE | SERDEPROPERTIES] for a specific partition is not supported for tables created with the datasource API." ] }, - "_LEGACY_ERROR_TEMP_1248" : { - "message" : [ - "Operation not allowed: ALTER TABLE SET SERDE is not supported for tables created with the datasource API." - ] - }, "_LEGACY_ERROR_TEMP_1250" : { "message" : [ " is not allowed on since filesource partition management is disabled (spark.sql.hive.manageFilesourcePartitions = false)." @@ -6028,16 +6699,6 @@ " is not allowed on since its partition metadata is not stored in the Hive metastore. To import this information into the metastore, run `msck repair table `." ] }, - "_LEGACY_ERROR_TEMP_1252" : { - "message" : [ - "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead." - ] - }, - "_LEGACY_ERROR_TEMP_1253" : { - "message" : [ - "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead." - ] - }, "_LEGACY_ERROR_TEMP_1255" : { "message" : [ "Cannot drop built-in function ''." @@ -6098,41 +6759,6 @@ "Failed to truncate table when removing data of the path: ." ] }, - "_LEGACY_ERROR_TEMP_1270" : { - "message" : [ - "SHOW CREATE TABLE is not supported on a temporary view:
." - ] - }, - "_LEGACY_ERROR_TEMP_1271" : { - "message" : [ - "Failed to execute SHOW CREATE TABLE against table
, which is created by Hive and uses the following unsupported feature(s)", - "", - "Please use `SHOW CREATE TABLE
AS SERDE` to show Hive DDL instead." - ] - }, - "_LEGACY_ERROR_TEMP_1272" : { - "message" : [ - "SHOW CREATE TABLE doesn't support transactional Hive table. Please use `SHOW CREATE TABLE
AS SERDE` to show Hive DDL instead." - ] - }, - "_LEGACY_ERROR_TEMP_1273" : { - "message" : [ - "Failed to execute SHOW CREATE TABLE against table
, which is created by Hive and uses the following unsupported serde configuration", - "", - "Please use `SHOW CREATE TABLE
AS SERDE` to show Hive DDL instead." - ] - }, - "_LEGACY_ERROR_TEMP_1274" : { - "message" : [ - "
is a Spark data source table. Use `SHOW CREATE TABLE` without `AS SERDE` instead." - ] - }, - "_LEGACY_ERROR_TEMP_1275" : { - "message" : [ - "Failed to execute SHOW CREATE TABLE against table/view
, which is created by Hive and uses the following unsupported feature(s)", - "." - ] - }, "_LEGACY_ERROR_TEMP_1276" : { "message" : [ "The logical plan that represents the view is not analyzed." @@ -6143,18 +6769,6 @@ "It is not allowed to create a persisted view from the Dataset API." ] }, - "_LEGACY_ERROR_TEMP_1285" : { - "message" : [ - "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the", - "referenced columns only include the internal corrupt record column", - "(named _corrupt_record by default). For example:", - "spark.read.schema(schema).csv(file).filter($\"_corrupt_record\".isNotNull).count()", - "and spark.read.schema(schema).csv(file).select(\"_corrupt_record\").show().", - "Instead, you can cache or save the parsed results and then send the same query.", - "For example, val df = spark.read.schema(schema).csv(file).cache() and then", - "df.filter($\"_corrupt_record\".isNotNull).count()." - ] - }, "_LEGACY_ERROR_TEMP_1286" : { "message" : [ "User-defined partition column not found in the JDBC relation: ." @@ -6290,11 +6904,6 @@ "The pivot column has more than distinct values, this could indicate an error. If this was intended, set to at least the number of distinct values of the pivot column." ] }, - "_LEGACY_ERROR_TEMP_1325" : { - "message" : [ - "Cannot modify the value of a static config: ." - ] - }, "_LEGACY_ERROR_TEMP_1327" : { "message" : [ "Command execution is not supported in runner ." @@ -6315,12 +6924,6 @@ "Class doesn't implement interface UserDefinedAggregateFunction." ] }, - "_LEGACY_ERROR_TEMP_1331" : { - "message" : [ - "Missing field in table
with schema:", - "." - ] - }, "_LEGACY_ERROR_TEMP_1332" : { "message" : [ "" @@ -6331,26 +6934,6 @@ "Sinks cannot request distribution and ordering in continuous execution mode." ] }, - "_LEGACY_ERROR_TEMP_1344" : { - "message" : [ - "Invalid DEFAULT value for column : fails to parse as a valid literal value." - ] - }, - "_LEGACY_ERROR_TEMP_1345" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." - ] - }, - "_LEGACY_ERROR_TEMP_1346" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." - ] - }, - "_LEGACY_ERROR_TEMP_2000" : { - "message" : [ - ". If necessary set to false to bypass this error." - ] - }, "_LEGACY_ERROR_TEMP_2003" : { "message" : [ "Unsuccessful try to zip maps with unique keys due to exceeding the array size limit ." @@ -6361,16 +6944,6 @@ "Type does not support ordered operations." ] }, - "_LEGACY_ERROR_TEMP_2011" : { - "message" : [ - "Unexpected data type ." - ] - }, - "_LEGACY_ERROR_TEMP_2013" : { - "message" : [ - "Negative values found in " - ] - }, "_LEGACY_ERROR_TEMP_2017" : { "message" : [ "not resolved." @@ -6451,11 +7024,6 @@ " is not implemented." ] }, - "_LEGACY_ERROR_TEMP_2042" : { - "message" : [ - ". If necessary set to false to bypass this error." - ] - }, "_LEGACY_ERROR_TEMP_2045" : { "message" : [ "Unsupported table change: " @@ -6506,11 +7074,6 @@ "Unable to clear partition directory prior to writing to it." ] }, - "_LEGACY_ERROR_TEMP_2058" : { - "message" : [ - "Failed to cast value `` to `` for partition column ``." - ] - }, "_LEGACY_ERROR_TEMP_2059" : { "message" : [ "End of stream." @@ -6538,21 +7101,11 @@ "Invalid namespace name: ." ] }, - "_LEGACY_ERROR_TEMP_2067" : { - "message" : [ - "Unsupported partition transform: ." - ] - }, "_LEGACY_ERROR_TEMP_2068" : { "message" : [ "Missing database location." ] }, - "_LEGACY_ERROR_TEMP_2069" : { - "message" : [ - "Cannot remove reserved property: ." - ] - }, "_LEGACY_ERROR_TEMP_2070" : { "message" : [ "Writing job failed." @@ -6573,11 +7126,6 @@ "user-specified schema." ] }, - "_LEGACY_ERROR_TEMP_2075" : { - "message" : [ - "Write is not supported for binary file data source." - ] - }, "_LEGACY_ERROR_TEMP_2076" : { "message" : [ "The length of is , which exceeds the max length allowed: ." @@ -6638,11 +7186,6 @@ "Invalid value `` for parameter `` in table writing via JDBC. The minimum value is 1." ] }, - "_LEGACY_ERROR_TEMP_2088" : { - "message" : [ - " is not supported yet." - ] - }, "_LEGACY_ERROR_TEMP_2089" : { "message" : [ "DataType: ." @@ -6675,7 +7218,7 @@ }, "_LEGACY_ERROR_TEMP_2097" : { "message" : [ - "Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1." + "Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1 or remove the broadcast hint if it exists in your code." ] }, "_LEGACY_ERROR_TEMP_2098" : { @@ -6698,11 +7241,6 @@ "Dictionary encoding should not be used because of dictionary overflow." ] }, - "_LEGACY_ERROR_TEMP_2104" : { - "message" : [ - "End of the iterator." - ] - }, "_LEGACY_ERROR_TEMP_2105" : { "message" : [ "Could not allocate memory to grow BytesToBytesMap." @@ -6743,16 +7281,6 @@ "Unexpected window function frame ." ] }, - "_LEGACY_ERROR_TEMP_2113" : { - "message" : [ - "Unable to parse as a percentile." - ] - }, - "_LEGACY_ERROR_TEMP_2114" : { - "message" : [ - " is not a recognised statistic." - ] - }, "_LEGACY_ERROR_TEMP_2115" : { "message" : [ "Unknown column: ." @@ -6803,22 +7331,6 @@ "Exception when registering StreamingQueryListener." ] }, - "_LEGACY_ERROR_TEMP_2138" : { - "message" : [ - "Cannot have circular references in bean class, but got the circular reference of class ." - ] - }, - "_LEGACY_ERROR_TEMP_2139" : { - "message" : [ - "cannot have circular references in class, but got the circular reference of class ." - ] - }, - "_LEGACY_ERROR_TEMP_2140" : { - "message" : [ - "`` is not a valid identifier of Java and cannot be used as field name", - "." - ] - }, "_LEGACY_ERROR_TEMP_2144" : { "message" : [ "Unable to find constructor for . This could happen if is an interface, or a trait without companion object constructor." @@ -6844,11 +7356,6 @@ "null value found but field is not nullable." ] }, - "_LEGACY_ERROR_TEMP_2150" : { - "message" : [ - "Due to Scala's limited support of tuple, tuple with more than 22 elements are not supported." - ] - }, "_LEGACY_ERROR_TEMP_2154" : { "message" : [ "Failed to get outer pointer for ." @@ -6869,21 +7376,11 @@ "Initial type must be an , a or a ." ] }, - "_LEGACY_ERROR_TEMP_2165" : { - "message" : [ - "Malformed records are detected in schema inference. Parse Mode: ." - ] - }, "_LEGACY_ERROR_TEMP_2166" : { "message" : [ "Malformed JSON." ] }, - "_LEGACY_ERROR_TEMP_2167" : { - "message" : [ - "Malformed records are detected in schema inference. Parse Mode: . Reasons: Failed to infer a common schema. Struct types are expected, but `` was found." - ] - }, "_LEGACY_ERROR_TEMP_2168" : { "message" : [ "Decorrelate inner query through is not supported." @@ -7053,11 +7550,6 @@ "StreamingRelationExec cannot be executed." ] }, - "_LEGACY_ERROR_TEMP_2211" : { - "message" : [ - "Invalid output mode: ." - ] - }, "_LEGACY_ERROR_TEMP_2212" : { "message" : [ "Invalid catalog name: ." @@ -7143,11 +7635,6 @@ "Primitive types are not supported." ] }, - "_LEGACY_ERROR_TEMP_2232" : { - "message" : [ - "Value at index is null." - ] - }, "_LEGACY_ERROR_TEMP_2233" : { "message" : [ "Only Data Sources providing FileFormat are supported: ." @@ -7173,21 +7660,6 @@ ".getParentLogger is not yet implemented." ] }, - "_LEGACY_ERROR_TEMP_2238" : { - "message" : [ - "Unable to create Parquet converter for whose Parquet type is without decimal metadata. Please read this column/field as Spark BINARY type." - ] - }, - "_LEGACY_ERROR_TEMP_2239" : { - "message" : [ - "Unable to create Parquet converter for decimal type whose Parquet type is . Parquet DECIMAL type can only be backed by INT32, INT64, FIXED_LEN_BYTE_ARRAY, or BINARY." - ] - }, - "_LEGACY_ERROR_TEMP_2240" : { - "message" : [ - "Unable to create Parquet converter for data type whose Parquet type is ." - ] - }, "_LEGACY_ERROR_TEMP_2241" : { "message" : [ "Nonatomic partition table can not add multiple partitions." @@ -7332,11 +7804,6 @@ "comment on table is not supported." ] }, - "_LEGACY_ERROR_TEMP_2271" : { - "message" : [ - "UpdateColumnNullability is not supported." - ] - }, "_LEGACY_ERROR_TEMP_2272" : { "message" : [ "Rename column is only supported for MySQL version 8.0 and above." @@ -7601,11 +8068,6 @@ " is not currently supported" ] }, - "_LEGACY_ERROR_TEMP_3055" : { - "message" : [ - "ScalarFunction '' neither implement magic method nor override 'produceResult'" - ] - }, "_LEGACY_ERROR_TEMP_3056" : { "message" : [ "Unexpected row-level read relations (allow multiple = ): " @@ -7739,11 +8201,6 @@ "No handler for UDF/UDAF/UDTF '': " ] }, - "_LEGACY_ERROR_TEMP_3085" : { - "message" : [ - "from_avro() doesn't support the mode. Acceptable modes are and ." - ] - }, "_LEGACY_ERROR_TEMP_3086" : { "message" : [ "Cannot persist into Hive metastore as table property keys may not start with 'spark.sql.': " @@ -7964,11 +8421,6 @@ "Partitions truncate is not supported" ] }, - "_LEGACY_ERROR_TEMP_3146" : { - "message" : [ - "Cannot find a compatible ScalarFunction#produceResult" - ] - }, "_LEGACY_ERROR_TEMP_3147" : { "message" : [ ": Batch scan are not supported" @@ -8034,11 +8486,6 @@ "continuous mode is not supported!" ] }, - "_LEGACY_ERROR_TEMP_3168" : { - "message" : [ - "hasTimedOut is true however there's no timeout configured" - ] - }, "_LEGACY_ERROR_TEMP_3169" : { "message" : [ "AcceptsLatestSeenOffset is not supported with DSv1 streaming source: " @@ -8209,36 +8656,6 @@ "The number of fields () in the partition identifier is not equal to the partition schema length (). The identifier might not refer to one partition." ] }, - "_LEGACY_ERROR_TEMP_3209" : { - "message" : [ - "Illegal input for day of week: " - ] - }, - "_LEGACY_ERROR_TEMP_3210" : { - "message" : [ - "Interval string does not match second-nano format of ss.nnnnnnnnn" - ] - }, - "_LEGACY_ERROR_TEMP_3211" : { - "message" : [ - "Error parsing interval day-time string: " - ] - }, - "_LEGACY_ERROR_TEMP_3212" : { - "message" : [ - "Cannot support (interval '' to ) expression" - ] - }, - "_LEGACY_ERROR_TEMP_3213" : { - "message" : [ - "Error parsing interval string: " - ] - }, - "_LEGACY_ERROR_TEMP_3214" : { - "message" : [ - "Interval string does not match format of when cast to : " - ] - }, "_LEGACY_ERROR_TEMP_3215" : { "message" : [ "Expected a Boolean type expression in replaceNullWithFalse, but got the type in ." @@ -8289,11 +8706,6 @@ "Cannot delete map key" ] }, - "_LEGACY_ERROR_TEMP_3227" : { - "message" : [ - "Cannot find field: " - ] - }, "_LEGACY_ERROR_TEMP_3228" : { "message" : [ "AFTER column not found: " @@ -8379,14 +8791,24 @@ "'' is an invalid timestamp" ] }, - "_LEGACY_ERROR_TEMP_3261" : { + "_LEGACY_ERROR_TEMP_3262" : { + "message" : [ + "Doesn't support month or year interval: " + ] + }, + "_LEGACY_ERROR_TEMP_3300" : { "message" : [ - "Unknown output mode . Accepted output modes are 'append', 'complete', 'update'" + "error while calling spill() on : " ] }, - "_LEGACY_ERROR_TEMP_3262" : { + "_LEGACY_ERROR_TEMP_3301" : { "message" : [ - "Doesn't support month or year interval: " + "Not enough memory to grow pointer array" + ] + }, + "_LEGACY_ERROR_TEMP_3302" : { + "message" : [ + "No enough memory for aggregation" ] }, "_LEGACY_ERROR_USER_RAISED_EXCEPTION" : { diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index 0cd55bda7ba35..fb899e4eb207e 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4619,6 +4619,30 @@ "standard": "N", "usedBy": ["Spark"] }, + "42K0M": { + "description": "Invalid variable declaration.", + "origin": "Spark,", + "standard": "N", + "usedBy": ["Spark"] + }, + "42K0N": { + "description": "Invalid external type.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, + "42K0O": { + "description": "ScalarFunction not overrides method 'produceResult(InternalRow)' with custom implementation.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, + "42K0P": { + "description": "ScalarFunction not implements or overrides method 'produceResult(InternalRow)'.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", @@ -4889,6 +4913,12 @@ "standard": "N", "usedBy": ["SQL Server"] }, + "42S23": { + "description": "Partition transform expression not in 'partitionedBy'", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "44000": { "description": "with check option violation", "origin": "SQL/Foundation", @@ -7405,6 +7435,12 @@ "standard": "N", "usedBy": ["Databricks"] }, + "KD010": { + "description": "external data source failure", + "origin": "Databricks", + "standard": "N", + "usedBy": ["Databricks"] + }, "P0000": { "description": "procedural logic error", "origin": "PostgreSQL", diff --git a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala index 99b993f1127af..e2dd0da1aac85 100644 --- a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala +++ b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.net.URL -import scala.collection.immutable.Map import scala.jdk.CollectionConverters._ import com.fasterxml.jackson.annotation.JsonIgnore @@ -43,12 +42,16 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) { private[spark] val errorInfoMap = jsonFileURLs.map(ErrorClassesJsonReader.readAsMap).reduce(_ ++ _) - def getErrorMessage(errorClass: String, messageParameters: Map[String, String]): String = { + def getErrorMessage(errorClass: String, messageParameters: Map[String, Any]): String = { val messageTemplate = getMessageTemplate(errorClass) - val sub = new StringSubstitutor(messageParameters.asJava) + val sanitizedParameters = messageParameters.map { + case (key, null) => key -> "null" + case (key, value) => key -> value + } + val sub = new StringSubstitutor(sanitizedParameters.asJava) sub.setEnableUndefinedVariableException(true) sub.setDisableSubstitutionInValues(true) - try { + val errorMessage = try { sub.replace(ErrorClassesJsonReader.TEMPLATE_REGEX.replaceAllIn( messageTemplate, "\\$\\{$1\\}")) } catch { @@ -57,6 +60,17 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) { s"MessageTemplate: $messageTemplate, " + s"Parameters: $messageParameters", i) } + if (util.SparkEnvUtils.isTesting) { + val placeHoldersNum = ErrorClassesJsonReader.TEMPLATE_REGEX.findAllIn(messageTemplate).length + if (placeHoldersNum < sanitizedParameters.size) { + throw SparkException.internalError( + s"Found unused message parameters of the error class '$errorClass'. " + + s"Its error message format has $placeHoldersNum placeholders, " + + s"but the passed message parameters map has ${sanitizedParameters.size} items. " + + "Consider to add placeholders to the error format or remove unused message parameters.") + } + } + errorMessage } def getMessageParameters(errorClass: String): Seq[String] = { diff --git a/common/utils/src/main/scala/org/apache/spark/SparkException.scala b/common/utils/src/main/scala/org/apache/spark/SparkException.scala index 398cb1fad6726..0c0a1902ee2a1 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkException.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkException.scala @@ -69,7 +69,7 @@ class SparkException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -179,7 +179,7 @@ private[spark] class SparkUpgradeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull } /** @@ -212,7 +212,7 @@ private[spark] class SparkArithmeticException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -250,7 +250,7 @@ private[spark] class SparkUnsupportedOperationException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull } private[spark] object SparkUnsupportedOperationException { @@ -280,7 +280,7 @@ private[spark] class SparkClassNotFoundException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -296,7 +296,7 @@ private[spark] class SparkConcurrentModificationException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -306,8 +306,9 @@ private[spark] class SparkDateTimeException private( message: String, errorClass: Option[String], messageParameters: Map[String, String], - context: Array[QueryContext]) - extends DateTimeException(message) with SparkThrowable { + context: Array[QueryContext], + cause: Option[Throwable]) + extends DateTimeException(message, cause.orNull) with SparkThrowable { def this( errorClass: String, @@ -318,7 +319,23 @@ private[spark] class SparkDateTimeException private( SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), Option(errorClass), messageParameters, - context + context, + cause = None + ) + } + + def this( + errorClass: String, + messageParameters: Map[String, String], + context: Array[QueryContext], + summary: String, + cause: Option[Throwable]) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(errorClass), + messageParameters, + context, + cause.orElse(None) ) } @@ -329,7 +346,7 @@ private[spark] class SparkDateTimeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -345,7 +362,7 @@ private[spark] class SparkFileNotFoundException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -379,7 +396,7 @@ private[spark] class SparkNumberFormatException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -431,7 +448,7 @@ private[spark] class SparkIllegalArgumentException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -460,7 +477,7 @@ private[spark] class SparkRuntimeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -489,7 +506,7 @@ private[spark] class SparkPythonException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -507,7 +524,7 @@ private[spark] class SparkNoSuchElementException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass override def getQueryContext: Array[QueryContext] = context } @@ -524,7 +541,7 @@ private[spark] class SparkSecurityException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -558,7 +575,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -574,7 +591,7 @@ private[spark] class SparkSQLException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -589,5 +606,5 @@ private[spark] class SparkSQLFeatureNotSupportedException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index db5eff72e124a..b6c2b176de62b 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -74,14 +74,14 @@ private[spark] object SparkThrowableHelper { } def isInternalError(errorClass: String): Boolean = { - errorClass.startsWith("INTERNAL_ERROR") + errorClass != null && errorClass.startsWith("INTERNAL_ERROR") } def getMessage(e: SparkThrowable with Throwable, format: ErrorMessageFormat.Value): String = { import ErrorMessageFormat._ format match { case PRETTY => e.getMessage - case MINIMAL | STANDARD if e.getErrorClass == null => + case MINIMAL | STANDARD if e.getCondition == null => toJsonString { generator => val g = generator.useDefaultPrettyPrinter() g.writeStartObject() @@ -92,7 +92,7 @@ private[spark] object SparkThrowableHelper { g.writeEndObject() } case MINIMAL | STANDARD => - val errorClass = e.getErrorClass + val errorClass = e.getCondition toJsonString { generator => val g = generator.useDefaultPrettyPrinter() g.writeStartObject() diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index a6184038b5230..c365797cec690 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -266,6 +266,7 @@ private[spark] object LogKeys { case object FEATURE_NAME extends LogKey case object FETCH_SIZE extends LogKey case object FIELD_NAME extends LogKey + case object FIELD_TYPE extends LogKey case object FILES extends LogKey case object FILE_ABSOLUTE_PATH extends LogKey case object FILE_END_OFFSET extends LogKey @@ -396,6 +397,7 @@ private[spark] object LogKeys { case object MAX_NUM_PARTITIONS extends LogKey case object MAX_NUM_POSSIBLE_BINS extends LogKey case object MAX_NUM_ROWS_IN_MEMORY_BUFFER extends LogKey + case object MAX_SEEN_VERSION extends LogKey case object MAX_SERVICE_NAME_LENGTH extends LogKey case object MAX_SIZE extends LogKey case object MAX_SLOTS extends LogKey @@ -420,9 +422,11 @@ private[spark] object LogKeys { case object MIN_NUM_FREQUENT_PATTERN extends LogKey case object MIN_POINT_PER_CLUSTER extends LogKey case object MIN_RATE extends LogKey + case object MIN_SEEN_VERSION extends LogKey case object MIN_SHARE extends LogKey case object MIN_SIZE extends LogKey case object MIN_TIME extends LogKey + case object MIN_VERSIONS_TO_DELETE extends LogKey case object MIN_VERSION_NUM extends LogKey case object MISSING_PARENT_STAGES extends LogKey case object MODEL_WEIGHTS extends LogKey @@ -649,6 +653,7 @@ private[spark] object LogKeys { case object RECEIVER_IDS extends LogKey case object RECORDS extends LogKey case object RECOVERY_STATE extends LogKey + case object RECURSIVE_DEPTH extends LogKey case object REDACTED_STATEMENT extends LogKey case object REDUCE_ID extends LogKey case object REGEX extends LogKey @@ -687,6 +692,7 @@ private[spark] object LogKeys { case object RPC_ENDPOINT_REF extends LogKey case object RPC_MESSAGE_CAPACITY extends LogKey case object RPC_SSL_ENABLED extends LogKey + case object RULE_EXECUTOR_NAME extends LogKey case object RULE_NAME extends LogKey case object RUN_ID extends LogKey case object SCALA_VERSION extends LogKey @@ -746,6 +752,7 @@ private[spark] object LogKeys { case object START_INDEX extends LogKey case object START_TIME extends LogKey case object STATEMENT_ID extends LogKey + case object STATE_NAME extends LogKey case object STATE_STORE_ID extends LogKey case object STATE_STORE_PROVIDER extends LogKey case object STATE_STORE_VERSION extends LogKey @@ -850,6 +857,7 @@ private[spark] object LogKeys { case object USER_NAME extends LogKey case object UUID extends LogKey case object VALUE extends LogKey + case object VERSIONS_TO_DELETE extends LogKey case object VERSION_NUM extends LogKey case object VIEW_ACLS extends LogKey case object VIEW_ACLS_GROUPS extends LogKey diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala index 8eea9b44da26d..7471b764bd2b3 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala @@ -165,11 +165,17 @@ trait Logging { } protected def withLogContext(context: java.util.HashMap[String, String])(body: => Unit): Unit = { - val threadContext = CloseableThreadContext.putAll(context) + // put into thread context only when structured logging is enabled + val closeableThreadContextOpt = if (Logging.isStructuredLoggingEnabled) { + Some(CloseableThreadContext.putAll(context)) + } else { + None + } + try { body } finally { - threadContext.close() + closeableThreadContextOpt.foreach(_.close()) } } @@ -328,7 +334,7 @@ trait Logging { // If Log4j 2 is used but is initialized by default configuration, // load a default properties file // scalastyle:off println - if (Logging.islog4j2DefaultConfigured()) { + if (Logging.defaultSparkLog4jConfig || Logging.islog4j2DefaultConfigured()) { Logging.defaultSparkLog4jConfig = true val defaultLogProps = if (Logging.isStructuredLoggingEnabled) { "org/apache/spark/log4j2-defaults.properties" @@ -418,7 +424,6 @@ private[spark] object Logging { def uninitialize(): Unit = initLock.synchronized { if (isLog4j2()) { if (defaultSparkLog4jConfig) { - defaultSparkLog4jConfig = false val context = LogManager.getContext(false).asInstanceOf[LoggerContext] context.reconfigure() } else { diff --git a/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala b/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala new file mode 100644 index 0000000000000..1f1d3492d6ac5 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.annotation.DeveloperApi + +@DeveloperApi +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent { + /* Whether output this event to the event log */ + protected[spark] def logEvent: Boolean = true +} diff --git a/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 259f4330224c9..1972ef05d8759 100644 --- a/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -84,7 +84,7 @@ class StreamingQueryException private[sql]( s"""${classOf[StreamingQueryException].getName}: ${cause.getMessage} |$queryDebugString""".stripMargin - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava } diff --git a/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala index 4d729adfbb7eb..f88f267727c11 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala @@ -24,6 +24,7 @@ import com.fasterxml.jackson.core.{JsonEncoding, JsonGenerator} import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.apache.spark.util.SparkErrorUtils.tryWithResource private[spark] trait JsonUtils { @@ -31,12 +32,12 @@ private[spark] trait JsonUtils { .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) def toJsonString(block: JsonGenerator => Unit): String = { - val baos = new ByteArrayOutputStream() - val generator = mapper.createGenerator(baos, JsonEncoding.UTF8) - block(generator) - generator.close() - baos.close() - new String(baos.toByteArray, StandardCharsets.UTF_8) + tryWithResource(new ByteArrayOutputStream()) { baos => + tryWithResource(mapper.createGenerator(baos, JsonEncoding.UTF8)) { generator => + block(generator) + } + new String(baos.toByteArray, StandardCharsets.UTF_8) + } } } diff --git a/common/utils/src/main/scala/org/apache/spark/util/LogUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/LogUtils.scala index 5a798ffad3a92..8b41f10339271 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/LogUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/LogUtils.scala @@ -29,9 +29,9 @@ object LogUtils { /** * Schema for structured Spark logs. * Example usage: - * val logDf = spark.read.schema(LOG_SCHEMA).json("path/to/logs") + * val logDf = spark.read.schema(SPARK_LOG_SCHEMA).json("path/to/logs") */ - val LOG_SCHEMA: String = """ + val SPARK_LOG_SCHEMA: String = """ |ts TIMESTAMP, |level STRING, |msg STRING, diff --git a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala index 42a1d1612aeeb..d54a2f2ed9cea 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala @@ -342,7 +342,7 @@ private[spark] object MavenUtils extends Logging { } /* Set ivy settings for location of cache, if option is supplied */ - private def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = { + private[util] def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = { val alternateIvyDir = ivyPath.filterNot(_.trim.isEmpty).getOrElse { // To protect old Ivy-based systems like old Spark from Apache Ivy 2.5.2's incompatibility. System.getProperty("ivy.home", diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala index 7a4ef4a5ce81f..307006315a3c4 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala @@ -69,6 +69,73 @@ private[spark] trait SparkClassUtils { targetClass == null || targetClass.isAssignableFrom(cls) }.getOrElse(false) } + + /** Return the class name of the given object, removing all dollar signs */ + def getFormattedClassName(obj: AnyRef): String = { + getSimpleName(obj.getClass).replace("$", "") + } + + /** + * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. + * This method mimics scalatest's getSimpleNameOfAnObjectsClass. + */ + def getSimpleName(cls: Class[_]): String = { + try { + cls.getSimpleName + } catch { + // TODO: the value returned here isn't even quite right; it returns simple names + // like UtilsSuite$MalformedClassObject$MalformedClass instead of MalformedClass + // The exact value may not matter much as it's used in log statements + case _: InternalError => + stripDollars(stripPackages(cls.getName)) + } + } + + /** + * Remove the packages from full qualified class name + */ + private def stripPackages(fullyQualifiedName: String): String = { + fullyQualifiedName.split("\\.").takeRight(1)(0) + } + + /** + * Remove trailing dollar signs from qualified class name, + * and return the trailing part after the last dollar sign in the middle + */ + @scala.annotation.tailrec + final def stripDollars(s: String): String = { + val lastDollarIndex = s.lastIndexOf('$') + if (lastDollarIndex < s.length - 1) { + // The last char is not a dollar sign + if (lastDollarIndex == -1 || !s.contains("$iw")) { + // The name does not have dollar sign or is not an interpreter + // generated class, so we should return the full string + s + } else { + // The class name is interpreter generated, + // return the part after the last dollar sign + // This is the same behavior as getClass.getSimpleName + s.substring(lastDollarIndex + 1) + } + } + else { + // The last char is a dollar sign + // Find last non-dollar char + val lastNonDollarChar = s.findLast(_ != '$') + lastNonDollarChar match { + case None => s + case Some(c) => + val lastNonDollarIndex = s.lastIndexOf(c) + if (lastNonDollarIndex == -1) { + s + } else { + // Strip the trailing dollar signs + // Invoke stripDollars again to get the simple name + stripDollars(s.substring(0, lastNonDollarIndex + 1)) + } + } + } + } } private[spark] object SparkClassUtils extends SparkClassUtils diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala index 9f604e4bf47f2..872c89e5a29a2 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala @@ -48,6 +48,22 @@ private[spark] trait SparkErrorUtils extends Logging { try f.apply(resource) finally resource.close() } + /** + * Try to initialize a resource. If an exception is throw during initialization, closes the + * resource before propagating the error. Otherwise, the caller is responsible for closing + * the resource. This means that [[T]] should provide some way to close the resource. + */ + def tryInitializeResource[R <: Closeable, T](createResource: => R)(initialize: R => T): T = { + val resource = createResource + try { + initialize(resource) + } catch { + case e: Throwable => + resource.close() + throw e + } + } + /** * Execute a block of code, then a finally block, but if exceptions happen in * the finally block, do not suppress the original exception. diff --git a/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala b/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala index 76062074edcaf..140de836622f4 100644 --- a/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala +++ b/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala @@ -365,7 +365,7 @@ private[spark] object IvyTestUtils { useIvyLayout: Boolean = false, withPython: Boolean = false, withR: Boolean = false, - ivySettings: IvySettings = new IvySettings)(f: String => Unit): Unit = { + ivySettings: IvySettings = defaultIvySettings())(f: String => Unit): Unit = { val deps = dependencies.map(MavenUtils.extractMavenCoordinates) purgeLocalIvyCache(artifact, deps, ivySettings) val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, @@ -401,4 +401,16 @@ private[spark] object IvyTestUtils { } } } + + /** + * Creates and initializes a new instance of IvySettings with default configurations. + * The method processes the Ivy path argument using MavenUtils to ensure proper setup. + * + * @return A newly created and configured instance of IvySettings. + */ + private def defaultIvySettings(): IvySettings = { + val settings = new IvySettings + MavenUtils.processIvyPathArg(ivySettings = settings, ivyPath = None) + settings + } } diff --git a/common/utils/src/test/scala/org/apache/spark/util/StructuredLoggingSuite.scala b/common/utils/src/test/scala/org/apache/spark/util/StructuredLoggingSuite.scala index b3e103f46337c..48951c2084f17 100644 --- a/common/utils/src/test/scala/org/apache/spark/util/StructuredLoggingSuite.scala +++ b/common/utils/src/test/scala/org/apache/spark/util/StructuredLoggingSuite.scala @@ -40,7 +40,7 @@ trait LoggingSuiteBase } // Return the newly added log contents in the log file after executing the function `f` - private def captureLogOutput(f: () => Unit): String = { + protected def captureLogOutput(f: () => Unit): String = { val content = if (logFile.exists()) { Files.readString(logFile.toPath) } else { @@ -438,6 +438,44 @@ class StructuredLoggingSuite extends LoggingSuiteBase { assert((log"\r" + log"\n" + log"\t" + log"\b").message == "\r\n\t\b") assert((log"\r${MDC(LogKeys.EXECUTOR_ID, 1)}\n".message == "\r1\n")) } + + test("disabled structured logging won't log context") { + Logging.disableStructuredLogging() + val expectedPatternWithoutContext = compactAndToRegexPattern( + s""" + { + "ts": "", + "level": "INFO", + "msg": "Lost executor 1.", + "logger": "$className" + }""") + + Seq( + () => logInfo(log"Lost executor ${MDC(LogKeys.EXECUTOR_ID, "1")}."), + () => logInfo( // blocked when explicitly constructing the MessageWithContext + MessageWithContext( + "Lost executor 1.", + new java.util.HashMap[String, String] { put(LogKeys.EXECUTOR_ID.name, "1") } + ) + ) + ).foreach { f => + val logOutput = captureLogOutput(f) + assert(expectedPatternWithoutContext.r.matches(logOutput)) + } + Logging.enableStructuredLogging() + } + + test("setting to MDC gets logged") { + val mdcPattern = s""""${LogKeys.DATA.name}":"some-data"""" + + org.slf4j.MDC.put(LogKeys.DATA.name, "some-data") + val logOutputWithMDCSet = captureLogOutput(() => logInfo(msgWithMDC)) + assert(mdcPattern.r.findFirstIn(logOutputWithMDCSet).isDefined) + + org.slf4j.MDC.remove(LogKeys.DATA.name) + val logOutputWithoutMDCSet = captureLogOutput(() => logInfo(msgWithMDC)) + assert(mdcPattern.r.findFirstIn(logOutputWithoutMDCSet).isEmpty) + } } object CustomLogKeys { diff --git a/common/variant/README.md b/common/variant/README.md index 3e1b00c494755..58ebab7bd2651 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -322,8 +322,6 @@ Each `array_val` and `object_val` must contain exactly `num_elements + 1` values The "short string" basic type may be used as an optimization to fold string length into the type byte for strings less than 64 bytes. It is semantically identical to the "string" primitive type. -String and binary values may also be represented as an index into the metadata dictionary. (See “string from metadata” and “binary from metadata” in the “Primitive Types” table) Writers may choose to use this mechanism to avoid repeating identical string values in a Variant object. - The Decimal type contains a scale, but no precision. The implied precision of a decimal value is `floor(log_10(val)) + 1`. # Encoding types @@ -335,27 +333,25 @@ The Decimal type contains a scale, but no precision. The implied precision of a | Object | `2` | A collection of (string-key, variant-value) pairs | | Array | `3` | An ordered sequence of variant values | -| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | -|-----------------------------|---------|---------------------------|-----------------------------------------------------------------------------------------------------------| -| null | `0` | any | none | -| boolean (True) | `1` | BOOLEAN | none | -| boolean (False) | `2` | BOOLEAN | none | -| int8 | `3` | INT(8, signed) | 1 byte | -| int16 | `4` | INT(16, signed) | 2 byte little-endian | -| int32 | `5` | INT(32, signed) | 4 byte little-endian | -| int64 | `6` | INT(64, signed) | 8 byte little-endian | -| double | `7` | DOUBLE | IEEE little-endian | -| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| date | `11` | DATE | 4 byte little-endian | -| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | -| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | -| float | `14` | FLOAT | IEEE little-endian | -| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | -| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | -| binary from metadata | `17` | BINARY | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | -| string from metadata | `18` | STRING | Little-endian index into the metadata dictionary. Number of bytes is equal to the metadata `offset_size`. | +| Logical Type | Physical Type | Type ID | Equivalent Parquet Type | Binary format | +|----------------------|-----------------------------|---------|-----------------------------|---------------------------------------------------------------------------------------------------------------------| +| NullType | null | `0` | any | none | +| Boolean | boolean (True) | `1` | BOOLEAN | none | +| Boolean | boolean (False) | `2` | BOOLEAN | none | +| Exact Numeric | int8 | `3` | INT(8, signed) | 1 byte | +| Exact Numeric | int16 | `4` | INT(16, signed) | 2 byte little-endian | +| Exact Numeric | int32 | `5` | INT(32, signed) | 4 byte little-endian | +| Exact Numeric | int64 | `6` | INT(64, signed) | 8 byte little-endian | +| Double | double | `7` | DOUBLE | IEEE little-endian | +| Exact Numeric | decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Exact Numeric | decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Exact Numeric | decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Date | date | `11` | DATE | 4 byte little-endian | +| Timestamp | timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | +| TimestampNTZ | timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | +| Float | float | `14` | FLOAT | IEEE little-endian | +| Binary | binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | +| String | string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | | Decimal Precision | Decimal value type | |-----------------------|--------------------| @@ -364,6 +360,8 @@ The Decimal type contains a scale, but no precision. The implied precision of a | 18 <= precision <= 38 | int128 | | > 38 | Not supported | +The *Logical Type* column indicates logical equivalence of physically encoded types. For example, a user expression operating on a string value containing "hello" should behave the same, whether it is encoded with the short string optimization, or long string encoding. Similarly, user expressions operating on an *int8* value of 1 should behave the same as a decimal16 with scale 2 and unscaled value 100. + # Field ID order and uniqueness For objects, field IDs and offsets must be listed in the order of the corresponding field names, sorted lexicographically. Note that the fields themselves are not required to follow this order. As a result, offsets will not necessarily be listed in ascending order. diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index a705daaf323b2..a6fc6b534ee02 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -113,6 +113,11 @@ public String getString() { return VariantUtil.getString(value, pos); } + // Get the type info bits from a variant value. + public int getTypeInfo() { + return VariantUtil.getTypeInfo(value, pos); + } + // Get the value type of the variant. public Type getType() { return VariantUtil.getType(value, pos); @@ -188,6 +193,18 @@ public ObjectField getFieldAtIndex(int index) { }); } + // Get the dictionary ID for the object field at the `index` slot. Throws malformedVariant if + // `index` is out of the bound of `[0, objectSize())`. + // It is only legal to call it when `getType()` is `Type.OBJECT`. + public int getDictionaryIdAtIndex(int index) { + return handleObject(value, pos, (size, idSize, offsetSize, idStart, offsetStart, dataStart) -> { + if (index < 0 || index >= size) { + throw malformedVariant(); + } + return readUnsigned(value, idStart + idSize * index, idSize); + }); + } + // Get the number of array elements in the variant. // It is only legal to call it when `getType()` is `Type.ARRAY`. public int arraySize() { diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index 2afba81d192e9..32595baf6a4f2 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -26,10 +26,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; +import java.util.*; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonParser; @@ -43,24 +40,29 @@ * Build variant value and metadata by parsing JSON values. */ public class VariantBuilder { + public VariantBuilder(boolean allowDuplicateKeys) { + this.allowDuplicateKeys = allowDuplicateKeys; + } + /** * Parse a JSON string as a Variant value. * @throws VariantSizeLimitException if the resulting variant value or metadata would exceed * the SIZE_LIMIT (for example, this could be a maximum of 16 MiB). * @throws IOException if any JSON parsing error happens. */ - public static Variant parseJson(String json) throws IOException { + public static Variant parseJson(String json, boolean allowDuplicateKeys) throws IOException { try (JsonParser parser = new JsonFactory().createParser(json)) { parser.nextToken(); - return parseJson(parser); + return parseJson(parser, allowDuplicateKeys); } } /** - * Similar {@link #parseJson(String)}, but takes a JSON parser instead of string input. + * Similar {@link #parseJson(String, boolean)}, but takes a JSON parser instead of string input. */ - public static Variant parseJson(JsonParser parser) throws IOException { - VariantBuilder builder = new VariantBuilder(); + public static Variant parseJson(JsonParser parser, boolean allowDuplicateKeys) + throws IOException { + VariantBuilder builder = new VariantBuilder(allowDuplicateKeys); builder.buildJson(parser); return builder.result(); } @@ -105,6 +107,14 @@ public Variant result() { return new Variant(Arrays.copyOfRange(writeBuffer, 0, writePos), metadata); } + // Return the variant value only, without metadata. + // Used in shredding to produce a final value, where all shredded values refer to a common + // metadata. It is expected to be called instead of `result()`, although it is valid to call both + // methods, in any order. + public byte[] valueWithoutMetadata() { + return Arrays.copyOfRange(writeBuffer, 0, writePos); + } + public void appendString(String str) { byte[] text = str.getBytes(StandardCharsets.UTF_8); boolean longStr = text.length > MAX_SHORT_STR_SIZE; @@ -258,23 +268,63 @@ public int getWritePos() { // record the offset of the field. The offset is computed as `getWritePos() - start`. // 3. The caller calls `finishWritingObject` to finish writing a variant object. // - // This function is responsible to sort the fields by key and check for any duplicate field keys. + // This function is responsible to sort the fields by key. If there are duplicate field keys: + // - when `allowDuplicateKeys` is true, the field with the greatest offset value (the last + // appended one) is kept. + // - otherwise, throw an exception. public void finishWritingObject(int start, ArrayList fields) { - int dataSize = writePos - start; int size = fields.size(); Collections.sort(fields); int maxId = size == 0 ? 0 : fields.get(0).id; - // Check for duplicate field keys. Only need to check adjacent key because they are sorted. - for (int i = 1; i < size; ++i) { - maxId = Math.max(maxId, fields.get(i).id); - String key = fields.get(i).key; - if (key.equals(fields.get(i - 1).key)) { - @SuppressWarnings("unchecked") - Map parameters = Map$.MODULE$.empty().updated("key", key); - throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters, - null, new QueryContext[]{}, ""); + if (allowDuplicateKeys) { + int distinctPos = 0; + // Maintain a list of distinct keys in-place. + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + if (fields.get(i).id == fields.get(i - 1).id) { + // Found a duplicate key. Keep the field with a greater offset. + if (fields.get(distinctPos).offset < fields.get(i).offset) { + fields.set(distinctPos, fields.get(distinctPos).withNewOffset(fields.get(i).offset)); + } + } else { + // Found a distinct key. Add the field to the list. + ++distinctPos; + fields.set(distinctPos, fields.get(i)); + } + } + if (distinctPos + 1 < fields.size()) { + size = distinctPos + 1; + // Resize `fields` to `size`. + fields.subList(size, fields.size()).clear(); + // Sort the fields by offsets so that we can move the value data of each field to the new + // offset without overwriting the fields after it. + fields.sort(Comparator.comparingInt(f -> f.offset)); + int currentOffset = 0; + for (int i = 0; i < size; ++i) { + int oldOffset = fields.get(i).offset; + int fieldSize = VariantUtil.valueSize(writeBuffer, start + oldOffset); + System.arraycopy(writeBuffer, start + oldOffset, + writeBuffer, start + currentOffset, fieldSize); + fields.set(i, fields.get(i).withNewOffset(currentOffset)); + currentOffset += fieldSize; + } + writePos = start + currentOffset; + // Change back to the sort order by field keys to meet the variant spec. + Collections.sort(fields); + } + } else { + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + String key = fields.get(i).key; + if (key.equals(fields.get(i - 1).key)) { + @SuppressWarnings("unchecked") + Map parameters = Map$.MODULE$.empty().updated("key", key); + throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters, + null, new QueryContext[]{}, ""); + } } } + int dataSize = writePos - start; boolean largeSize = size > U8_MAX; int sizeBytes = largeSize ? U32_SIZE : 1; int idSize = getIntegerSize(maxId); @@ -362,15 +412,26 @@ private void appendVariantImpl(byte[] value, byte[] metadata, int pos) { }); break; default: - int size = valueSize(value, pos); - checkIndex(pos + size - 1, value.length); - checkCapacity(size); - System.arraycopy(value, pos, writeBuffer, writePos, size); - writePos += size; + shallowAppendVariantImpl(value, pos); break; } } + // Append the variant value without rewriting or creating any metadata. This is used when + // building an object during shredding, where there is a fixed pre-existing metadata that + // all shredded values will refer to. + public void shallowAppendVariant(Variant v) { + shallowAppendVariantImpl(v.value, v.pos); + } + + private void shallowAppendVariantImpl(byte[] value, int pos) { + int size = valueSize(value, pos); + checkIndex(pos + size - 1, value.length); + checkCapacity(size); + System.arraycopy(value, pos, writeBuffer, writePos, size); + writePos += size; + } + private void checkCapacity(int additional) { int required = writePos + additional; if (required > writeBuffer.length) { @@ -399,6 +460,10 @@ public FieldEntry(String key, int id, int offset) { this.offset = offset; } + FieldEntry withNewOffset(int newOffset) { + return new FieldEntry(key, id, newOffset); + } + @Override public int compareTo(FieldEntry other) { return key.compareTo(other.key); @@ -502,4 +567,5 @@ private boolean tryParseDecimal(String input) { private final HashMap dictionary = new HashMap<>(); // Store all keys in `dictionary` in the order of id. private final ArrayList dictionaryKeys = new ArrayList<>(); + private final boolean allowDuplicateKeys; } diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java new file mode 100644 index 0000000000000..551e46214859a --- /dev/null +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.types.variant; + +import java.util.HashMap; +import java.util.Map; + +/** + * Defines a valid shredding schema, as described in + * https://github.com/apache/parquet-format/blob/master/VariantShredding.md. + * A shredding schema contains a value and optional typed_value field. + * If a typed_value is an array or struct, it recursively contain its own shredding schema for + * elements and fields, respectively. + * The schema also contains a metadata field at the top level, but not in recursively shredded + * fields. + */ +public class VariantSchema { + + // Represents one field of an object in the shredding schema. + public static final class ObjectField { + public final String fieldName; + public final VariantSchema schema; + + public ObjectField(String fieldName, VariantSchema schema) { + this.fieldName = fieldName; + this.schema = schema; + } + + @Override + public String toString() { + return "ObjectField{" + + "fieldName=" + fieldName + + ", schema=" + schema + + '}'; + } + } + + public abstract static class ScalarType { + } + + public static final class StringType extends ScalarType { + } + + public enum IntegralSize { + BYTE, SHORT, INT, LONG + } + + public static final class IntegralType extends ScalarType { + public final IntegralSize size; + + public IntegralType(IntegralSize size) { + this.size = size; + } + } + + public static final class FloatType extends ScalarType { + } + + public static final class DoubleType extends ScalarType { + } + + public static final class BooleanType extends ScalarType { + } + + public static final class BinaryType extends ScalarType { + } + + public static final class DecimalType extends ScalarType { + public final int precision; + public final int scale; + + public DecimalType(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + } + + public static final class DateType extends ScalarType { + } + + public static final class TimestampType extends ScalarType { + } + + public static final class TimestampNTZType extends ScalarType { + } + + // The index of the typed_value, value, and metadata fields in the schema, respectively. If a + // given field is not in the schema, its value must be set to -1 to indicate that it is invalid. + // The indices of valid fields should be contiguous and start from 0. + public final int typedIdx; + public final int variantIdx; + // topLevelMetadataIdx must be non-negative in the top-level schema, and -1 at all other nesting + // levels. + public final int topLevelMetadataIdx; + // The number of fields in the schema. I.e. a value between 1 and 3, depending on which of value, + // typed_value and metadata are present. + public final int numFields; + + public final ScalarType scalarSchema; + public final ObjectField[] objectSchema; + // Map for fast lookup of object fields by name. The values are an index into `objectSchema`. + public final Map objectSchemaMap; + public final VariantSchema arraySchema; + + public VariantSchema(int typedIdx, int variantIdx, int topLevelMetadataIdx, int numFields, + ScalarType scalarSchema, ObjectField[] objectSchema, + VariantSchema arraySchema) { + this.typedIdx = typedIdx; + this.numFields = numFields; + this.variantIdx = variantIdx; + this.topLevelMetadataIdx = topLevelMetadataIdx; + this.scalarSchema = scalarSchema; + this.objectSchema = objectSchema; + if (objectSchema != null) { + objectSchemaMap = new HashMap<>(); + for (int i = 0; i < objectSchema.length; i++) { + objectSchemaMap.put(objectSchema[i].fieldName, i); + } + } else { + objectSchemaMap = null; + } + + this.arraySchema = arraySchema; + } + + @Override + public String toString() { + return "VariantSchema{" + + "typedIdx=" + typedIdx + + ", variantIdx=" + variantIdx + + ", topLevelMetadataIdx=" + topLevelMetadataIdx + + ", numFields=" + numFields + + ", scalarSchema=" + scalarSchema + + ", objectSchema=" + objectSchema + + ", arraySchema=" + arraySchema + + '}'; + } +} diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantShreddingWriter.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantShreddingWriter.java new file mode 100644 index 0000000000000..b5f8ea0a1484b --- /dev/null +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantShreddingWriter.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.types.variant; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; + +/** + * Class to implement shredding a Variant value. + */ +public class VariantShreddingWriter { + + // Interface to build up a shredded result. Callers should implement a ShreddedResultBuilder to + // create an empty result with a given schema. The castShredded method will call one or more of + // the add* methods to populate it. + public interface ShreddedResult { + // Create an array. The elements are the result of shredding each element. + void addArray(ShreddedResult[] array); + // Create an object. The values are the result of shredding each field, order by the index in + // objectSchema. Missing fields are populated with an empty result. + void addObject(ShreddedResult[] values); + void addVariantValue(byte[] result); + // Add a scalar to typed_value. The type of Object depends on the scalarSchema in the shredding + // schema. + void addScalar(Object result); + void addMetadata(byte[] result); + } + + public interface ShreddedResultBuilder { + ShreddedResult createEmpty(VariantSchema schema); + + // If true, we will shred decimals to a different scale or to integers, as long as they are + // numerically equivalent. Similarly, integers will be allowed to shred to decimals. + boolean allowNumericScaleChanges(); + } + + /** + * Converts an input variant into shredded components. Returns the shredded result, as well + * as the original Variant with shredded fields removed. + * `dataType` must be a valid shredding schema, as described in + * https://github.com/apache/parquet-format/blob/master/VariantShredding.md. + */ + public static ShreddedResult castShredded( + Variant v, + VariantSchema schema, + ShreddedResultBuilder builder) { + VariantUtil.Type variantType = v.getType(); + ShreddedResult result = builder.createEmpty(schema); + + if (schema.topLevelMetadataIdx >= 0) { + result.addMetadata(v.getMetadata()); + } + + if (schema.arraySchema != null && variantType == VariantUtil.Type.ARRAY) { + // The array element is always a struct containing untyped and typed fields. + VariantSchema elementSchema = schema.arraySchema; + int size = v.arraySize(); + ShreddedResult[] array = new ShreddedResult[size]; + for (int i = 0; i < size; ++i) { + ShreddedResult shreddedArray = castShredded(v.getElementAtIndex(i), elementSchema, builder); + array[i] = shreddedArray; + } + result.addArray(array); + } else if (schema.objectSchema != null && variantType == VariantUtil.Type.OBJECT) { + VariantSchema.ObjectField[] objectSchema = schema.objectSchema; + ShreddedResult[] shreddedValues = new ShreddedResult[objectSchema.length]; + + // Create a variantBuilder for any field that exist in `v`, but not in the shredding schema. + VariantBuilder variantBuilder = new VariantBuilder(false); + ArrayList fieldEntries = new ArrayList<>(); + // Keep track of which schema fields we actually found in the Variant value. + int numFieldsMatched = 0; + int start = variantBuilder.getWritePos(); + for (int i = 0; i < v.objectSize(); ++i) { + Variant.ObjectField field = v.getFieldAtIndex(i); + Integer fieldIdx = schema.objectSchemaMap.get(field.key); + if (fieldIdx != null) { + // The field exists in the shredding schema. Recursively shred, and write the result. + ShreddedResult shreddedField = castShredded( + field.value, objectSchema[fieldIdx].schema, builder); + shreddedValues[fieldIdx] = shreddedField; + numFieldsMatched++; + } else { + // The field is not shredded. Put it in the untyped_value column. + int id = v.getDictionaryIdAtIndex(i); + fieldEntries.add(new VariantBuilder.FieldEntry( + field.key, id, variantBuilder.getWritePos() - start)); + variantBuilder.appendVariant(field.value); + } + } + if (numFieldsMatched < objectSchema.length) { + // Set missing fields to non-null with all fields set to null. + for (int i = 0; i < objectSchema.length; ++i) { + if (shreddedValues[i] == null) { + VariantSchema.ObjectField fieldSchema = objectSchema[i]; + ShreddedResult emptyChild = builder.createEmpty(fieldSchema.schema); + shreddedValues[i] = emptyChild; + numFieldsMatched += 1; + } + } + } + if (numFieldsMatched != objectSchema.length) { + // Since we just filled in all the null entries, this can only happen if we tried to write + // to the same field twice; i.e. the Variant contained duplicate fields, which is invalid. + throw VariantUtil.malformedVariant(); + } + result.addObject(shreddedValues); + if (variantBuilder.getWritePos() != start) { + // We added something to the untyped value. + variantBuilder.finishWritingObject(start, fieldEntries); + result.addVariantValue(variantBuilder.valueWithoutMetadata()); + } + } else if (schema.scalarSchema != null) { + VariantSchema.ScalarType scalarType = schema.scalarSchema; + Object typedValue = tryTypedShred(v, variantType, scalarType, builder); + if (typedValue != null) { + // Store the typed value. + result.addScalar(typedValue); + } else { + VariantBuilder variantBuilder = new VariantBuilder(false); + variantBuilder.appendVariant(v); + result.addVariantValue(v.getValue()); + } + } else { + // Store in untyped. + result.addVariantValue(v.getValue()); + } + return result; + } + + /** + * Tries to cast a Variant into a typed value. If the cast fails, returns null. + * + * @param v + * @param variantType The Variant Type of v + * @param targetType The target type + * @return The scalar value, or null if the cast is not valid. + */ + private static Object tryTypedShred( + Variant v, + VariantUtil.Type variantType, + VariantSchema.ScalarType targetType, + ShreddedResultBuilder builder) { + switch (variantType) { + case LONG: + if (targetType instanceof VariantSchema.IntegralType integralType) { + // Check that the target type can hold the actual value. + VariantSchema.IntegralSize size = integralType.size; + long value = v.getLong(); + switch (size) { + case BYTE: + if (value == (byte) value) { + return (byte) value; + } + break; + case SHORT: + if (value == (short) value) { + return (short) value; + } + break; + case INT: + if (value == (int) value) { + return (int) value; + } + break; + case LONG: + return value; + } + } else if (targetType instanceof VariantSchema.DecimalType decimalType && + builder.allowNumericScaleChanges()) { + // If the integer can fit in the given decimal precision, allow it. + long value = v.getLong(); + // Set to the requested scale, and check if the precision is large enough. + BigDecimal decimalValue = BigDecimal.valueOf(value); + BigDecimal scaledValue = decimalValue.setScale(decimalType.scale); + // The initial value should have scale 0, so rescaling shouldn't lose information. + assert(decimalValue.compareTo(scaledValue) == 0); + if (scaledValue.precision() <= decimalType.precision) { + return scaledValue; + } + } + break; + case DECIMAL: + if (targetType instanceof VariantSchema.DecimalType decimalType) { + // Use getDecimalWithOriginalScale so that we retain scale information if + // allowNumericScaleChanges() is false. + BigDecimal value = VariantUtil.getDecimalWithOriginalScale(v.value, v.pos); + if (value.precision() <= decimalType.precision && + value.scale() == decimalType.scale) { + return value; + } + if (builder.allowNumericScaleChanges()) { + // Convert to the target scale, and see if it fits. Rounding mode doesn't matter, + // since we'll reject it if it turned out to require rounding. + BigDecimal scaledValue = value.setScale(decimalType.scale, RoundingMode.FLOOR); + if (scaledValue.compareTo(value) == 0 && + scaledValue.precision() <= decimalType.precision) { + return scaledValue; + } + } + } else if (targetType instanceof VariantSchema.IntegralType integralType && + builder.allowNumericScaleChanges()) { + // Check if the decimal happens to be an integer. + BigDecimal value = v.getDecimal(); + VariantSchema.IntegralSize size = integralType.size; + // Try to cast to the appropriate type, and check if any information is lost. + switch (size) { + case BYTE: + if (value.compareTo(BigDecimal.valueOf(value.byteValue())) == 0) { + return value.byteValue(); + } + break; + case SHORT: + if (value.compareTo(BigDecimal.valueOf(value.shortValue())) == 0) { + return value.shortValue(); + } + break; + case INT: + if (value.compareTo(BigDecimal.valueOf(value.intValue())) == 0) { + return value.intValue(); + } + break; + case LONG: + if (value.compareTo(BigDecimal.valueOf(value.longValue())) == 0) { + return value.longValue(); + } + } + } + break; + case BOOLEAN: + if (targetType instanceof VariantSchema.BooleanType) { + return v.getBoolean(); + } + break; + case STRING: + if (targetType instanceof VariantSchema.StringType) { + return v.getString(); + } + break; + case DOUBLE: + if (targetType instanceof VariantSchema.DoubleType) { + return v.getDouble(); + } + break; + case DATE: + if (targetType instanceof VariantSchema.DateType) { + return (int) v.getLong(); + } + break; + case TIMESTAMP: + if (targetType instanceof VariantSchema.TimestampType) { + return v.getLong(); + } + break; + case TIMESTAMP_NTZ: + if (targetType instanceof VariantSchema.TimestampNTZType) { + return v.getLong(); + } + break; + case FLOAT: + if (targetType instanceof VariantSchema.FloatType) { + return v.getFloat(); + } + break; + case BINARY: + if (targetType instanceof VariantSchema.BinaryType) { + return v.getBinary(); + } + break; + } + // The stored type does not match the requested shredding type. Return null, and the caller + // will store the result in untyped_value. + return null; + } + + // Add the result to the shredding result. + private static void addVariantValueVariant(Variant variantResult, + VariantSchema schema, ShreddedResult result) { + result.addVariantValue(variantResult.getValue()); + } + +} diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index 84e3a45e4b0ee..86609eef5d908 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -171,6 +171,12 @@ static SparkRuntimeException malformedVariant() { Map$.MODULE$.empty(), null, new QueryContext[]{}, ""); } + static SparkRuntimeException unknownPrimitiveTypeInVariant(int id) { + return new SparkRuntimeException("UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT", + new scala.collection.immutable.Map.Map1<>("id", Integer.toString(id)), null, + new QueryContext[]{}, ""); + } + // An exception indicating that an external caller tried to call the Variant constructor with // value or metadata exceeding the 16MiB size limit. We will never construct a Variant this large, // so it should only be possible to encounter this exception when reading a Variant produced by @@ -235,6 +241,11 @@ public enum Type { BINARY, } + public static int getTypeInfo(byte[] value, int pos) { + checkIndex(pos, value.length); + return (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + } + // Get the value type of variant value `value[pos...]`. It is only legal to call `get*` if // `getType` returns this type (for example, it is only legal to call `getLong` if `getType` // returns `Type.Long`). @@ -281,7 +292,7 @@ public static Type getType(byte[] value, int pos) { case LONG_STR: return Type.STRING; default: - throw malformedVariant(); + throw unknownPrimitiveTypeInVariant(typeInfo); } } } @@ -332,7 +343,7 @@ public static int valueSize(byte[] value, int pos) { case LONG_STR: return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE); default: - throw malformedVariant(); + throw unknownPrimitiveTypeInVariant(typeInfo); } } } @@ -356,8 +367,9 @@ public static boolean getBoolean(byte[] value, int pos) { // Get a long value from variant value `value[pos...]`. // It is only legal to call it if `getType` returns one of `Type.LONG/DATE/TIMESTAMP/ // TIMESTAMP_NTZ`. If the type is `DATE`, the return value is guaranteed to fit into an int and - // represents the number of days from the Unix epoch. If the type is `TIMESTAMP/TIMESTAMP_NTZ`, - // the return value represents the number of microseconds from the Unix epoch. + // represents the number of days from the Unix epoch. + // If the type is `TIMESTAMP/TIMESTAMP_NTZ`, the return value represents the number of + // microseconds from the Unix epoch. // Throw `MALFORMED_VARIANT` if the variant is malformed. public static long getLong(byte[] value, int pos) { checkIndex(pos, value.length); @@ -401,7 +413,7 @@ private static void checkDecimal(BigDecimal d, int maxPrecision) { // Get a decimal value from variant value `value[pos...]`. // Throw `MALFORMED_VARIANT` if the variant is malformed. - public static BigDecimal getDecimal(byte[] value, int pos) { + public static BigDecimal getDecimalWithOriginalScale(byte[] value, int pos) { checkIndex(pos, value.length); int basicType = value[pos] & BASIC_TYPE_MASK; int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; @@ -433,7 +445,11 @@ public static BigDecimal getDecimal(byte[] value, int pos) { default: throw unexpectedType(Type.DECIMAL); } - return result.stripTrailingZeros(); + return result; + } + + public static BigDecimal getDecimal(byte[] value, int pos) { + return getDecimalWithOriginalScale(value, pos).stripTrailingZeros(); } // Get a float value from variant value `value[pos...]`. diff --git a/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain b/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain deleted file mode 100644 index 1ef91ef8c36ac..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [from_avro(bytes#0, {"type": "int", "name": "id"}, (mode,FAILFAST), (compression,zstandard)) AS from_avro(bytes)#0] -+- LocalRelation , [id#0L, bytes#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain b/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain deleted file mode 100644 index 6eb4805b4fcc4..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [from_protobuf(bytes#0, StorageLevel, Some([B)) AS from_protobuf(bytes)#0] -+- LocalRelation , [id#0L, bytes#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain b/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain deleted file mode 100644 index c4a47b1aef07b..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [from_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS from_protobuf(bytes)#0] -+- LocalRelation , [id#0L, bytes#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_aggregate.explain b/connect/common/src/test/resources/query-tests/explain-results/function_aggregate.explain deleted file mode 100644 index 31fe84066f8c7..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_aggregate.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [aggregate(e#0, 0, lambdafunction((lambda x#0 + lambda y#0), lambda x#0, lambda y#0, false), lambdafunction(lambda x#0, lambda x#0, false)) AS aggregate(e, 0, lambdafunction((namedlambdavariable() + namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain b/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain deleted file mode 100644 index a78195c4ae295..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false)) AS array_compact(e)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_array_sort_with_comparator.explain b/connect/common/src/test/resources/query-tests/explain-results/function_array_sort_with_comparator.explain deleted file mode 100644 index cd86bcc5ffdf5..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_array_sort_with_comparator.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [array_sort(e#0, lambdafunction((lambda x#0 - lambda y#0), lambda x#0, lambda y#0, false), false) AS array_sort(e, lambdafunction((namedlambdavariable() - namedlambdavariable()), namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_exists.explain b/connect/common/src/test/resources/query-tests/explain-results/function_exists.explain deleted file mode 100644 index 1fab4ccb3a86a..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_exists.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [exists(e#0, lambdafunction((lambda x#0 > 10), lambda x#0, false)) AS exists(e, lambdafunction((namedlambdavariable() > 10), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_filter.explain b/connect/common/src/test/resources/query-tests/explain-results/function_filter.explain deleted file mode 100644 index a92b212666c05..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_filter.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [filter(e#0, lambdafunction((lambda x#0 > 10), lambda x#0, false)) AS filter(e, lambdafunction((namedlambdavariable() > 10), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_filter_with_pair_input.explain b/connect/common/src/test/resources/query-tests/explain-results/function_filter_with_pair_input.explain deleted file mode 100644 index 63ab17bd1e55e..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_filter_with_pair_input.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [filter(e#0, lambdafunction(((lambda x#0 > 10) AND (lambda y#0 > 2)), lambda x#0, lambda y#0, false)) AS filter(e, lambdafunction(((namedlambdavariable() > 10) AND (namedlambdavariable() > 2)), namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_forall.explain b/connect/common/src/test/resources/query-tests/explain-results/function_forall.explain deleted file mode 100644 index e69389808a457..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_forall.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [forall(e#0, lambdafunction((lambda x#0 > 10), lambda x#0, false)) AS forall(e, lambdafunction((namedlambdavariable() > 10), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain b/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain deleted file mode 100644 index 50ab91560e64a..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [json_array_length(g#0) AS json_array_length(g)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_map_filter.explain b/connect/common/src/test/resources/query-tests/explain-results/function_map_filter.explain deleted file mode 100644 index 4e2502d0c988d..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_map_filter.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [map_filter(f#0, lambdafunction(Contains(lambda x#0, baz), lambda x#0, lambda y#0, false)) AS map_filter(f, lambdafunction(contains(namedlambdavariable(), baz), namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_map_from_entries.explain b/connect/common/src/test/resources/query-tests/explain-results/function_map_from_entries.explain deleted file mode 100644 index 737900bef096d..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_map_from_entries.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [map_from_entries(transform(e#0, lambdafunction(struct(y, lambda y#0, x, lambda x#0), lambda x#0, lambda y#0, false))) AS map_from_entries(transform(e, lambdafunction(struct(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable())))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_map_zip_with.explain b/connect/common/src/test/resources/query-tests/explain-results/function_map_zip_with.explain deleted file mode 100644 index 2c053fa655853..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_map_zip_with.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [map_zip_with(f#0, f#0, lambdafunction((lambda y#0.id + lambda z#0.id), lambda x#0, lambda y#0, lambda z#0, false)) AS map_zip_with(f, f, lambdafunction((namedlambdavariable().id + namedlambdavariable().id), namedlambdavariable(), namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain b/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain deleted file mode 100644 index 3c874b5c8b6a4..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [parse_url(g#0, g#0, false) AS parse_url(g, g)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain b/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain deleted file mode 100644 index eba1c5c814fe3..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [parse_url(g#0, g#0, g#0, false) AS parse_url(g, g, g)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_reduce.explain b/connect/common/src/test/resources/query-tests/explain-results/function_reduce.explain deleted file mode 100644 index d25e5a306e217..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_reduce.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [reduce(e#0, 0, lambdafunction((lambda x#0 + lambda y#0), lambda x#0, lambda y#0, false), lambdafunction(lambda x#0, lambda x#0, false)) AS reduce(e, 0, lambdafunction((namedlambdavariable() + namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_csv.explain b/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_csv.explain deleted file mode 100644 index ecd181a4292de..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_csv.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [schema_of_csv(1|abc, (sep,|)) AS schema_of_csv(1|abc)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain b/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain deleted file mode 100644 index 8ec799bc58084..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [schema_of_json([{"col":01}]) AS schema_of_json([{"col":01}])#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain b/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain deleted file mode 100644 index 13867949177a4..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [schema_of_json([{"col":01}], (allowNumericLeadingZeros,true)) AS schema_of_json([{"col":01}])#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_locale.explain b/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_locale.explain deleted file mode 100644 index 7819f9b542340..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_locale.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [sentences(g#0, en, US) AS sentences(g, en, US)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_to_json.explain b/connect/common/src/test/resources/query-tests/explain-results/function_to_json.explain deleted file mode 100644 index cd72b12ee19b6..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_to_json.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [to_json((timestampFormat,dd/MM/yyyy), d#0, Some(America/Los_Angeles)) AS to_json(d)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_transform.explain b/connect/common/src/test/resources/query-tests/explain-results/function_transform.explain deleted file mode 100644 index 1eb446551f130..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_transform.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [transform(e#0, lambdafunction((lambda x#0 + 1), lambda x#0, false)) AS transform(e, lambdafunction((namedlambdavariable() + 1), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_transform_keys.explain b/connect/common/src/test/resources/query-tests/explain-results/function_transform_keys.explain deleted file mode 100644 index aae92957bcd0d..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_transform_keys.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [transform_keys(f#0, lambdafunction(concat(lambda x#0, cast(lambda y#0.id as string)), lambda x#0, lambda y#0, false)) AS transform_keys(f, lambdafunction(concat(namedlambdavariable(), namedlambdavariable().id), namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_transform_values.explain b/connect/common/src/test/resources/query-tests/explain-results/function_transform_values.explain deleted file mode 100644 index 3837ff0b78f02..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_transform_values.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [transform_values(f#0, lambdafunction(update_fields(lambda y#0, WithField(key, lambda x#0)), lambda x#0, lambda y#0, false)) AS transform_values(f, lambdafunction(update_fields(namedlambdavariable(), WithField(namedlambdavariable())), namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_transform_with_index.explain b/connect/common/src/test/resources/query-tests/explain-results/function_transform_with_index.explain deleted file mode 100644 index 99c7733b1f734..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_transform_with_index.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [transform(e#0, lambdafunction((lambda x#0 + lambda y#0), lambda x#0, lambda y#0, false)) AS transform(e, lambdafunction((namedlambdavariable() + namedlambdavariable()), namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_try_url_decode.explain b/connect/common/src/test/resources/query-tests/explain-results/function_try_url_decode.explain deleted file mode 100644 index 74b360a6b5f38..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_try_url_decode.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [tryeval(static_invoke(UrlCodec.decode(g#0, UTF-8))) AS try_url_decode(g)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain b/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain deleted file mode 100644 index 6111cc1374fb6..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [static_invoke(UrlCodec.decode(g#0, UTF-8)) AS url_decode(g)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain b/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain deleted file mode 100644 index 871842d41ba4f..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [static_invoke(UrlCodec.encode(g#0, UTF-8)) AS url_encode(g)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain deleted file mode 100644 index d9e2e55d9b12e..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath(s#0, a/b/text()) AS xpath(s, a/b/text())#0] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain deleted file mode 100644 index 9b75f81802467..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath_boolean(s#0, a/b) AS xpath_boolean(s, a/b)#0] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain deleted file mode 100644 index 9ce47136df242..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath_double(s#0, a/b) AS xpath_double(s, a/b)#0] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain deleted file mode 100644 index 02b29ec4afa9c..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath_float(s#0, a/b) AS xpath_float(s, a/b)#0] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain deleted file mode 100644 index cdd56eaa73199..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath_int(s#0, a/b) AS xpath_int(s, a/b)#0] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain deleted file mode 100644 index 3acefb13d0f8c..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath_long(s#0, a/b) AS xpath_long(s, a/b)#0L] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain deleted file mode 100644 index 0a30685f0c6d2..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath_number(s#0, a/b) AS xpath_number(s, a/b)#0] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain deleted file mode 100644 index ed440972bf490..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath_short(s#0, a/b) AS xpath_short(s, a/b)#0] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain b/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain deleted file mode 100644 index f4103f68c3bc3..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [xpath_string(s#0, a/b) AS xpath_string(s, a/b)#0] -+- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_zip_with.explain b/connect/common/src/test/resources/query-tests/explain-results/function_zip_with.explain deleted file mode 100644 index 53c9298360735..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/function_zip_with.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [zip_with(e#0, e#0, lambdafunction((lambda x#0 + lambda y#0), lambda x#0, lambda y#0, false)) AS zip_with(e, e, lambdafunction((namedlambdavariable() + namedlambdavariable()), namedlambdavariable(), namedlambdavariable()))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain b/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain deleted file mode 100644 index f61fc30a3a529..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain +++ /dev/null @@ -1,2 +0,0 @@ -Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, #0, value#0] -+- LocalRelation , [id#0L, a#0, b#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain b/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain deleted file mode 100644 index b5742d976dee9..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain +++ /dev/null @@ -1,2 +0,0 @@ -Expand [[a#0, id, id#0L]], [a#0, #0, value#0L] -+- LocalRelation , [id#0L, a#0, b#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/toJSON.explain b/connect/common/src/test/resources/query-tests/explain-results/toJSON.explain deleted file mode 100644 index 1698c562732e8..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/toJSON.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [to_json(struct(id, id#0L, a, a#0, b, b#0, d, d#0, e, e#0, f, f#0, g, g#0), Some(America/Los_Angeles)) AS to_json(struct(id, a, b, d, e, f, g))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain b/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain deleted file mode 100644 index a5371c70ac78a..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [to_avro(id#0L, None) AS to_avro(id)#0] -+- LocalRelation , [id#0L, a#0, b#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain b/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain deleted file mode 100644 index 7c688cc446947..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [to_protobuf(bytes#0, StorageLevel, Some([B)) AS to_protobuf(bytes)#0] -+- LocalRelation , [id#0L, bytes#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain b/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain deleted file mode 100644 index 9f05bb03c9c6d..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [to_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS to_protobuf(bytes)#0] -+- LocalRelation , [id#0L, bytes#0] diff --git a/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain b/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain deleted file mode 100644 index f61fc30a3a529..0000000000000 --- a/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain +++ /dev/null @@ -1,2 +0,0 @@ -Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, #0, value#0] -+- LocalRelation , [id#0L, a#0, b#0] diff --git a/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin b/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin deleted file mode 100644 index d46e40b39dcfe..0000000000000 Binary files a/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin and /dev/null differ diff --git a/connect/common/src/test/resources/query-tests/queries/groupby_max.proto.bin b/connect/common/src/test/resources/query-tests/queries/groupby_max.proto.bin deleted file mode 100644 index 651274b1afcac..0000000000000 Binary files a/connect/common/src/test/resources/query-tests/queries/groupby_max.proto.bin and /dev/null differ diff --git a/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin b/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin deleted file mode 100644 index 64dbb597c3650..0000000000000 Binary files a/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin and /dev/null differ diff --git a/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala deleted file mode 100644 index 4ef4f632204b3..0000000000000 --- a/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ /dev/null @@ -1,302 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connect.execution - -import scala.concurrent.{ExecutionContext, Promise} -import scala.jdk.CollectionConverters._ -import scala.util.Try -import scala.util.control.NonFatal - -import com.google.protobuf.Message -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.SparkSQLException -import org.apache.spark.connect.proto -import org.apache.spark.internal.Logging -import org.apache.spark.sql.connect.common.ProtoUtils -import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService} -import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.util.{ThreadUtils, Utils} - -/** - * This class launches the actual execution in an execution thread. The execution pushes the - * responses to a ExecuteResponseObserver in executeHolder. - */ -private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { - - private val promise: Promise[Unit] = Promise[Unit]() - - // The newly created thread will inherit all InheritableThreadLocals used by Spark, - // e.g. SparkContext.localProperties. If considering implementing a thread-pool, - // forwarding of thread locals needs to be taken into account. - private val executionThread: ExecutionThread = new ExecutionThread(promise) - - private var started: Boolean = false - - private var interrupted: Boolean = false - - private var completed: Boolean = false - - private val lock = new Object - - /** Launches the execution in a background thread, returns immediately. */ - private[connect] def start(): Unit = { - lock.synchronized { - assert(!started) - // Do not start if already interrupted. - if (!interrupted) { - executionThread.start() - started = true - } - } - } - - /** - * Register a callback that gets executed after completion/interruption of the execution thread. - */ - private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = { - promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext) - } - - /** - * Interrupt the executing thread. - * @return - * true if it was not interrupted before, false if it was already interrupted or completed. - */ - private[connect] def interrupt(): Boolean = { - lock.synchronized { - if (!started && !interrupted) { - // execution thread hasn't started yet, and will not be started. - // handle the interrupted error here directly. - interrupted = true - ErrorUtils.handleError( - "execute", - executeHolder.responseObserver, - executeHolder.sessionHolder.userId, - executeHolder.sessionHolder.sessionId, - Some(executeHolder.eventsManager), - interrupted)(new SparkSQLException("OPERATION_CANCELED", Map.empty)) - true - } else if (!interrupted && !completed) { - // checking completed prevents sending interrupt onError after onCompleted - interrupted = true - executionThread.interrupt() - true - } else { - false - } - } - } - - private def execute(): Unit = { - // Outer execute handles errors. - // Separate it from executeInternal to save on indent and improve readability. - try { - try { - executeInternal() - } catch { - // Need to catch throwable instead of NonFatal, because e.g. InterruptedException is fatal. - case e: Throwable => - logDebug(s"Exception in execute: $e") - // Always cancel all remaining execution after error. - executeHolder.sessionHolder.session.sparkContext.cancelJobsWithTag(executeHolder.jobTag) - // Rely on an internal interrupted flag, because Thread.interrupted() could be cleared, - // and different exceptions like InterruptedException, ClosedByInterruptException etc. - // could be thrown. - if (interrupted) { - throw new SparkSQLException("OPERATION_CANCELED", Map.empty) - } else { - // Rethrown the original error. - throw e - } - } finally { - executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag) - SparkConnectService.executionListener.foreach(_.removeJobTag(executeHolder.jobTag)) - executeHolder.sparkSessionTags.foreach { tag => - executeHolder.sessionHolder.session.sparkContext.removeJobTag( - ExecuteSessionTag( - executeHolder.sessionHolder.userId, - executeHolder.sessionHolder.sessionId, - tag)) - } - } - } catch { - ErrorUtils.handleError( - "execute", - executeHolder.responseObserver, - executeHolder.sessionHolder.userId, - executeHolder.sessionHolder.sessionId, - Some(executeHolder.eventsManager), - interrupted) - } - } - - // Inner executeInternal is wrapped by execute() for error handling. - private def executeInternal() = { - // synchronized - check if already got interrupted while starting. - lock.synchronized { - if (interrupted) { - throw new InterruptedException() - } - } - - // `withSession` ensures that session-specific artifacts (such as JARs and class files) are - // available during processing. - executeHolder.sessionHolder.withSession { session => - val debugString = requestString(executeHolder.request) - - // Set tag for query cancellation - session.sparkContext.addJobTag(executeHolder.jobTag) - // Register the job for progress reports. - SparkConnectService.executionListener.foreach(_.registerJobTag(executeHolder.jobTag)) - // Also set all user defined tags as Spark Job tags. - executeHolder.sparkSessionTags.foreach { tag => - session.sparkContext.addJobTag( - ExecuteSessionTag( - executeHolder.sessionHolder.userId, - executeHolder.sessionHolder.sessionId, - tag)) - } - session.sparkContext.setJobDescription( - s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") - session.sparkContext.setInterruptOnCancel(true) - - // Add debug information to the query execution so that the jobs are traceable. - session.sparkContext.setLocalProperty( - "callSite.short", - s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") - session.sparkContext.setLocalProperty( - "callSite.long", - StringUtils.abbreviate(debugString, 2048)) - - executeHolder.request.getPlan.getOpTypeCase match { - case proto.Plan.OpTypeCase.COMMAND => handleCommand(executeHolder.request) - case proto.Plan.OpTypeCase.ROOT => handlePlan(executeHolder.request) - case _ => - throw new UnsupportedOperationException( - s"${executeHolder.request.getPlan.getOpTypeCase} not supported.") - } - - val observedMetrics: Map[String, Seq[(Option[String], Any)]] = { - executeHolder.observations.map { case (name, observation) => - val values = observation.getOrEmpty.map { case (key, value) => - (Some(key), value) - }.toSeq - name -> values - }.toMap - } - val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = { - executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator => - accumulator.synchronized { - val value = accumulator.value.asScala.toSeq - if (value.nonEmpty) { - accumulator.reset() - Some("__python_accumulator__" -> value.map(value => (None, value))) - } else { - None - } - } - }.toMap - } - if (observedMetrics.nonEmpty || accumulatedInPython.nonEmpty) { - executeHolder.responseObserver.onNext( - SparkConnectPlanExecution - .createObservedMetricsResponse( - executeHolder.sessionHolder.sessionId, - executeHolder.sessionHolder.serverSessionId, - executeHolder.request.getPlan.getRoot.getCommon.getPlanId, - observedMetrics ++ accumulatedInPython)) - } - - lock.synchronized { - // Synchronized before sending ResultComplete, and up until completing the result stream - // to prevent a situation in which a client of reattachable execution receives - // ResultComplete, and proceeds to send ReleaseExecute, and that triggers an interrupt - // before it finishes. - - if (interrupted) { - // check if it got interrupted at the very last moment - throw new InterruptedException() - } - completed = true // no longer interruptible - - if (executeHolder.reattachable) { - // Reattachable execution sends a ResultComplete at the end of the stream - // to signal that there isn't more coming. - executeHolder.responseObserver.onNextComplete(createResultComplete()) - } else { - executeHolder.responseObserver.onCompleted() - } - } - } - } - - private def handlePlan(request: proto.ExecutePlanRequest): Unit = { - val responseObserver = executeHolder.responseObserver - - val execution = new SparkConnectPlanExecution(executeHolder) - execution.handlePlan(responseObserver) - } - - private def handleCommand(request: proto.ExecutePlanRequest): Unit = { - val responseObserver = executeHolder.responseObserver - - val command = request.getPlan.getCommand - val planner = new SparkConnectPlanner(executeHolder) - planner.process(command = command, responseObserver = responseObserver) - } - - private def requestString(request: Message) = { - try { - Utils.redact( - executeHolder.sessionHolder.session.sessionState.conf.stringRedactionPattern, - ProtoUtils.abbreviate(request).toString) - } catch { - case NonFatal(e) => - logWarning("Fail to extract debug information", e) - "UNKNOWN" - } - } - - private def createResultComplete(): proto.ExecutePlanResponse = { - // Send the Spark data type - proto.ExecutePlanResponse - .newBuilder() - .setResultComplete(proto.ExecutePlanResponse.ResultComplete.newBuilder().build()) - .build() - } - - private class ExecutionThread(onCompletionPromise: Promise[Unit]) - extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { - override def run(): Unit = { - try { - execute() - onCompletionPromise.success(()) - } catch { - case NonFatal(e) => - onCompletionPromise.failure(e) - } - } - } -} - -private[connect] object ExecuteThreadRunner { - private implicit val namedExecutionContext: ExecutionContext = ExecutionContext - .fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback")) -} diff --git a/connector/avro/benchmarks/AvroReadBenchmark-jdk21-results.txt b/connector/avro/benchmarks/AvroReadBenchmark-jdk21-results.txt index e0d9f9b90121f..0f4579f5da24f 100644 --- a/connector/avro/benchmarks/AvroReadBenchmark-jdk21-results.txt +++ b/connector/avro/benchmarks/AvroReadBenchmark-jdk21-results.txt @@ -2,140 +2,140 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 2124 2129 8 7.4 135.0 1.0X +Sum 2061 2066 7 7.6 131.0 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 2124 2129 7 7.4 135.1 1.0X +Sum 2032 2033 2 7.7 129.2 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 2082 2096 19 7.6 132.4 1.0X +Sum 1992 2005 19 7.9 126.6 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 2079 2091 17 7.6 132.2 1.0X +Sum 1992 2017 35 7.9 126.6 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 2070 2078 11 7.6 131.6 1.0X +Sum 1981 1981 0 7.9 125.9 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 2005 2023 25 7.8 127.5 1.0X +Sum 1981 1984 4 7.9 126.0 1.0X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of columns 3598 3606 12 2.9 343.1 1.0X +Sum of columns 3863 3867 5 2.7 368.4 1.0X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column 2081 2094 19 7.6 132.3 1.0X -Partition column 1913 1917 5 8.2 121.7 1.1X -Both columns 2141 2171 43 7.3 136.1 1.0X +Data column 2129 2143 20 7.4 135.4 1.0X +Partition column 1984 1986 2 7.9 126.1 1.1X +Both columns 2209 2231 31 7.1 140.4 1.0X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of string length 2299 2337 53 4.6 219.3 1.0X +Sum of string length 2511 2564 75 4.2 239.5 1.0X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of string length 3094 3118 33 3.4 295.1 1.0X +Sum of string length 3303 3317 20 3.2 315.0 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of string length 2162 2213 72 4.9 206.2 1.0X +Sum of string length 2221 2252 44 4.7 211.8 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of string length 1205 1210 7 8.7 114.9 1.0X +Sum of string length 1170 1174 6 9.0 111.6 1.0X ================================================================================================ Select All From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Wide Column Scan from 1000 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select of all columns 17970 18066 135 0.0 35940.5 1.0X +Select of all columns 19095 19150 78 0.0 38190.4 1.0X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of single column 3222 3242 29 0.3 3072.7 1.0X +Sum of single column 3188 3192 4 0.3 3040.7 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of single column 6336 6343 9 0.2 6043.0 1.0X +Sum of single column 6229 6254 35 0.2 5940.6 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of single column 9410 9463 75 0.1 8974.0 1.0X +Sum of single column 9279 9318 56 0.1 8848.8 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -w/o filters 5535 5544 10 0.2 5535.0 1.0X -pushdown disabled 5450 5479 29 0.2 5450.1 1.0X -w/ filters 2335 2340 9 0.4 2334.5 2.4X +w/o filters 5538 5544 7 0.2 5537.5 1.0X +pushdown disabled 5546 5571 24 0.2 5546.5 1.0X +w/ filters 2312 2324 18 0.4 2312.4 2.4X diff --git a/connector/avro/benchmarks/AvroReadBenchmark-results.txt b/connector/avro/benchmarks/AvroReadBenchmark-results.txt index f1065f98b81a2..db6193e67ac39 100644 --- a/connector/avro/benchmarks/AvroReadBenchmark-results.txt +++ b/connector/avro/benchmarks/AvroReadBenchmark-results.txt @@ -2,140 +2,140 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 1986 2030 63 7.9 126.3 1.0X +Sum 1959 1993 47 8.0 124.6 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 1983 2021 54 7.9 126.1 1.0X +Sum 1929 1949 28 8.2 122.6 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 1955 1977 30 8.0 124.3 1.0X +Sum 1965 1974 13 8.0 124.9 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 1939 1958 26 8.1 123.3 1.0X +Sum 1904 1918 20 8.3 121.0 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 1961 1963 3 8.0 124.7 1.0X +Sum 1916 1934 26 8.2 121.8 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum 1944 1946 3 8.1 123.6 1.0X +Sum 1911 1917 8 8.2 121.5 1.0X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of columns 3345 3376 44 3.1 319.0 1.0X +Sum of columns 3426 3450 34 3.1 326.8 1.0X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column 2006 2022 22 7.8 127.5 1.0X -Partition column 1761 1765 5 8.9 112.0 1.1X -Both columns 2054 2068 20 7.7 130.6 1.0X +Data column 1871 1888 24 8.4 118.9 1.0X +Partition column 1713 1720 9 9.2 108.9 1.1X +Both columns 1962 1970 12 8.0 124.7 1.0X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of string length 2002 2024 31 5.2 191.0 1.0X +Sum of string length 2042 2055 18 5.1 194.7 1.0X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of string length 3103 3141 54 3.4 295.9 1.0X +Sum of string length 3089 3109 28 3.4 294.6 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of string length 2056 2064 11 5.1 196.1 1.0X +Sum of string length 2103 2104 2 5.0 200.5 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of string length 1084 1086 3 9.7 103.3 1.0X +Sum of string length 1073 1079 10 9.8 102.3 1.0X ================================================================================================ Select All From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Wide Column Scan from 1000 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select of all columns 19331 19457 177 0.0 38662.8 1.0X +Select of all columns 18300 18346 64 0.0 36600.5 1.0X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of single column 3178 3191 18 0.3 3030.7 1.0X +Sum of single column 3144 3148 5 0.3 2998.8 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of single column 6288 6406 167 0.2 5996.4 1.0X +Sum of single column 6216 6229 19 0.2 5927.8 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Sum of single column 9478 9487 12 0.1 9039.2 1.0X +Sum of single column 9286 9585 423 0.1 8855.6 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -w/o filters 5488 5511 22 0.2 5488.5 1.0X -pushdown disabled 5495 5510 15 0.2 5494.9 1.0X -w/ filters 2218 2232 12 0.5 2218.1 2.5X +w/o filters 5308 5326 23 0.2 5307.5 1.0X +pushdown disabled 5253 5288 33 0.2 5252.7 1.0X +w/ filters 2036 2061 24 0.5 2036.3 2.6X diff --git a/connector/avro/benchmarks/AvroWriteBenchmark-jdk21-results.txt b/connector/avro/benchmarks/AvroWriteBenchmark-jdk21-results.txt index f49e7db17093e..a071bc767cfaa 100644 --- a/connector/avro/benchmarks/AvroWriteBenchmark-jdk21-results.txt +++ b/connector/avro/benchmarks/AvroWriteBenchmark-jdk21-results.txt @@ -1,56 +1,56 @@ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 1589 1616 39 9.9 101.0 1.0X -Output Single Double Column 1522 1536 20 10.3 96.8 1.0X -Output Int and String Column 3264 3266 2 4.8 207.5 0.5X -Output Partitions 3054 3094 57 5.1 194.2 0.5X -Output Buckets 4024 4078 76 3.9 255.9 0.4X +Output Single Int Column 1566 1588 30 10.0 99.6 1.0X +Output Single Double Column 1655 1668 18 9.5 105.3 0.9X +Output Int and String Column 3493 3496 5 4.5 222.1 0.4X +Output Partitions 3062 3112 71 5.1 194.7 0.5X +Output Buckets 3937 3952 20 4.0 250.3 0.4X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro compression with different codec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -BZIP2: 115682 115702 29 0.0 1156821.8 1.0X -DEFLATE: 6294 6309 20 0.0 62944.7 18.4X -UNCOMPRESSED: 5130 5148 26 0.0 51301.8 22.5X -SNAPPY: 4611 4643 45 0.0 46106.1 25.1X -XZ: 64308 64406 139 0.0 643084.7 1.8X -ZSTANDARD: 4651 4687 51 0.0 46509.7 24.9X +BZIP2: 115765 115975 297 0.0 1157649.1 1.0X +DEFLATE: 6345 6370 35 0.0 63448.5 18.2X +UNCOMPRESSED: 5183 5184 1 0.0 51827.4 22.3X +SNAPPY: 4611 4614 3 0.0 46112.5 25.1X +XZ: 54096 57854 5315 0.0 540956.3 2.1X +ZSTANDARD: 4877 4888 15 0.0 48770.9 23.7X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro deflate with different levels: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -DEFLATE: deflate.level=1 4666 4669 5 0.0 46656.5 1.0X -DEFLATE: deflate.level=3 4646 4648 3 0.0 46463.0 1.0X -DEFLATE: deflate.level=5 6223 6230 11 0.0 62226.1 0.7X -DEFLATE: deflate.level=7 6272 6282 15 0.0 62715.3 0.7X -DEFLATE: deflate.level=9 6628 6635 9 0.0 66283.6 0.7X +DEFLATE: deflate.level=1 4807 4847 57 0.0 48065.8 1.0X +DEFLATE: deflate.level=3 4803 4809 7 0.0 48033.5 1.0X +DEFLATE: deflate.level=5 6373 6389 22 0.0 63728.7 0.8X +DEFLATE: deflate.level=7 6427 6460 47 0.0 64266.6 0.7X +DEFLATE: deflate.level=9 6628 6634 10 0.0 66277.2 0.7X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro xz with different levels: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -XZ: xz.level=1 12512 12545 46 0.0 125121.6 1.0X -XZ: xz.level=3 23744 23832 124 0.0 237441.6 0.5X -XZ: xz.level=5 48209 50241 2874 0.0 482091.5 0.3X -XZ: xz.level=7 69424 69655 327 0.0 694240.2 0.2X -XZ: xz.level=9 142278 142354 108 0.0 1422778.3 0.1X +XZ: xz.level=1 11974 12000 37 0.0 119736.9 1.0X +XZ: xz.level=3 21671 21813 201 0.0 216709.0 0.6X +XZ: xz.level=5 47055 47335 397 0.0 470545.5 0.3X +XZ: xz.level=7 74766 75069 428 0.0 747658.3 0.2X +XZ: xz.level=9 146478 146490 16 0.0 1464783.7 0.1X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro zstandard with different levels: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------- -ZSTANDARD: zstandard.level=1 4669 4670 2 0.0 46688.6 1.0X -ZSTANDARD: zstandard.level=1, zstandard.bufferPool.enabled=true 4689 4701 16 0.0 46893.1 1.0X -ZSTANDARD: zstandard.level=3 4805 4819 20 0.0 48048.3 1.0X -ZSTANDARD: zstandard.level=3, zstandard.bufferPool.enabled=true 4667 4670 4 0.0 46666.8 1.0X -ZSTANDARD: zstandard.level=5 4985 5014 41 0.0 49852.2 0.9X -ZSTANDARD: zstandard.level=5, zstandard.bufferPool.enabled=true 4950 4991 59 0.0 49499.4 0.9X -ZSTANDARD: zstandard.level=7 5282 5291 13 0.0 52820.2 0.9X -ZSTANDARD: zstandard.level=7, zstandard.bufferPool.enabled=true 5221 5260 55 0.0 52208.0 0.9X -ZSTANDARD: zstandard.level=9 5997 6034 52 0.0 59974.4 0.8X -ZSTANDARD: zstandard.level=9, zstandard.bufferPool.enabled=true 5888 5949 85 0.0 58885.0 0.8X +ZSTANDARD: zstandard.level=1 4760 4797 53 0.0 47598.3 1.0X +ZSTANDARD: zstandard.level=1, zstandard.bufferPool.enabled=true 4666 4696 43 0.0 46656.0 1.0X +ZSTANDARD: zstandard.level=3 4845 4869 33 0.0 48452.8 1.0X +ZSTANDARD: zstandard.level=3, zstandard.bufferPool.enabled=true 4790 4801 16 0.0 47896.5 1.0X +ZSTANDARD: zstandard.level=5 5125 5164 55 0.0 51248.6 0.9X +ZSTANDARD: zstandard.level=5, zstandard.bufferPool.enabled=true 4912 4928 22 0.0 49122.5 1.0X +ZSTANDARD: zstandard.level=7 5319 5333 19 0.0 53192.1 0.9X +ZSTANDARD: zstandard.level=7, zstandard.bufferPool.enabled=true 5250 5284 48 0.0 52501.2 0.9X +ZSTANDARD: zstandard.level=9 6087 6087 0 0.0 60869.7 0.8X +ZSTANDARD: zstandard.level=9, zstandard.bufferPool.enabled=true 6219 6234 21 0.0 62191.3 0.8X diff --git a/connector/avro/benchmarks/AvroWriteBenchmark-results.txt b/connector/avro/benchmarks/AvroWriteBenchmark-results.txt index 658b9ad7851d4..1a605c0ea0e90 100644 --- a/connector/avro/benchmarks/AvroWriteBenchmark-results.txt +++ b/connector/avro/benchmarks/AvroWriteBenchmark-results.txt @@ -1,56 +1,56 @@ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Output Single Int Column 1566 1615 69 10.0 99.6 1.0X -Output Single Double Column 1718 1720 3 9.2 109.2 0.9X -Output Int and String Column 3250 3250 0 4.8 206.6 0.5X -Output Partitions 2869 2870 0 5.5 182.4 0.5X -Output Buckets 3655 3660 7 4.3 232.4 0.4X +Output Single Int Column 1544 1567 34 10.2 98.1 1.0X +Output Single Double Column 1635 1647 17 9.6 104.0 0.9X +Output Int and String Column 3324 3334 15 4.7 211.3 0.5X +Output Partitions 2961 3047 122 5.3 188.2 0.5X +Output Buckets 3776 3778 3 4.2 240.1 0.4X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro compression with different codec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -BZIP2: 131005 132600 2255 0.0 1310049.5 1.0X -DEFLATE: 6673 6696 34 0.0 66725.6 19.6X -UNCOMPRESSED: 5469 5506 51 0.0 54692.2 24.0X -SNAPPY: 4970 5003 47 0.0 49696.0 26.4X -XZ: 55374 55620 347 0.0 553743.6 2.4X -ZSTANDARD: 4998 5044 64 0.0 49984.1 26.2X +BZIP2: 130388 131379 1402 0.0 1303881.3 1.0X +DEFLATE: 6523 6538 21 0.0 65227.6 20.0X +UNCOMPRESSED: 5394 5425 43 0.0 53944.9 24.2X +SNAPPY: 4813 4816 3 0.0 48134.6 27.1X +XZ: 54364 54382 26 0.0 543640.7 2.4X +ZSTANDARD: 4864 4873 13 0.0 48635.9 26.8X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro deflate with different levels: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -DEFLATE: deflate.level=1 4996 5017 30 0.0 49961.8 1.0X -DEFLATE: deflate.level=3 5013 5026 18 0.0 50129.7 1.0X -DEFLATE: deflate.level=5 6557 6574 23 0.0 65574.0 0.8X -DEFLATE: deflate.level=7 6593 6624 44 0.0 65929.1 0.8X -DEFLATE: deflate.level=9 6973 6983 14 0.0 69725.4 0.7X +DEFLATE: deflate.level=1 4909 4916 9 0.0 49091.1 1.0X +DEFLATE: deflate.level=3 4874 4903 42 0.0 48735.8 1.0X +DEFLATE: deflate.level=5 6460 6473 19 0.0 64601.7 0.8X +DEFLATE: deflate.level=7 6450 6482 46 0.0 64497.5 0.8X +DEFLATE: deflate.level=9 6875 6878 5 0.0 68745.4 0.7X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro xz with different levels: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -XZ: xz.level=1 12335 12414 113 0.0 123346.1 1.0X -XZ: xz.level=3 22830 22901 101 0.0 228298.8 0.5X -XZ: xz.level=5 47861 48099 336 0.0 478610.6 0.3X -XZ: xz.level=7 71299 71967 944 0.0 712993.0 0.2X -XZ: xz.level=9 159311 159585 388 0.0 1593106.7 0.1X +XZ: xz.level=1 12210 12226 22 0.0 122101.7 1.0X +XZ: xz.level=3 22235 22235 0 0.0 222346.3 0.5X +XZ: xz.level=5 47597 47659 88 0.0 475969.7 0.3X +XZ: xz.level=7 69231 69482 356 0.0 692308.3 0.2X +XZ: xz.level=9 147042 148998 2766 0.0 1470415.9 0.1X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Avro zstandard with different levels: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------- -ZSTANDARD: zstandard.level=1 4917 4951 48 0.0 49169.8 1.0X -ZSTANDARD: zstandard.level=1, zstandard.bufferPool.enabled=true 4885 4904 26 0.0 48848.3 1.0X -ZSTANDARD: zstandard.level=3 5045 5051 9 0.0 50448.8 1.0X -ZSTANDARD: zstandard.level=3, zstandard.bufferPool.enabled=true 4926 4931 7 0.0 49258.9 1.0X -ZSTANDARD: zstandard.level=5 5366 5437 101 0.0 53656.6 0.9X -ZSTANDARD: zstandard.level=5, zstandard.bufferPool.enabled=true 5261 5305 62 0.0 52610.6 0.9X -ZSTANDARD: zstandard.level=7 5673 5680 9 0.0 56731.6 0.9X -ZSTANDARD: zstandard.level=7, zstandard.bufferPool.enabled=true 5592 5615 33 0.0 55917.0 0.9X -ZSTANDARD: zstandard.level=9 6662 6663 2 0.0 66620.2 0.7X -ZSTANDARD: zstandard.level=9, zstandard.bufferPool.enabled=true 6759 6760 1 0.0 67591.8 0.7X +ZSTANDARD: zstandard.level=1 4750 4817 94 0.0 47504.2 1.0X +ZSTANDARD: zstandard.level=1, zstandard.bufferPool.enabled=true 4753 4802 69 0.0 47532.8 1.0X +ZSTANDARD: zstandard.level=3 4920 4924 6 0.0 49198.5 1.0X +ZSTANDARD: zstandard.level=3, zstandard.bufferPool.enabled=true 4792 4799 9 0.0 47921.8 1.0X +ZSTANDARD: zstandard.level=5 5240 5276 51 0.0 52404.0 0.9X +ZSTANDARD: zstandard.level=5, zstandard.bufferPool.enabled=true 5072 5101 41 0.0 50722.5 0.9X +ZSTANDARD: zstandard.level=7 5542 5591 69 0.0 55416.5 0.9X +ZSTANDARD: zstandard.level=7, zstandard.bufferPool.enabled=true 5605 5617 17 0.0 56050.4 0.8X +ZSTANDARD: zstandard.level=9 6311 6403 130 0.0 63109.5 0.8X +ZSTANDARD: zstandard.level=9, zstandard.bufferPool.enabled=true 6324 6331 10 0.0 63236.4 0.8X diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 7d80998d96eb1..62c185a91d85c 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -24,10 +24,10 @@ import org.apache.avro.generic.GenericDatumReader import org.apache.avro.io.{BinaryDecoder, DecoderFactory} import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ private[sql] case class AvroDataToCatalyst( @@ -42,7 +42,8 @@ private[sql] case class AvroDataToCatalyst( val dt = SchemaConverters.toSqlType( expectedSchema, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType).dataType + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth).dataType parseMode match { // With PermissiveMode, the output Catalyst row might contain columns of null values for // corrupt records, even if some of the columns are not nullable in the user-provided schema. @@ -69,7 +70,8 @@ private[sql] case class AvroDataToCatalyst( dataType, avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) @transient private var decoder: BinaryDecoder = _ @@ -78,12 +80,9 @@ private[sql] case class AvroDataToCatalyst( @transient private lazy val parseMode: ParseMode = { val mode = avroOptions.parseMode if (mode != PermissiveMode && mode != FailFastMode) { - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3085", - messageParameters = Map( - "name" -> mode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + throw QueryCompilationErrors.parseModeUnsupportedError( + prettyName, mode + ) } mode } @@ -91,7 +90,7 @@ private[sql] case class AvroDataToCatalyst( @transient private lazy val nullResultRow: Any = dataType match { case st: StructType => val resultRow = new SpecificInternalRow(st.map(_.dataType)) - for(i <- 0 until st.length) { + for (i <- 0 until st.length) { resultRow.setNullAt(i) } resultRow @@ -121,12 +120,9 @@ private[sql] case class AvroDataToCatalyst( s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + "result, try setting the option 'mode' as 'PERMISSIVE'.", e) case _ => - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3085", - messageParameters = Map( - "name" -> parseMode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + throw QueryCompilationErrors.parseModeUnsupportedError( + prettyName, parseMode + ) } } } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala deleted file mode 100755 index 5830b2ec42383..0000000000000 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import scala.jdk.CollectionConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.Column - - -// scalastyle:off: object.name -object functions { -// scalastyle:on: object.name - - /** - * Converts a binary column of avro format into its corresponding catalyst value. The specified - * schema must match the read data, otherwise the behavior is undefined: it may fail or return - * arbitrary result. - * - * @param data the binary column. - * @param jsonFormatSchema the avro schema in JSON string format. - * - * @since 3.0.0 - */ - @Experimental - def from_avro( - data: Column, - jsonFormatSchema: String): Column = { - new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty)) - } - - /** - * Converts a binary column of Avro format into its corresponding catalyst value. - * The specified schema must match actual schema of the read data, otherwise the behavior - * is undefined: it may fail or return arbitrary result. - * To deserialize the data with a compatible and evolved schema, the expected Avro schema can be - * set via the option avroSchema. - * - * @param data the binary column. - * @param jsonFormatSchema the avro schema in JSON string format. - * @param options options to control how the Avro record is parsed. - * - * @since 3.0.0 - */ - @Experimental - def from_avro( - data: Column, - jsonFormatSchema: String, - options: java.util.Map[String, String]): Column = { - new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap)) - } - - /** - * Converts a column into binary of avro format. - * - * @param data the data column. - * - * @since 3.0.0 - */ - @Experimental - def to_avro(data: Column): Column = { - new Column(CatalystDataToAvro(data.expr, None)) - } - - /** - * Converts a column into binary of avro format. - * - * @param data the data column. - * @param jsonFormatSchema user-specified output avro schema in JSON string format. - * - * @since 3.0.0 - */ - @Experimental - def to_avro(data: Column, jsonFormatSchema: String): Column = { - new Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema))) - } -} diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 1083c99160724..a13faf3b51560 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -105,7 +105,8 @@ case class AvroPartitionReaderFactory( datetimeRebaseMode, avroFilters, options.useStableIdForUnionType, - options.stableIdPrefixForUnionType) + options.stableIdPrefixForUnionType, + options.recursiveFieldMaxDepth) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala index fe61fe3db8786..e898253be1168 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala @@ -37,15 +37,17 @@ case class AvroTable( fallbackFileFormat: Class[_ <: FileFormat]) extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): AvroScanBuilder = - new AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options)) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files) - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { new WriteBuilder { - override def build(): Write = AvroWrite(paths, formatName, supportsDataType, info) + override def build(): Write = + AvroWrite(paths, formatName, supportsDataType, mergedWriteInfo(info)) } + } override def supportsDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 388347537a4d6..311eda3a1b6ae 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -291,7 +291,8 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite RebaseSpec(LegacyBehaviorPolicy.CORRECTED), filters, false, - "") + "", + -1) val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala index 256b608feaa1f..0db9d284c4512 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala @@ -54,7 +54,7 @@ class AvroCodecSuite extends FileSourceCodecSuite { s"""CREATE TABLE avro_t |USING $format OPTIONS('compression'='unsupported') |AS SELECT 1 as id""".stripMargin)), - errorClass = "CODEC_SHORT_NAME_NOT_FOUND", + condition = "CODEC_SHORT_NAME_NOT_FOUND", sqlState = Some("42704"), parameters = Map("codecName" -> "unsupported") ) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index c807685db0f0c..096cdfe0b9ee4 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -22,16 +22,18 @@ import java.io.ByteArrayOutputStream import scala.jdk.CollectionConverters._ import org.apache.avro.{Schema, SchemaBuilder} -import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecordBuilder} -import org.apache.avro.io.EncoderFactory +import org.apache.avro.file.SeekableByteArrayInput +import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter, GenericRecord, GenericRecordBuilder} +import org.apache.avro.io.{DecoderFactory, EncoderFactory} import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.avro.{functions => Fns} import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} class AvroFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -104,6 +106,17 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { functions.from_avro( $"avro", avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)), expected) + + checkError( + exception = intercept[AnalysisException] { + avroStructDF.select( + functions.from_avro( + $"avro", avroTypeStruct, Map("mode" -> "DROPMALFORMED").asJava)).collect() + }, + condition = "PARSE_MODE_UNSUPPORTED", + parameters = Map( + "funcName" -> "`from_avro`", + "mode" -> "DROPMALFORMED")) } test("roundtrip in to_avro and from_avro - array with null") { @@ -316,6 +329,10 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { spark.sql(s"select from_avro(result, '$jsonFormatSchema', map()).u from ($toAvroSql)"), Seq(Row(Row(1, null)), Row(Row(null, "a")))) + // The 'jsonFormatSchema' argument of the 'to_avro' function is optional. + checkAnswer( + spark.sql(s"select length(to_avro(s)) > 0 from t"), + Seq(Row(true), Row(true))) // Negative tests. checkError( @@ -323,8 +340,8 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select to_avro(s, 42) as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = Map("sqlExpr" -> "\"toavro(s, 42)\"", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map("sqlExpr" -> "\"to_avro(s, 42)\"", "msg" -> ("The second argument of the TO_AVRO SQL function must be a constant string " + "containing the JSON representation of the schema to use for converting the value to " + "AVRO format"), @@ -338,8 +355,8 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select from_avro(s, 42, '') as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = Map("sqlExpr" -> "\"fromavro(s, 42, )\"", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map("sqlExpr" -> "\"from_avro(s, 42, )\"", "msg" -> ("The second argument of the FROM_AVRO SQL function must be a constant string " + "containing the JSON representation of the schema to use for converting the value " + "from AVRO format"), @@ -353,10 +370,10 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select from_avro(s, '$jsonFormatSchema', 42) as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> - s"\"fromavro(s, $jsonFormatSchema, 42)\"".stripMargin, + s"\"from_avro(s, $jsonFormatSchema, 42)\"".stripMargin, "msg" -> ("The third argument of the FROM_AVRO SQL function must be a constant map of " + "strings to strings containing the options to use for converting the value " + "from AVRO format"), @@ -367,4 +384,249 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { stop = 138))) } } + + + test("roundtrip in to_avro and from_avro - recursive schema") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4))), Row(1, null))), + catalystSchema).select(struct("Id", "Name").as("struct")) + + val avroStructDF = df.select(functions.to_avro($"struct", avroSchema).as("avro")) + checkAnswer(avroStructDF.select( + functions.from_avro($"avro", avroSchema, Map( + "recursiveFieldMaxDepth" -> "3").asJava)), df) + } + + private def serialize(record: GenericRecord, avroSchema: String): Array[Byte] = { + val schema = new Schema.Parser().parse(avroSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + var outputStream: ByteArrayOutputStream = null + var bytes: Array[Byte] = null + try { + outputStream = new ByteArrayOutputStream() + val encoder = EncoderFactory.get.binaryEncoder(outputStream, null) + datumWriter.write(record, encoder) + encoder.flush() + bytes = outputStream.toByteArray + } finally { + if (outputStream != null) { + outputStream.close() + } + } + bytes + } + + private def deserialize(bytes: Array[Byte], avroSchema: String): GenericRecord = { + val schema = new Schema.Parser().parse(avroSchema) + val datumReader = new GenericDatumReader[GenericRecord](schema) + var inputStream: SeekableByteArrayInput = null + var record: GenericRecord = null + try { + inputStream = new SeekableByteArrayInput(bytes) + val decoder = DecoderFactory.get.binaryDecoder(inputStream, null) + record = datumReader.read(null, decoder) + } finally { + if (inputStream != null) { + inputStream.close() + } + } + record + } + + // write: `GenericRecord` -> binary (by `serialize`) -> dataframe + // read: dataframe -> binary -> `GenericRecord` (by `deserialize`) + test("roundtrip in serialize and deserialize - GenericRecord") { + val avroSchema = + """ + |{ + | "type": "record", + | "name": "person", + | "fields": [ + | {"name": "name", "type": "string"}, + | {"name": "age", "type": "int"}, + | {"name": "country", "type": "string"} + | ] + |} + |""".stripMargin + val testTable = "test_avro" + withTable(testTable) { + val schema = new Schema.Parser().parse(avroSchema) + val person1 = new GenericRecordBuilder(schema) + .set("name", "sparkA") + .set("age", 18) + .set("country", "usa") + .build() + val person2 = new GenericRecordBuilder(schema) + .set("name", "sparkB") + .set("age", 19) + .set("country", "usb") + .build() + Seq(person1, person2) + .map(p => serialize(p, avroSchema)) + .toDF("data") + .repartition(1) + .writeTo(testTable) + .create() + + val expectedSchema = new StructType().add("data", BinaryType) + assert(spark.table(testTable).schema === expectedSchema) + + // Note that what is returned here is `Row[Array[Byte]]` + val avroDF = sql(s"SELECT data FROM $testTable") + val readbacks = avroDF + .collect() + .map(row => deserialize(row.get(0).asInstanceOf[Array[Byte]], avroSchema)) + + val readbackPerson1 = readbacks.head + assert(readbackPerson1.get(0).toString === person1.get(0)) + assert(readbackPerson1.get(1).asInstanceOf[Int] === person1.get(1).asInstanceOf[Int]) + assert(readbackPerson1.get(2).toString === person1.get(2)) + + val readbackPerson2 = readbacks(1) + assert(readbackPerson2.get(0).toString === person2.get(0)) + assert(readbackPerson2.get(1).asInstanceOf[Int] === person2.get(1).asInstanceOf[Int]) + assert(readbackPerson2.get(2).toString === person2.get(2)) + } + } + + // write: `GenericRecord` -> binary (by `serialize`) -> dataframe + // read: dataframe -> binary -> struct (by `from_avro`) -> `GenericRecord` + test("use `serialize` to write GenericRecord and `from_avro` to read GenericRecord") { + val avroSchema = + """ + |{ + | "type": "record", + | "name": "person", + | "fields": [ + | {"name": "name", "type": "string"}, + | {"name": "age", "type": "int"}, + | {"name": "country", "type": "string"} + | ] + |} + |""".stripMargin + val testTable = "test_avro" + withTable(testTable) { + val schema = new Schema.Parser().parse(avroSchema) + val person1 = new GenericRecordBuilder(schema) + .set("name", "sparkA") + .set("age", 18) + .set("country", "usa") + .build() + val person2 = new GenericRecordBuilder(schema) + .set("name", "sparkB") + .set("age", 19) + .set("country", "usb") + .build() + Seq(person1, person2) + .map(p => serialize(p, avroSchema)) + .toDF("data") + .repartition(1) + .writeTo(testTable) + .create() + + val expectedSchema = new StructType().add("data", BinaryType) + assert(spark.table(testTable).schema === expectedSchema) + + // Note that what is returned here is `Row[Struct]` + val avroDF = sql(s"SELECT from_avro(data, '$avroSchema', map()) FROM $testTable") + val readbacks = avroDF + .collect() + .map(row => + new GenericRecordBuilder(schema) + .set("name", row.getStruct(0).getString(0)) + .set("age", row.getStruct(0).getInt(1)) + .set("country", row.getStruct(0).getString(2)) + .build()) + + val readbackPerson1 = readbacks.head + assert(readbackPerson1.get(0) === person1.get(0)) + assert(readbackPerson1.get(1).asInstanceOf[Int] === person1.get(1).asInstanceOf[Int]) + assert(readbackPerson1.get(2) === person1.get(2)) + + val readbackPerson2 = readbacks(1) + assert(readbackPerson2.get(0) === person2.get(0)) + assert(readbackPerson2.get(1).asInstanceOf[Int] === person2.get(1).asInstanceOf[Int]) + assert(readbackPerson2.get(2) === person2.get(2)) + } + } + + // write: `GenericRecord` (to `struct`) -> binary (by `to_avro`) -> dataframe + // read: dataframe -> binary -> `GenericRecord` (by `deserialize`) + test("use `to_avro` to write GenericRecord and `deserialize` to read GenericRecord") { + val avroSchema = + """ + |{ + | "type": "record", + | "name": "person", + | "fields": [ + | {"name": "name", "type": "string"}, + | {"name": "age", "type": "int"}, + | {"name": "country", "type": "string"} + | ] + |} + |""".stripMargin + val testTable = "test_avro" + withTable(testTable) { + val schema = new Schema.Parser().parse(avroSchema) + val person1 = new GenericRecordBuilder(schema) + .set("name", "sparkA") + .set("age", 18) + .set("country", "usa") + .build() + val person2 = new GenericRecordBuilder(schema) + .set("name", "sparkB") + .set("age", 19) + .set("country", "usb") + .build() + Seq(person1, person2) + .map(p => ( + p.get(0).asInstanceOf[String], + p.get(1).asInstanceOf[Int], + p.get(2).asInstanceOf[String])) + .toDF("name", "age", "country") + .select(Fns.to_avro(struct($"name", $"age", $"country"), avroSchema).as("data")) + .repartition(1) + .writeTo(testTable) + .create() + + val expectedSchema = new StructType().add("data", BinaryType) + assert(spark.table(testTable).schema === expectedSchema) + + // Note that what is returned here is `Row[Array[Byte]]` + val avroDF = sql(s"select data from $testTable") + val readbacks = avroDF + .collect() + .map(row => row.get(0).asInstanceOf[Array[Byte]]) + .map(bytes => deserialize(bytes, avroSchema)) + + val readbackPerson1 = readbacks.head + assert(readbackPerson1.get(0).toString === person1.get(0)) + assert(readbackPerson1.get(1).asInstanceOf[Int] === person1.get(1).asInstanceOf[Int]) + assert(readbackPerson1.get(2).toString === person1.get(2)) + + val readbackPerson2 = readbacks(1) + assert(readbackPerson2.get(0).toString === person2.get(0)) + assert(readbackPerson2.get(1).asInstanceOf[Int] === person2.get(1).asInstanceOf[Int]) + assert(readbackPerson2.get(2).toString === person2.get(2)) + } + } } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 429f3c0deca6a..bb0858decdf8f 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -436,10 +436,10 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { val ex = intercept[SparkException] { spark.read.format("avro").load(s"$dir.avro").collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkArithmeticException], - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", parameters = Map( "value" -> "0", "precision" -> "4", diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 9b3bb929a700d..c1ab96a63eb26 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -77,7 +77,8 @@ class AvroRowReaderSuite RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index cbcbc2e7e76a6..3643a95abe19c 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -228,7 +228,8 @@ object AvroSerdeSuite { RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index ce38ada7c9e40..0df6a7c4bc90e 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.avro import java.io._ -import java.net.URL +import java.net.URI import java.nio.file.{Files, Paths, StandardCopyOption} import java.sql.{Date, Timestamp} import java.util.UUID @@ -648,7 +648,7 @@ abstract class AvroSuite assert(message.contains("No Avro files found.")) Files.copy( - Paths.get(new URL(episodesAvro).toURI), + Paths.get(new URI(episodesAvro)), Paths.get(dir.getCanonicalPath, "episodes.avro")) val result = spark.read.format("avro").load(episodesAvro).collect() @@ -760,7 +760,7 @@ abstract class AvroSuite assert(uncompressSize > deflateSize) assert(snappySize > deflateSize) assert(snappySize > bzip2Size) - assert(bzip2Size > xzSize) + assert(xzSize > bzip2Size) assert(uncompressSize > zstandardSize) } } @@ -891,10 +891,10 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema("a DECIMAL(4, 3)").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "decimal\\(12,10\\)", @@ -921,6 +921,39 @@ abstract class AvroSuite } } + test("SPARK-49082: Widening type promotions in AvroDeserializer") { + withTempPath { tempPath => + // Int -> Long + val intPath = s"$tempPath/int_data" + val intDf = Seq(1, Int.MinValue, Int.MaxValue).toDF("col") + intDf.write.format("avro").save(intPath) + checkAnswer( + spark.read.schema("col Long").format("avro").load(intPath), + Seq(Row(1L), Row(-2147483648L), Row(2147483647L)) + ) + + // Int -> Double + checkAnswer( + spark.read.schema("col Double").format("avro").load(intPath), + Seq(Row(1D), Row(-2147483648D), Row(2147483647D)) + ) + + // Float -> Double + val floatPath = s"$tempPath/float_data" + val floatDf = Seq(1F, + Float.MinValue, Float.MinPositiveValue, Float.MaxValue, + Float.NaN, Float.NegativeInfinity, Float.PositiveInfinity + ).toDF("col") + floatDf.write.format("avro").save(floatPath) + checkAnswer( + spark.read.schema("col Double").format("avro").load(floatPath), + Seq(Row(1D), + Row(-3.4028234663852886E38D), Row(1.401298464324817E-45D), Row(3.4028234663852886E38D), + Row(Double.NaN), Row(Double.NegativeInfinity), Row(Double.PositiveInfinity)) + ) + } + } + test("SPARK-43380: Fix Avro data type conversion" + " of DayTimeIntervalType to avoid producing incorrect results") { withTempPath { path => @@ -936,10 +969,10 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema(s"a $sqlType").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "interval day to second", @@ -973,10 +1006,10 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema(s"a $sqlType").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "interval year to month", @@ -1482,7 +1515,7 @@ abstract class AvroSuite .write.format("avro").option("avroSchema", avroSchema) .save(s"$tempDir/${UUID.randomUUID()}") } - assert(ex.getErrorClass == "TASK_WRITE_FAILED") + assert(ex.getCondition == "TASK_WRITE_FAILED") assert(ex.getCause.isInstanceOf[java.lang.NullPointerException]) assert(ex.getCause.getMessage.contains( "null value for (non-nullable) string at test_schema.Name")) @@ -1640,15 +1673,19 @@ abstract class AvroSuite exception = intercept[AnalysisException] { sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir) }, - errorClass = "_LEGACY_ERROR_TEMP_1136", - parameters = Map.empty + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + parameters = Map( + "format" -> "Avro", + "columnName" -> "`INTERVAL '1 days'`", + "columnType" -> "\"INTERVAL\"" + ) ) checkError( exception = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.format("avro").mode("overwrite").save(tempDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`testType()`", "columnType" -> "UDT(\"INTERVAL\")", @@ -2106,7 +2143,7 @@ abstract class AvroSuite test("SPARK-24805: do not ignore files without .avro extension by default") { withTempDir { dir => Files.copy( - Paths.get(new URL(episodesAvro).toURI), + Paths.get(new URI(episodesAvro)), Paths.get(dir.getCanonicalPath, "episodes")) val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes" @@ -2145,7 +2182,7 @@ abstract class AvroSuite test("SPARK-24836: ignoreExtension must override hadoop's config") { withTempDir { dir => Files.copy( - Paths.get(new URL(episodesAvro).toURI), + Paths.get(new URI(episodesAvro)), Paths.get(dir.getCanonicalPath, "episodes")) val hadoopConf = spark.sessionState.newHadoopConf() @@ -2187,7 +2224,8 @@ abstract class AvroSuite } } - private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { + private def checkSchemaWithRecursiveLoop(avroSchema: String, recursiveFieldMaxDepth: Int): + Unit = { val message = intercept[IncompatibleSchemaException] { SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false, "") }.getMessage @@ -2196,7 +2234,79 @@ abstract class AvroSuite } test("Detect recursive loop") { - checkSchemaWithRecursiveLoop(""" + for (recursiveFieldMaxDepth <- Seq(-1, 0)) { + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, // each element has a long + | {"name": "next", "type": ["null", "LongList"]} // optional next element + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields": [ + | { + | "name": "value", + | "type": { + | "type": "record", + | "name": "foo", + | "fields": [ + | { + | "name": "parent", + | "type": "LongList" + | } + | ] + | } + | } + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "array", "type": {"type": "array", "items": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "map", "type": {"type": "map", "values": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + } + } + + private def checkSparkSchemaEquals( + avroSchema: String, expectedSchema: StructType, recursiveFieldMaxDepth: Int): Unit = { + val sparkSchema = + SchemaConverters.toSqlType( + new Schema.Parser().parse(avroSchema), false, "", recursiveFieldMaxDepth).dataType + + assert(sparkSchema === expectedSchema) + } + + test("Translate recursive schema - union") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2205,9 +2315,57 @@ abstract class AvroSuite | {"name": "next", "type": ["null", "LongList"]} // optional next element | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("next", expectedSchema) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - union - 2 non-null fields") { + val avroSchema = """ + |{ + | "type": "record", + | "name": "TreeNode", + | "fields": [ + | { + | "name": "name", + | "type": "string" + | }, + | { + | "name": "value", + | "type": [ + | "long" + | ] + | }, + | { + | "name": "children", + | "type": [ + | "null", + | { + | "type": "array", + | "items": "TreeNode" + | } + | ], + | "default": null + | } + | ] + |} + """.stripMargin + val nonRecursiveFields = new StructType().add("name", StringType, nullable = false) + .add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("children", + new ArrayType(expectedSchema, false), nullable = true) + } + } + + test("Translate recursive schema - record") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2227,9 +2385,18 @@ abstract class AvroSuite | } | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", StructType(Seq()), nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = new StructType().add("value", + new StructType().add("parent", expectedSchema, nullable = false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - array") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2238,9 +2405,18 @@ abstract class AvroSuite | {"name": "array", "type": {"type": "array", "items": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("array", new ArrayType(expectedSchema, false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - map") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2249,7 +2425,70 @@ abstract class AvroSuite | {"name": "map", "type": {"type": "map", "values": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("map", + new MapType(StringType, expectedSchema, false), nullable = false) + } + } + + test("recursive schema integration test") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", NullType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4, null))), Row(1, null))), + catalystSchema) + + withTempPath { tempDir => + df.write.format("avro").save(tempDir.getPath) + + val exc = intercept[AnalysisException] { + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 16) + .load(tempDir.getPath) + } + assert(exc.getMessage.contains("Should not be greater than 15.")) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 10) + .load(tempDir.getPath), + df) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 1) + .load(tempDir.getPath), + df.select("Id")) + } } test("log a warning of ignoreExtension deprecation") { @@ -2394,7 +2633,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").option("avroSchema", avroSchema).save(path3_x) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") assert(e.getCause.isInstanceOf[SparkUpgradeException]) } checkDefaultLegacyRead(oldPath) @@ -2649,7 +2888,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").option("avroSchema", avroSchema).save(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } @@ -2660,7 +2899,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").save(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } @@ -2693,7 +2932,7 @@ abstract class AvroSuite |LOCATION '${dir}' |AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) @@ -2744,7 +2983,7 @@ abstract class AvroSuite } test("SPARK-40667: validate Avro Options") { - assert(AvroOptions.getAllOptions.size == 11) + assert(AvroOptions.getAllOptions.size == 12) // Please add validation on any new Avro options here assert(AvroOptions.isValidOption("ignoreExtension")) assert(AvroOptions.isValidOption("mode")) @@ -2757,6 +2996,7 @@ abstract class AvroSuite assert(AvroOptions.isValidOption("datetimeRebaseMode")) assert(AvroOptions.isValidOption("enableStableIdentifiersForUnionType")) assert(AvroOptions.isValidOption("stableIdentifierPrefixForUnionType")) + assert(AvroOptions.isValidOption("recursiveFieldMaxDepth")) } test("SPARK-46633: read file with empty blocks") { @@ -2798,7 +3038,7 @@ class AvroV1Suite extends AvroSuite { sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite) .format("avro").save(dir.getCanonicalPath) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) @@ -2811,7 +3051,7 @@ class AvroV1Suite extends AvroSuite { .write.mode(SaveMode.Overwrite) .format("avro").save(dir.getCanonicalPath) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 81ffb140226e8..2fdb2d4bafe01 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -45,6 +45,11 @@ spark-sql-api_${scala.binary.version} ${project.version} + + org.apache.spark + spark-connect-shims_${scala.binary.version} + ${project.version} + org.apache.spark spark-sketch_${scala.binary.version} @@ -88,6 +93,13 @@ scalacheck_${scala.binary.version} test + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${project.version} + tests + test + org.apache.spark spark-common-utils_${scala.binary.version} @@ -194,6 +206,28 @@ + + + org.apache.maven.plugins + maven-dependency-plugin + + + get-ammonite-jar + package + + copy-dependencies + + + ${basedir}/target/connect-repl + provided + + + + org.apache.maven.plugins maven-jar-plugin diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 04970f67e99f5..3777f82594aae 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -17,117 +17,27 @@ package org.apache.spark.sql -import java.util.Locale - import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{NAReplace, Relation} import org.apache.spark.connect.proto.Expression.{Literal => GLiteral} import org.apache.spark.connect.proto.NAReplace.Replacement -import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.sql.connect.ConnectConversions._ /** * Functionality for working with missing data in `DataFrame`s. * * @since 3.4.0 */ -final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) { - - /** - * Returns a new `DataFrame` that drops rows containing any null or NaN values. - * - * @since 3.4.0 - */ - def drop(): DataFrame = buildDropDataFrame(None, None) - - /** - * Returns a new `DataFrame` that drops rows containing null or NaN values. - * - * If `how` is "any", then drop rows containing any null or NaN values. If `how` is "all", then - * drop rows only if every column is null or NaN for that row. - * - * @since 3.4.0 - */ - def drop(how: String): DataFrame = { - buildDropDataFrame(None, buildMinNonNulls(how)) - } - - /** - * Returns a new `DataFrame` that drops rows containing any null or NaN values in the specified - * columns. - * - * @since 3.4.0 - */ - def drop(cols: Array[String]): DataFrame = drop(cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values - * in the specified columns. - * - * @since 3.4.0 - */ - def drop(cols: Seq[String]): DataFrame = buildDropDataFrame(Some(cols), None) - - /** - * Returns a new `DataFrame` that drops rows containing null or NaN values in the specified - * columns. - * - * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. - * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. - * - * @since 3.4.0 - */ - def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in - * the specified columns. - * - * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. - * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. - * - * @since 3.4.0 - */ - def drop(how: String, cols: Seq[String]): DataFrame = { - buildDropDataFrame(Some(cols), buildMinNonNulls(how)) - } +final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) + extends api.DataFrameNaFunctions { + import sparkSession.RichColumn - /** - * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and - * non-NaN values. - * - * @since 3.4.0 - */ - def drop(minNonNulls: Int): DataFrame = { - buildDropDataFrame(None, Some(minNonNulls)) - } - - /** - * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and - * non-NaN values in the specified columns. - * - * @since 3.4.0 - */ - def drop(minNonNulls: Int, cols: Array[String]): DataFrame = - drop(minNonNulls, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than `minNonNulls` - * non-null and non-NaN values in the specified columns. - * - * @since 3.4.0 - */ - def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { - buildDropDataFrame(Some(cols), Some(minNonNulls)) - } + override protected def drop(minNonNulls: Option[Int]): Dataset[Row] = + buildDropDataFrame(None, minNonNulls) - private def buildMinNonNulls(how: String): Option[Int] = { - how.toLowerCase(Locale.ROOT) match { - case "any" => None // No-Op. Do nothing. - case "all" => Some(1) - case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") - } - } + override protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row] = + buildDropDataFrame(Option(cols), minNonNulls) private def buildDropDataFrame( cols: Option[Seq[String]], @@ -139,110 +49,42 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: } } - /** - * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Long): DataFrame = { buildFillDataFrame(None, GLiteral.newBuilder().setLong(value).build()) } - /** - * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a - * specified column is not a numeric column, it is ignored. - * - * @since 3.4.0 - */ - def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified - * numeric columns. If a specified column is not a numeric column, it is ignored. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Long, cols: Seq[String]): DataFrame = { buildFillDataFrame(Some(cols), GLiteral.newBuilder().setLong(value).build()) } - /** - * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Double): DataFrame = { buildFillDataFrame(None, GLiteral.newBuilder().setDouble(value).build()) } - /** - * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a - * specified column is not a numeric column, it is ignored. - * - * @since 3.4.0 - */ - def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified - * numeric columns. If a specified column is not a numeric column, it is ignored. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Double, cols: Seq[String]): DataFrame = { buildFillDataFrame(Some(cols), GLiteral.newBuilder().setDouble(value).build()) } - /** - * Returns a new `DataFrame` that replaces null values in string columns with `value`. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: String): DataFrame = { buildFillDataFrame(None, GLiteral.newBuilder().setString(value).build()) } - /** - * Returns a new `DataFrame` that replaces null values in specified string columns. If a - * specified column is not a string column, it is ignored. - * - * @since 3.4.0 - */ - def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified string - * columns. If a specified column is not a string column, it is ignored. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: String, cols: Seq[String]): DataFrame = { buildFillDataFrame(Some(cols), GLiteral.newBuilder().setString(value).build()) } - /** - * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Boolean): DataFrame = { buildFillDataFrame(None, GLiteral.newBuilder().setBoolean(value).build()) } - /** - * Returns a new `DataFrame` that replaces null values in specified boolean columns. If a - * specified column is not a boolean column, it is ignored. - * - * @since 3.4.0 - */ - def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean - * columns. If a specified column is not a boolean column, it is ignored. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Boolean, cols: Seq[String]): DataFrame = { buildFillDataFrame(Some(cols), GLiteral.newBuilder().setBoolean(value).build()) } @@ -255,43 +97,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: } } - /** - * Returns a new `DataFrame` that replaces null values. - * - * The key of the map is the column name, and the value of the map is the replacement value. The - * value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`, - * `Boolean`. Replacement values are cast to the column data type. - * - * For example, the following replaces null values in column "A" with string "unknown", and null - * values in column "B" with numeric value 1.0. - * {{{ - * import com.google.common.collect.ImmutableMap; - * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); - * }}} - * - * @since 3.4.0 - */ - def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq) - - /** - * Returns a new `DataFrame` that replaces null values. - * - * The key of the map is the column name, and the value of the map is the replacement value. The - * value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`, - * `Boolean`. Replacement values are cast to the column data type. - * - * For example, the following replaces null values in column "A" with string "unknown", and null - * values in column "B" with numeric value 1.0. - * {{{ - * import com.google.common.collect.ImmutableMap; - * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); - * }}} - * - * @since 3.4.0 - */ - def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq) - - private def fillMap(values: Seq[(String, Any)]): DataFrame = { + protected def fillMap(values: Seq[(String, Any)]): DataFrame = { sparkSession.newDataFrame { builder => val fillNaBuilder = builder.getFillNaBuilder.setInput(root) values.map { case (colName, replaceValue) => @@ -300,104 +106,13 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: } } - /** - * Replaces values matching keys in `replacement` map with the corresponding values. - * - * {{{ - * import com.google.common.collect.ImmutableMap; - * - * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.na.replace("height", ImmutableMap.of(1.0, 2.0)); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); - * }}} - * - * @param col - * name of the column to apply the value replacement. If `col` is "*", replacement is applied - * on all string, numeric or boolean columns. - * @param replacement - * value replacement map. Key and value of `replacement` map must have the same type, and can - * only be doubles, strings or booleans. The map value can have nulls. - * @since 3.4.0 - */ - def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = - replace(col, replacement.asScala.toMap) - - /** - * (Scala-specific) Replaces values matching keys in `replacement` map. - * - * {{{ - * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.na.replace("height", Map(1.0 -> 2.0)); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.na.replace("name", Map("UNKNOWN" -> "unnamed")); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); - * }}} - * - * @param col - * name of the column to apply the value replacement. If `col` is "*", replacement is applied - * on all string, numeric or boolean columns. - * @param replacement - * value replacement map. Key and value of `replacement` map must have the same type, and can - * only be doubles, strings or booleans. The map value can have nulls. - * @since 3.4.0 - */ + /** @inheritdoc */ def replace[T](col: String, replacement: Map[T, T]): DataFrame = { val cols = if (col != "*") Some(Seq(col)) else None buildReplaceDataFrame(cols, buildReplacement(replacement)) } - /** - * Replaces values matching keys in `replacement` map with the corresponding values. - * - * {{{ - * import com.google.common.collect.ImmutableMap; - * - * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); - * }}} - * - * @param cols - * list of columns to apply the value replacement. If `col` is "*", replacement is applied on - * all string, numeric or boolean columns. - * @param replacement - * value replacement map. Key and value of `replacement` map must have the same type, and can - * only be doubles, strings or booleans. The map value can have nulls. - * @since 3.4.0 - */ - def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { - replace(cols.toImmutableArraySeq, replacement.asScala.toMap) - } - - /** - * (Scala-specific) Replaces values matching keys in `replacement` map. - * - * {{{ - * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); - * }}} - * - * @param cols - * list of columns to apply the value replacement. If `col` is "*", replacement is applied on - * all string, numeric or boolean columns. - * @param replacement - * value replacement map. Key and value of `replacement` map must have the same type, and can - * only be doubles, strings or booleans. The map value can have nulls. - * @since 3.4.0 - */ + /** @inheritdoc */ def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { buildReplaceDataFrame(Some(cols), buildReplacement(replacement)) } @@ -440,4 +155,59 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: case v => throw new IllegalArgumentException(s"Unsupported value type ${v.getClass.getName} ($v).") } + + /** @inheritdoc */ + override def drop(): DataFrame = super.drop() + + /** @inheritdoc */ + override def drop(cols: Array[String]): DataFrame = super.drop(cols) + + /** @inheritdoc */ + override def drop(cols: Seq[String]): DataFrame = super.drop(cols) + + /** @inheritdoc */ + override def drop(how: String, cols: Array[String]): DataFrame = super.drop(how, cols) + + /** @inheritdoc */ + override def drop(minNonNulls: Int, cols: Array[String]): DataFrame = + super.drop(minNonNulls, cols) + + /** @inheritdoc */ + override def drop(how: String): DataFrame = super.drop(how) + + /** @inheritdoc */ + override def drop(how: String, cols: Seq[String]): DataFrame = super.drop(how, cols) + + /** @inheritdoc */ + override def drop(minNonNulls: Int): DataFrame = super.drop(minNonNulls) + + /** @inheritdoc */ + override def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = + super.drop(minNonNulls, cols) + + /** @inheritdoc */ + override def fill(value: Long, cols: Array[String]): DataFrame = super.fill(value, cols) + + /** @inheritdoc */ + override def fill(value: Double, cols: Array[String]): DataFrame = super.fill(value, cols) + + /** @inheritdoc */ + override def fill(value: String, cols: Array[String]): DataFrame = super.fill(value, cols) + + /** @inheritdoc */ + override def fill(value: Boolean, cols: Array[String]): DataFrame = super.fill(value, cols) + + /** @inheritdoc */ + override def fill(valueMap: java.util.Map[String, Any]): DataFrame = super.fill(valueMap) + + /** @inheritdoc */ + override def fill(valueMap: Map[String, Any]): DataFrame = super.fill(valueMap) + + /** @inheritdoc */ + override def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = + super.replace[T](col, replacement) + + /** @inheritdoc */ + override def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = + super.replace(cols, replacement) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1ad98dc91b216..1fbc887901ecc 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,12 +22,12 @@ import java.util.Properties import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable +import org.apache.spark.api.java.JavaRDD import org.apache.spark.connect.proto.Parse.ParseFormat -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.connect.ConnectClientUnsupportedErrors +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter -import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.StructType /** @@ -37,144 +37,44 @@ import org.apache.spark.sql.types.StructType * @since 3.4.0 */ @Stable -class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging { - - /** - * Specifies the input data source format. - * - * @since 3.4.0 - */ - def format(source: String): DataFrameReader = { - this.source = source - this - } +class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader { + type DS[U] = Dataset[U] - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can skip - * the schema inference step, and thus speed up data loading. - * - * @since 3.4.0 - */ - def schema(schema: StructType): DataFrameReader = { - if (schema != null) { - val replaced = SparkCharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - this.userSpecifiedSchema = Option(replaced) - } - this - } + /** @inheritdoc */ + override def format(source: String): this.type = super.format(source) - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) - * can infer the input schema automatically from data. By specifying the schema here, the - * underlying data source can skip the schema inference step, and thus speed up data loading. - * - * {{{ - * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") - * }}} - * - * @since 3.4.0 - */ - def schema(schemaString: String): DataFrameReader = { - schema(StructType.fromDDL(schemaString)) - } + /** @inheritdoc */ + override def schema(schema: StructType): this.type = super.schema(schema) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: String): DataFrameReader = { - this.extraOptions = this.extraOptions + (key -> value) - this - } + /** @inheritdoc */ + override def schema(schemaString: String): this.type = super.schema(schemaString) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Long): DataFrameReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Double): DataFrameReader = option(key, value.toString) - - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def options(options: scala.collection.Map[String, String]): DataFrameReader = { - this.extraOptions ++= options - this - } + /** @inheritdoc */ + override def option(key: String, value: String): this.type = super.option(key, value) - /** - * Adds input options for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def options(options: java.util.Map[String, String]): DataFrameReader = { - this.options(options.asScala) - this - } + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) - /** - * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external - * key-value stores). - * - * @since 3.4.0 - */ - def load(): DataFrame = { - load(Seq.empty: _*) // force invocation of `load(...varargs...)` - } + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) - /** - * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a - * local or distributed file system). - * - * @since 3.4.0 - */ - def load(path: String): DataFrame = { - // force invocation of `load(...varargs...)` - load(Seq(path): _*) - } + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = + super.options(options) + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = super.options(options) - /** - * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if - * the source is a HadoopFsRelationProvider. - * - * @since 3.4.0 - */ + /** @inheritdoc */ + override def load(): DataFrame = load(Nil: _*) + + /** @inheritdoc */ + def load(path: String): DataFrame = load(Seq(path): _*) + + /** @inheritdoc */ @scala.annotation.varargs def load(paths: String*): DataFrame = { sparkSession.newDataFrame { builder => @@ -190,93 +90,29 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table and connection properties. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ - def jdbc(url: String, table: String, properties: Properties): DataFrame = { - // properties should override settings in extraOptions. - this.extraOptions ++= properties.asScala - // explicit url and dbtable should override all - this.extraOptions ++= Seq("url" -> url, "dbtable" -> table) - format("jdbc").load() - } + /** @inheritdoc */ + override def jdbc(url: String, table: String, properties: Properties): DataFrame = + super.jdbc(url, table, properties) - // scalastyle:off line.size.limit - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table. Partitions of the table will be retrieved in parallel based on the parameters passed - * to this function. - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @param table - * Name of the table in the external database. - * @param columnName - * Alias of `partitionColumn` option. Refer to `partitionColumn` in - * Data Source Option in the version you use. - * @param connectionProperties - * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least - * a "user" and "password" property should be included. "fetchsize" can be used to control the - * number of rows per fetch and "queryTimeout" can be used to wait for a Statement object to - * execute to the given number of seconds. - * @since 3.4.0 - */ - // scalastyle:on line.size.limit - def jdbc( + /** @inheritdoc */ + override def jdbc( url: String, table: String, columnName: String, lowerBound: Long, upperBound: Long, numPartitions: Int, - connectionProperties: Properties): DataFrame = { - // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. - this.extraOptions ++= Map( - "partitionColumn" -> columnName, - "lowerBound" -> lowerBound.toString, - "upperBound" -> upperBound.toString, - "numPartitions" -> numPartitions.toString) - jdbc(url, table, connectionProperties) - } - - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table using connection properties. The `predicates` parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`. - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @param table - * Name of the table in the external database. - * @param predicates - * Condition in the where clause for each partition. - * @param connectionProperties - * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least - * a "user" and "password" property should be included. "fetchsize" can be used to control the - * number of rows per fetch. - * @since 3.4.0 - */ + connectionProperties: Properties): DataFrame = + super.jdbc( + url, + table, + columnName, + lowerBound, + upperBound, + numPartitions, + connectionProperties) + + /** @inheritdoc */ def jdbc( url: String, table: String, @@ -296,207 +132,64 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Loads a JSON file and returns the results as a `DataFrame`. - * - * See the documentation on the overloaded `json()` method with varargs for more details. - * - * @since 3.4.0 - */ - def json(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - json(Seq(path): _*) - } + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) - /** - * Loads JSON files and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can find the JSON-specific options for reading JSON files in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def json(paths: String*): DataFrame = { - format("json").load(paths: _*) - } + override def json(paths: String*): DataFrame = super.json(paths: _*) - /** - * Loads a `Dataset[String]` storing JSON objects (JSON Lines - * text format or newline-delimited JSON) and returns the result as a `DataFrame`. - * - * Unless the schema is specified using `schema` function, this function goes through the input - * once to determine the input schema. - * - * @param jsonDataset - * input Dataset with one JSON object per record - * @since 3.4.0 - */ + /** @inheritdoc */ def json(jsonDataset: Dataset[String]): DataFrame = parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON) - /** - * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other - * overloaded `csv()` method for more details. - * - * @since 3.4.0 - */ - def csv(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - csv(Seq(path): _*) - } + /** @inheritdoc */ + override def json(jsonRDD: JavaRDD[String]): Dataset[Row] = + throw ConnectClientUnsupportedErrors.rdd() - /** - * Loads CSV files and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can find the CSV-specific options for reading CSV files in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ + override def json(jsonRDD: RDD[String]): Dataset[Row] = + throw ConnectClientUnsupportedErrors.rdd() + + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) + + /** @inheritdoc */ @scala.annotation.varargs - def csv(paths: String*): DataFrame = format("csv").load(paths: _*) - - /** - * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. - * - * If the schema is not specified using `schema` function and `inferSchema` option is enabled, - * this function goes through the input once to determine the input schema. - * - * If the schema is not specified using `schema` function and `inferSchema` option is disabled, - * it determines the columns as string types and it reads only the first line to determine the - * names and the number of fields. - * - * If the enforceSchema is set to `false`, only the CSV header in the first line is checked to - * conform specified or inferred schema. - * - * @note - * if `header` option is set to `true` when calling this API, all lines same with the header - * will be removed if exists. - * @param csvDataset - * input Dataset with one CSV row per record - * @since 3.4.0 - */ + override def csv(paths: String*): DataFrame = super.csv(paths: _*) + + /** @inheritdoc */ def csv(csvDataset: Dataset[String]): DataFrame = parse(csvDataset, ParseFormat.PARSE_FORMAT_CSV) - /** - * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other - * overloaded `xml()` method for more details. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - xml(Seq(path): _*) - } + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) - /** - * Loads XML files and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can find the XML-specific options for reading XML files in - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def xml(paths: String*): DataFrame = format("xml").load(paths: _*) - - /** - * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`. - * - * If the schema is not specified using `schema` function and `inferSchema` option is enabled, - * this function goes through the input once to determine the input schema. - * - * @param xmlDataset - * input Dataset with one XML object per record - * @since 4.0.0 - */ + override def xml(paths: String*): DataFrame = super.xml(paths: _*) + + /** @inheritdoc */ def xml(xmlDataset: Dataset[String]): DataFrame = parse(xmlDataset, ParseFormat.PARSE_FORMAT_UNSPECIFIED) - /** - * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the - * other overloaded `parquet()` method for more details. - * - * @since 3.4.0 - */ - def parquet(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - parquet(Seq(path): _*) - } + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) - /** - * Loads a Parquet file, returning the result as a `DataFrame`. - * - * Parquet-specific option(s) for reading Parquet files can be found in Data - * Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def parquet(paths: String*): DataFrame = { - format("parquet").load(paths: _*) - } + override def parquet(paths: String*): DataFrame = super.parquet(paths: _*) - /** - * Loads an ORC file and returns the result as a `DataFrame`. - * - * @param path - * input path - * @since 3.4.0 - */ - def orc(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - orc(Seq(path): _*) - } + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) - /** - * Loads ORC files and returns the result as a `DataFrame`. - * - * ORC-specific option(s) for reading ORC files can be found in Data - * Source Option in the version you use. - * - * @param paths - * input paths - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def orc(paths: String*): DataFrame = format("orc").load(paths: _*) - - /** - * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch - * reading and the returned DataFrame is the batch scan query plan of this table. If it's a - * view, the returned DataFrame is simply the query plan of the view, which can either be a - * batch or streaming query plan. - * - * @param tableName - * is either a qualified or unqualified name that designates a table or view. If a database is - * specified, it identifies the table/view from the database. Otherwise, it first attempts to - * find a temporary view with the given name and then match the table/view from the current - * database. Note that, the global temporary view database is also valid here. - * @since 3.4.0 - */ + override def orc(paths: String*): DataFrame = super.orc(paths: _*) + + /** @inheritdoc */ def table(tableName: String): DataFrame = { + assertNoSpecifiedSchema("table") sparkSession.newDataFrame { builder => builder.getReadBuilder.getNamedTableBuilder .setUnparsedIdentifier(tableName) @@ -504,80 +197,19 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. See the documentation on the - * other overloaded `text()` method for more details. - * - * @since 3.4.0 - */ - def text(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - text(Seq(path): _*) - } + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. The text files must be encoded - * as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.read.text("/path/to/spark/README.md") - * - * // Java: - * spark.read().text("/path/to/spark/README.md") - * }}} - * - * You can find the text-specific options for reading text files in - * Data Source Option in the version you use. - * - * @param paths - * input paths - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def text(paths: String*): DataFrame = format("text").load(paths: _*) - - /** - * Loads text files and returns a [[Dataset]] of String. See the documentation on the other - * overloaded `textFile()` method for more details. - * @since 3.4.0 - */ - def textFile(path: String): Dataset[String] = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - textFile(Seq(path): _*) - } + override def text(paths: String*): DataFrame = super.text(paths: _*) - /** - * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset - * contains a single string column named "value". The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.read.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.read().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataFrameReader.text`. - * - * @param paths - * input path - * @since 3.4.0 - */ + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) + + /** @inheritdoc */ @scala.annotation.varargs - def textFile(paths: String*): Dataset[String] = { - assertNoSpecifiedSchema("textFile") - text(paths: _*).select("value").as(StringEncoder) - } + override def textFile(paths: String*): Dataset[String] = super.textFile(paths: _*) private def assertSourceFormatSpecified(): Unit = { if (source == null) { @@ -597,24 +229,4 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } } - - /** - * A convenient function for schema validation in APIs. - */ - private def assertNoSpecifiedSchema(operation: String): Unit = { - if (userSpecifiedSchema.nonEmpty) { - throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation) - } - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = _ - - private var userSpecifiedSchema: Option[StructType] = None - - private var extraOptions = CaseInsensitiveMap[String](Map.empty) - } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 4eef26da706f0..bb7cfa75a9ab9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,88 +18,24 @@ package org.apache.spark.sql import java.{lang => jl, util => ju} -import java.io.ByteArrayInputStream - -import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.functions.lit -import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * Statistic functions for `DataFrame`s. * * @since 3.4.0 */ -final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, root: Relation) { - - /** - * Calculates the approximate quantiles of a numerical column of a DataFrame. - * - * The result of this algorithm has the following deterministic bound: If the DataFrame has N - * elements and if we request the quantile at probability `p` up to error `err`, then the - * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is - * close to (p * N). More precisely, - * - * {{{ - * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N) - * }}} - * - * This method implements a variation of the Greenwald-Khanna algorithm (with some speed - * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile - * Summaries by Greenwald and Khanna. - * - * @param col - * the name of the numerical column - * @param probabilities - * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the - * minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError - * The relative target precision to achieve (greater than or equal to 0). If set to zero, the - * exact quantiles are computed, which could be very expensive. Note that values greater than - * 1 are accepted but give the same result as 1. - * @return - * the approximate quantiles at the given probabilities - * - * @note - * null and NaN values will be removed from the numerical column before calculation. If the - * dataframe is empty or the column only contains null or NaN, an empty array is returned. - * - * @since 3.4.0 - */ - def approxQuantile( - col: String, - probabilities: Array[Double], - relativeError: Double): Array[Double] = { - approxQuantile(Array(col), probabilities, relativeError).head - } +final class DataFrameStatFunctions private[sql] (protected val df: DataFrame) + extends api.DataFrameStatFunctions { + private def root: Relation = df.plan.getRoot + private val sparkSession: SparkSession = df.sparkSession - /** - * Calculates the approximate quantiles of numerical columns of a DataFrame. - * @see - * `approxQuantile(col:Str* approxQuantile)` for detailed description. - * - * @param cols - * the names of the numerical columns - * @param probabilities - * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the - * minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError - * The relative target precision to achieve (greater than or equal to 0). If set to zero, the - * exact quantiles are computed, which could be very expensive. Note that values greater than - * 1 are accepted but give the same result as 1. - * @return - * the approximate quantiles at the given probabilities of each column - * - * @note - * null and NaN values will be ignored in numerical columns before calculation. For columns - * only containing null or NaN values, an empty array is returned. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def approxQuantile( cols: Array[String], probabilities: Array[Double], @@ -119,24 +55,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculate the sample covariance of two numerical columns of a DataFrame. - * @param col1 - * the name of the first column - * @param col2 - * the name of the second column - * @return - * the covariance of the two columns. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.cov("rand1", "rand2") - * res1: Double = 0.065... - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def cov(col1: String, col2: String): Double = { sparkSession .newDataset(PrimitiveDoubleEncoder) { builder => @@ -145,27 +64,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in - * MLlib's Statistics. - * - * @param col1 - * the name of the column - * @param col2 - * the name of the column to calculate the correlation against - * @return - * The Pearson Correlation Coefficient as a Double. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.corr("rand1", "rand2") - * res1: Double = 0.613... - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def corr(col1: String, col2: String, method: String): Double = { require( method == "pearson", @@ -178,289 +77,48 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. - * - * @param col1 - * the name of the column - * @param col2 - * the name of the column to calculate the correlation against - * @return - * The Pearson Correlation Coefficient as a Double. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.corr("rand1", "rand2", "pearson") - * res1: Double = 0.613... - * }}} - * - * @since 3.4.0 - */ - def corr(col1: String, col2: String): Double = { - corr(col1, col2, "pearson") - } - - /** - * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. - * The first column of each row will be the distinct values of `col1` and the column names will - * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts - * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts. - * Null elements will be replaced by "null", and back ticks will be dropped from elements if - * they exist. - * - * @param col1 - * The name of the first column. Distinct items will make the first item of each row. - * @param col2 - * The name of the second column. Distinct items will make the column names of the DataFrame. - * @return - * A DataFrame containing for the contingency table. - * - * {{{ - * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))) - * .toDF("key", "value") - * val ct = df.stat.crosstab("key", "value") - * ct.show() - * +---------+---+---+---+ - * |key_value| 1| 2| 3| - * +---------+---+---+---+ - * | 2| 2| 0| 1| - * | 1| 1| 1| 0| - * | 3| 0| 1| 1| - * +---------+---+---+---+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def crosstab(col1: String, col2: String): DataFrame = { sparkSession.newDataFrame { builder => builder.getCrosstabBuilder.setInput(root).setCol1(col1).setCol2(col2) } } - /** - * Finding frequent items for columns, possibly with false positives. Using the frequent element - * count algorithm described in here, - * proposed by Karp, Schenker, and Papadimitriou. The `support` should be greater than 1e-4. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @param support - * The minimum frequency for an item to be considered `frequent`. Should be greater than 1e-4. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * {{{ - * val rows = Seq.tabulate(100) { i => - * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) - * } - * val df = spark.createDataFrame(rows).toDF("a", "b") - * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns - * // "a" and "b" - * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4) - * freqSingles.show() - * +-----------+-------------+ - * |a_freqItems| b_freqItems| - * +-----------+-------------+ - * | [1, 99]|[-1.0, -99.0]| - * +-----------+-------------+ - * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" - * val pairDf = df.select(struct("a", "b").as("a-b")) - * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1) - * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() - * +----------+ - * | freq_ab| - * +----------+ - * | [1,-1.0]| - * | ... | - * +----------+ - * }}} - * - * @since 3.4.0 - */ - def freqItems(cols: Array[String], support: Double): DataFrame = { - sparkSession.newDataFrame { builder => - val freqItemsBuilder = builder.getFreqItemsBuilder.setInput(root).setSupport(support) - cols.foreach(freqItemsBuilder.addCols) - } - } + /** @inheritdoc */ + override def freqItems(cols: Array[String], support: Double): DataFrame = + super.freqItems(cols, support) - /** - * Finding frequent items for columns, possibly with false positives. Using the frequent element - * count algorithm described in here, - * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support of 1%. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * @since 3.4.0 - */ - def freqItems(cols: Array[String]): DataFrame = { - freqItems(cols, 0.01) - } + /** @inheritdoc */ + override def freqItems(cols: Array[String]): DataFrame = super.freqItems(cols) + + /** @inheritdoc */ + override def freqItems(cols: Seq[String]): DataFrame = super.freqItems(cols) - /** - * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in here, proposed by Karp, Schenker, and - * Papadimitriou. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * {{{ - * val rows = Seq.tabulate(100) { i => - * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) - * } - * val df = spark.createDataFrame(rows).toDF("a", "b") - * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns - * // "a" and "b" - * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4) - * freqSingles.show() - * +-----------+-------------+ - * |a_freqItems| b_freqItems| - * +-----------+-------------+ - * | [1, 99]|[-1.0, -99.0]| - * +-----------+-------------+ - * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" - * val pairDf = df.select(struct("a", "b").as("a-b")) - * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1) - * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() - * +----------+ - * | freq_ab| - * +----------+ - * | [1,-1.0]| - * | ... | - * +----------+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def freqItems(cols: Seq[String], support: Double): DataFrame = { - freqItems(cols.toArray, support) + df.sparkSession.newDataFrame { builder => + val freqItemsBuilder = builder.getFreqItemsBuilder + .setInput(df.plan.getRoot) + .setSupport(support) + cols.foreach(freqItemsBuilder.addCols) + } } - /** - * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in here, proposed by Karp, Schenker, and - * Papadimitriou. Uses a `default` support of 1%. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * @since 3.4.0 - */ - def freqItems(cols: Seq[String]): DataFrame = { - freqItems(cols.toArray, 0.01) - } + /** @inheritdoc */ + override def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * {{{ - * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), - * (3, 3))).toDF("key", "value") - * val fractions = Map(1 -> 1.0, 3 -> 0.5) - * df.stat.sampleBy("key", fractions, 36L).show() - * +---+-----+ - * |key|value| - * +---+-----+ - * | 1| 1| - * | 1| 2| - * | 3| 2| - * +---+-----+ - * }}} - * - * @since 3.4.0 - */ - def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { - sampleBy(Column(col), fractions, seed) - } + /** @inheritdoc */ + override def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * @since 3.4.0 - */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { - sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) - } + /** @inheritdoc */ + override def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * The stratified sample can be performed over multiple columns: - * {{{ - * import org.apache.spark.sql.Row - * import org.apache.spark.sql.functions.struct - * - * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17), - * ("Alice", 10))).toDF("name", "age") - * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0) - * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show() - * +-----+---+ - * | name|age| - * +-----+---+ - * | Nico| 8| - * |Alice| 10| - * +-----+---+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { + import sparkSession.RichColumn require( fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") @@ -478,180 +136,6 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo } } } - - /** - * (Java-specific) Returns a stratified sample without replacement based on the fraction given - * on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * @since 3.4.0 - */ - def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { - sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param colName - * name of the column over which the sketch is built - * @param depth - * depth of the sketch - * @param width - * width of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { - countMinSketch(Column(colName), depth, width, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param colName - * name of the column over which the sketch is built - * @param eps - * relative error of the sketch - * @param confidence - * confidence of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch( - colName: String, - eps: Double, - confidence: Double, - seed: Int): CountMinSketch = { - countMinSketch(Column(colName), eps, confidence, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param col - * the column over which the sketch is built - * @param depth - * depth of the sketch - * @param width - * width of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { - countMinSketch(col, eps = 2.0 / width, confidence = 1 - 1 / Math.pow(2, depth), seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param col - * the column over which the sketch is built - * @param eps - * relative error of the sketch - * @param confidence - * confidence of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { - val agg = Column.fn("count_min_sketch", col, lit(eps), lit(confidence), lit(seed)) - val ds = sparkSession.newDataset(BinaryEncoder) { builder => - builder.getProjectBuilder - .setInput(root) - .addExpressions(agg.expr) - } - CountMinSketch.readFrom(ds.head()) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param colName - * name of the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param fpp - * expected false positive probability of the filter. - * @since 3.5.0 - */ - def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { - bloomFilter(Column(colName), expectedNumItems, fpp) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param col - * the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param fpp - * expected false positive probability of the filter. - * @since 3.5.0 - */ - def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { - val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp) - bloomFilter(col, expectedNumItems, numBits) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param colName - * name of the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param numBits - * expected number of bits of the filter. - * @since 3.5.0 - */ - def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { - bloomFilter(Column(colName), expectedNumItems, numBits) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param col - * the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param numBits - * expected number of bits of the filter. - * @since 3.5.0 - */ - def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { - val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)) - val ds = sparkSession.newDataset(BinaryEncoder) { builder => - builder.getProjectBuilder - .setInput(root) - .addExpressions(agg.expr) - } - BloomFilter.readFrom(new ByteArrayInputStream(ds.head())) - } } private object DataFrameStatFunctions { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index d6cea3723f956..631e9057f8d15 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql -import java.util.{Collections, Locale} +import java.util import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -26,17 +26,23 @@ import scala.util.control.NonFatal import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils +import org.apache.spark.sql.connect.ConnectClientUnsupportedErrors +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkResult -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter, UdfUtils} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId -import org.apache.spark.sql.expressions.ScalaUserDefinedFunction +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.{struct, to_json} +import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex} import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel @@ -133,11 +139,15 @@ class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) - extends Serializable { + extends api.Dataset[T] { + type DS[U] = Dataset[U] + + import sparkSession.RichColumn + // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) - private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder) + private[sql] val agnosticEncoder: AgnosticEncoder[T] = agnosticEncoderFor(encoder) override def toString: String = { try { @@ -161,35 +171,10 @@ class Dataset[T] private[sql] ( } } - /** - * Converts this strongly typed collection of data to generic Dataframe. In contrast to the - * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] - * objects that allow fields to be accessed by ordinal or name. - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def toDF(): DataFrame = new Dataset(sparkSession, plan, UnboundRowEncoder) - /** - * Returns a new Dataset where each record has been mapped on to the specified type. The method - * used to map columns depend on the type of `U`:
  • When `U` is a class, fields for the - * class will be mapped to columns of the same name (case sensitivity is determined by - * `spark.sql.caseSensitive`).
  • When `U` is a tuple, the columns will be mapped by - * ordinal (i.e. the first column will be assigned to `_1`).
  • When `U` is a primitive - * type (i.e. String, Int, etc), then the first column of the `DataFrame` will be used.
  • - *
- * - * If the schema of the Dataset does not match the desired `U` type, you can use `select` along - * with `alias` or `as` to rearrange or rename as required. - * - * Note that `as[]` only changes the view of the data that is passed into typed operations, such - * as `map()`, and does not eagerly project away any columns that are not present in the - * specified class. - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def as[U: Encoder]: Dataset[U] = { val encoder = implicitly[Encoder[U]].asInstanceOf[AgnosticEncoder[U]] // We should add some validation/coercion here. We cannot use `to` @@ -197,19 +182,7 @@ class Dataset[T] private[sql] ( new Dataset[U](sparkSession, plan, encoder) } - /** - * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. - * This can be quite convenient in conversion from an RDD of tuples into a `DataFrame` with - * meaningful names. For example: - * {{{ - * val rdd: RDD[(Int, String)] = ... - * rdd.toDF() // this implicit conversion creates a DataFrame with column name `_1` and `_2` - * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name" - * }}} - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def toDF(colNames: String*): DataFrame = sparkSession.newDataFrame { builder => builder.getToDfBuilder @@ -217,34 +190,14 @@ class Dataset[T] private[sql] ( .addAllColumnNames(colNames.asJava) } - /** - * Returns a new DataFrame where each row is reconciled to match the specified schema. Spark - * will:
  • Reorder columns and/or inner fields by name to match the specified - * schema.
  • Project away columns and/or inner fields that are not needed by the - * specified schema. Missing columns and/or inner fields (present in the specified schema but - * not input DataFrame) lead to failures.
  • Cast the columns and/or inner fields to match - * the data types in the specified schema, if the types are compatible, e.g., numeric to numeric - * (error if overflows), but not string to int.
  • Carry over the metadata from the - * specified schema, while the columns and/or inner fields still keep their own metadata if not - * overwritten by the specified schema.
  • Fail if the nullability is not compatible. For - * example, the column and/or inner field is nullable but the specified schema requires them to - * be not nullable.
- * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def to(schema: StructType): DataFrame = sparkSession.newDataFrame { builder => builder.getToSchemaBuilder .setInput(plan.getRoot) .setSchema(DataTypeProtoConverter.toConnectProtoType(schema)) } - /** - * Returns the schema of this Dataset. - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def schema: StructType = cachedSchema /** @@ -266,38 +219,9 @@ class Dataset[T] private[sql] ( .asInstanceOf[StructType] } - /** - * Prints the schema to the console in a nice tree format. - * - * @group basic - * @since 3.4.0 - */ - def printSchema(): Unit = printSchema(Int.MaxValue) - - // scalastyle:off println - /** - * Prints the schema up to the given level to the console in a nice tree format. - * - * @group basic - * @since 3.4.0 - */ - def printSchema(level: Int): Unit = println(schema.treeString(level)) - // scalastyle:on println - - /** - * Prints the plans (logical and physical) with a format specified by a given explain mode. - * - * @param mode - * specifies the expected output format of plans.
  • `simple` Print only a physical - * plan.
  • `extended`: Print both logical and physical plans.
  • `codegen`: Print - * a physical plan and generated codes if they are available.
  • `cost`: Print a logical - * plan and statistics if they are available.
  • `formatted`: Split explain output into - * two sections: a physical plan outline and node details.
- * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def explain(mode: String): Unit = { - val protoMode = mode.trim.toLowerCase(Locale.ROOT) match { + val protoMode = mode.trim.toLowerCase(util.Locale.ROOT) match { case "simple" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE case "extended" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED case "codegen" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_CODEGEN @@ -308,32 +232,6 @@ class Dataset[T] private[sql] ( explain(protoMode) } - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * - * @param extended - * default `false`. If `false`, prints only the physical plan. - * - * @group basic - * @since 3.4.0 - */ - def explain(extended: Boolean): Unit = { - val mode = if (extended) { - proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED - } else { - proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE - } - explain(mode) - } - - /** - * Prints the physical plan to the console for debugging purposes. - * - * @group basic - * @since 3.4.0 - */ - def explain(): Unit = explain(proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE) - private def explain(mode: proto.AnalyzePlanRequest.Explain.ExplainMode): Unit = { // scalastyle:off println println( @@ -344,198 +242,31 @@ class Dataset[T] private[sql] ( // scalastyle:on println } - /** - * Returns all column names and their data types as an array. - * - * @group basic - * @since 3.4.0 - */ - def dtypes: Array[(String, String)] = schema.fields.map { field => - (field.name, field.dataType.toString) - } - - /** - * Returns all column names as an array. - * - * @group basic - * @since 3.4.0 - */ - def columns: Array[String] = schema.fields.map(_.name) - - /** - * Returns true if the `collect` and `take` methods can be run locally (without any Spark - * executors). - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def isLocal: Boolean = sparkSession .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL) .getIsLocal .getIsLocal - /** - * Returns true if the `Dataset` is empty. - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def isEmpty: Boolean = select().limit(1).withResult { result => result.length == 0 } - /** - * Returns true if this Dataset contains one or more sources that continuously return data as it - * arrives. A Dataset that reads data from a streaming source must be executed as a - * `StreamingQuery` using the `start()` method in `DataStreamWriter`. - * - * @group streaming - * @since 3.4.0 - */ + /** @inheritdoc */ def isStreaming: Boolean = sparkSession .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING) .getIsStreaming .getIsStreaming - /** - * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, - * and all cells will be aligned right. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * - * @param numRows - * Number of rows to show - * - * @group action - * @since 3.4.0 - */ - def show(numRows: Int): Unit = show(numRows, truncate = true) - - /** - * Displays the top 20 rows of Dataset in a tabular form. Strings more than 20 characters will - * be truncated, and all cells will be aligned right. - * - * @group action - * @since 3.4.0 - */ - def show(): Unit = show(20) - - /** - * Displays the top 20 rows of Dataset in a tabular form. - * - * @param truncate - * Whether truncate long strings. If true, strings more than 20 characters will be truncated - * and all cells will be aligned right - * - * @group action - * @since 3.4.0 - */ - def show(truncate: Boolean): Unit = show(20, truncate) - - /** - * Displays the Dataset in a tabular form. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * @param numRows - * Number of rows to show - * @param truncate - * Whether truncate long strings. If true, strings more than 20 characters will be truncated - * and all cells will be aligned right - * - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ // scalastyle:off println def show(numRows: Int, truncate: Boolean): Unit = { val truncateValue = if (truncate) 20 else 0 show(numRows, truncateValue, vertical = false) } - /** - * Displays the Dataset in a tabular form. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * - * @param numRows - * Number of rows to show - * @param truncate - * If set to more than 0, truncates strings to `truncate` characters and all cells will be - * aligned right. - * @group action - * @since 3.4.0 - */ - def show(numRows: Int, truncate: Int): Unit = show(numRows, truncate, vertical = false) - - /** - * Displays the Dataset in a tabular form. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * - * If `vertical` enabled, this command prints output rows vertically (one line per column - * value)? - * - * {{{ - * -RECORD 0------------------- - * year | 1980 - * month | 12 - * AVG('Adj Close) | 0.503218 - * AVG('Adj Close) | 0.595103 - * -RECORD 1------------------- - * year | 1981 - * month | 01 - * AVG('Adj Close) | 0.523289 - * AVG('Adj Close) | 0.570307 - * -RECORD 2------------------- - * year | 1982 - * month | 02 - * AVG('Adj Close) | 0.436504 - * AVG('Adj Close) | 0.475256 - * -RECORD 3------------------- - * year | 1983 - * month | 03 - * AVG('Adj Close) | 0.410516 - * AVG('Adj Close) | 0.442194 - * -RECORD 4------------------- - * year | 1984 - * month | 04 - * AVG('Adj Close) | 0.450090 - * AVG('Adj Close) | 0.483521 - * }}} - * - * @param numRows - * Number of rows to show - * @param truncate - * If set to more than 0, truncates strings to `truncate` characters and all cells will be - * aligned right. - * @param vertical - * If set to true, prints output rows vertically (one line per column value). - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = { val df = sparkSession.newDataset(StringEncoder) { builder => builder.getShowStringBuilder @@ -547,35 +278,15 @@ class Dataset[T] private[sql] ( df.withResult { result => assert(result.length == 1) assert(result.schema.size == 1) - // scalastyle:off println - println(result.toArray.head) - // scalastyle:on println + print(result.toArray.head) } } - /** - * Returns a [[DataFrameNaFunctions]] for working with missing data. - * {{{ - * // Dropping rows containing any null values. - * ds.na.drop() - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def na: DataFrameNaFunctions = new DataFrameNaFunctions(sparkSession, plan.getRoot) - /** - * Returns a [[DataFrameStatFunctions]] for working statistic functions support. - * {{{ - * // Finding frequent items in column with name 'a'. - * ds.stat.freqItems(Seq("a")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def stat: DataFrameStatFunctions = new DataFrameStatFunctions(sparkSession, plan.getRoot) + /** @inheritdoc */ + def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF()) private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = { checkSameSparkSession(right) @@ -587,7 +298,7 @@ class Dataset[T] private[sql] ( } private def toJoinType(name: String, skipSemiAnti: Boolean = false): proto.Join.JoinType = { - name.trim.toLowerCase(Locale.ROOT) match { + name.trim.toLowerCase(util.Locale.ROOT) match { case "inner" => proto.Join.JoinType.JOIN_TYPE_INNER case "cross" => @@ -607,170 +318,12 @@ class Dataset[T] private[sql] ( } } - /** - * Join with another `DataFrame`. - * - * Behaves as an INNER JOIN and requires a subsequent join predicate. - * - * @param right - * Right side of the join operation. - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def join(right: Dataset[_]): DataFrame = buildJoin(right) { builder => builder.setJoinType(proto.Join.JoinType.JOIN_TYPE_INNER) } - /** - * Inner equi-join with another `DataFrame` using the given column. - * - * Different from other join functions, the join column will only appear once in the output, - * i.e. similar to SQL's `JOIN USING` syntax. - * - * {{{ - * // Joining df1 and df2 using the column "user_id" - * df1.join(df2, "user_id") - * }}} - * - * @param right - * Right side of the join operation. - * @param usingColumn - * Name of the column to join on. This column must exist on both sides. - * - * @note - * If you perform a self-join using this function without aliasing the input `DataFrame`s, you - * will NOT be able to reference any columns after the join, since there is no way to - * disambiguate which side of the join you would like to reference. - * - * @group untypedrel - * @since 3.4.0 - */ - def join(right: Dataset[_], usingColumn: String): DataFrame = { - join(right, Seq(usingColumn)) - } - - /** - * (Java-specific) Inner equi-join with another `DataFrame` using the given columns. See the - * Scala-specific overload for more details. - * - * @param right - * Right side of the join operation. - * @param usingColumns - * Names of the columns to join on. This columns must exist on both sides. - * - * @group untypedrel - * @since 3.4.0 - */ - def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = { - join(right, usingColumns.toImmutableArraySeq) - } - - /** - * (Scala-specific) Inner equi-join with another `DataFrame` using the given columns. - * - * Different from other join functions, the join columns will only appear once in the output, - * i.e. similar to SQL's `JOIN USING` syntax. - * - * {{{ - * // Joining df1 and df2 using the columns "user_id" and "user_name" - * df1.join(df2, Seq("user_id", "user_name")) - * }}} - * - * @param right - * Right side of the join operation. - * @param usingColumns - * Names of the columns to join on. This columns must exist on both sides. - * - * @note - * If you perform a self-join using this function without aliasing the input `DataFrame`s, you - * will NOT be able to reference any columns after the join, since there is no way to - * disambiguate which side of the join you would like to reference. - * - * @group untypedrel - * @since 3.4.0 - */ - def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = { - join(right, usingColumns, "inner") - } - - /** - * Equi-join with another `DataFrame` using the given column. A cross join with a predicate is - * specified as an inner join. If you would explicitly like to perform a cross join use the - * `crossJoin` method. - * - * Different from other join functions, the join column will only appear once in the output, - * i.e. similar to SQL's `JOIN USING` syntax. - * - * @param right - * Right side of the join operation. - * @param usingColumn - * Name of the column to join on. This column must exist on both sides. - * @param joinType - * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, - * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`, - * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, - * `left_anti`. - * - * @note - * If you perform a self-join using this function without aliasing the input `DataFrame`s, you - * will NOT be able to reference any columns after the join, since there is no way to - * disambiguate which side of the join you would like to reference. - * - * @group untypedrel - * @since 3.4.0 - */ - def join(right: Dataset[_], usingColumn: String, joinType: String): DataFrame = { - join(right, Seq(usingColumn), joinType) - } - - /** - * (Java-specific) Equi-join with another `DataFrame` using the given columns. See the - * Scala-specific overload for more details. - * - * @param right - * Right side of the join operation. - * @param usingColumns - * Names of the columns to join on. This columns must exist on both sides. - * @param joinType - * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, - * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`, - * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, - * `left_anti`. - * - * @group untypedrel - * @since 3.4.0 - */ - def join(right: Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = { - join(right, usingColumns.toImmutableArraySeq, joinType) - } - - /** - * (Scala-specific) Equi-join with another `DataFrame` using the given columns. A cross join - * with a predicate is specified as an inner join. If you would explicitly like to perform a - * cross join use the `crossJoin` method. - * - * Different from other join functions, the join columns will only appear once in the output, - * i.e. similar to SQL's `JOIN USING` syntax. - * - * @param right - * Right side of the join operation. - * @param usingColumns - * Names of the columns to join on. This columns must exist on both sides. - * @param joinType - * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, - * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`, - * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, - * `left_anti`. - * - * @note - * If you perform a self-join using this function without aliasing the input `DataFrame`s, you - * will NOT be able to reference any columns after the join, since there is no way to - * disambiguate which side of the join you would like to reference. - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { buildJoin(right) { builder => builder @@ -779,47 +332,7 @@ class Dataset[T] private[sql] ( } } - /** - * Inner join with another `DataFrame`, using the given join expression. - * - * {{{ - * // The following two are equivalent: - * df1.join(df2, $"df1Key" === $"df2Key") - * df1.join(df2).where($"df1Key" === $"df2Key") - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, joinExprs, "inner") - - /** - * Join with another `DataFrame`, using the given join expression. The following performs a full - * outer join between `df1` and `df2`. - * - * {{{ - * // Scala: - * import org.apache.spark.sql.functions._ - * df1.join(df2, $"df1Key" === $"df2Key", "outer") - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer"); - * }}} - * - * @param right - * Right side of the join. - * @param joinExprs - * Join expression. - * @param joinType - * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, - * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`, - * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, - * `left_anti`. - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { buildJoin(right) { builder => builder @@ -828,54 +341,12 @@ class Dataset[T] private[sql] ( } } - /** - * Explicit cartesian join with another `DataFrame`. - * - * @param right - * Right side of the join operation. - * - * @note - * Cartesian joins are very expensive without an extra filter that can be pushed down. - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def crossJoin(right: Dataset[_]): DataFrame = buildJoin(right) { builder => builder.setJoinType(proto.Join.JoinType.JOIN_TYPE_CROSS) } - private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { - sparkSession.newDataset(agnosticEncoder) { builder => - builder.getSortBuilder - .setInput(plan.getRoot) - .setIsGlobal(global) - .addAllOrder(sortExprs.map(_.sortOrder).asJava) - } - } - - /** - * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to true. - * - * This is similar to the relation `join` function with one important difference in the result - * schema. Since `joinWith` preserves objects present on either side of the join, the result - * schema is similarly nested into a tuple under the column names `_1` and `_2`. - * - * This type of join can be useful both for preserving type-safety with the original object - * types as well as working with relational data where either side of the join has column names - * in common. - * - * @param other - * Right side of the join. - * @param condition - * Join expression. - * @param joinType - * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, - * `full`, `fullouter`,`full_outer`, `left`, `leftouter`, `left_outer`, `right`, `rightouter`, - * `right_outer`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { val joinTypeValue = toJoinType(joinType, skipSemiAnti = true) val (leftNullable, rightNullable) = joinTypeValue match { @@ -912,121 +383,19 @@ class Dataset[T] private[sql] ( } } - /** - * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair where - * `condition` evaluates to true. - * - * @param other - * Right side of the join. - * @param condition - * Join expression. - * - * @group typedrel - * @since 3.5.0 - */ - def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { - joinWith(other, condition, "inner") - } - - /** - * Returns a new Dataset with each partition sorted by the given expressions. - * - * This is the same operation as "SORT BY" in SQL (Hive QL). - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = { - sortWithinPartitions((sortCol +: sortCols).map(Column(_)): _*) - } - - /** - * Returns a new Dataset with each partition sorted by the given expressions. - * - * This is the same operation as "SORT BY" in SQL (Hive QL). - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def sortWithinPartitions(sortExprs: Column*): Dataset[T] = { - buildSort(global = false, sortExprs) - } - - /** - * Returns a new Dataset sorted by the specified column, all in ascending order. - * {{{ - * // The following 3 are equivalent - * ds.sort("sortcol") - * ds.sort($"sortcol") - * ds.sort($"sortcol".asc) - * }}} - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def sort(sortCol: String, sortCols: String*): Dataset[T] = { - sort((sortCol +: sortCols).map(Column(_)): _*) - } - - /** - * Returns a new Dataset sorted by the given expressions. For example: - * {{{ - * ds.sort($"col1", $"col2".desc) - * }}} - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def sort(sortExprs: Column*): Dataset[T] = { - buildSort(global = true, sortExprs) + override protected def sortInternal(global: Boolean, sortCols: Seq[Column]): Dataset[T] = { + val sortExprs = sortCols.map { c => + ColumnNodeToProtoConverter(c.sortOrder).getSortOrder + } + sparkSession.newDataset(agnosticEncoder) { builder => + builder.getSortBuilder + .setInput(plan.getRoot) + .setIsGlobal(global) + .addAllOrder(sortExprs.asJava) + } } - /** - * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort` - * function. - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols: _*) - - /** - * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort` - * function. - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs: _*) - - /** - * Selects column based on the column name and returns it as a [[Column]]. - * - * @note - * The column name can also reference to a nested column like `a.b`. - * - * @group untypedrel - * @since 3.4.0 - */ - def apply(colName: String): Column = col(colName) - - /** - * Specifies some hint on the current Dataset. As an example, the following code specifies that - * one of the plan can be broadcasted: - * - * {{{ - * df1.join(df2.hint("broadcast")) - * }}} - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => @@ -1043,161 +412,49 @@ class Dataset[T] private[sql] ( None } - /** - * Selects column based on the column name and returns it as a [[Column]]. - * - * @note - * The column name can also reference to a nested column like `a.b`. - * - * @group untypedrel - * @since 3.4.0 - */ - def col(colName: String): Column = { - Column.apply(colName, getPlanId) - } + /** @inheritdoc */ + def col(colName: String): Column = new Column(colName, getPlanId) - /** - * Selects a metadata column based on its logical column name, and returns it as a [[Column]]. - * - * A metadata column can be accessed this way even if the underlying data source defines a data - * column with a conflicting name. - * - * @group untypedrel - * @since 3.5.0 - */ - def metadataColumn(colName: String): Column = Column { builder => - val attributeBuilder = builder.getUnresolvedAttributeBuilder - .setUnparsedIdentifier(colName) - .setIsMetadataColumn(true) - getPlanId.foreach(attributeBuilder.setPlanId) + /** @inheritdoc */ + def metadataColumn(colName: String): Column = { + Column(UnresolvedAttribute(colName, getPlanId, isMetadataColumn = true)) } - /** - * Selects column based on the column name specified as a regex and returns it as [[Column]]. - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def colRegex(colName: String): Column = { - Column { builder => - val unresolvedRegexBuilder = builder.getUnresolvedRegexBuilder.setColName(colName) - getPlanId.foreach(unresolvedRegexBuilder.setPlanId) - } + Column(UnresolvedRegex(colName, getPlanId)) } - /** - * Returns a new Dataset with an alias set. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def as(alias: String): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getSubqueryAliasBuilder .setInput(plan.getRoot) .setAlias(alias) } - /** - * (Scala-specific) Returns a new Dataset with an alias set. - * - * @group typedrel - * @since 3.4.0 - */ - def as(alias: Symbol): Dataset[T] = as(alias.name) - - /** - * Returns a new Dataset with an alias set. Same as `as`. - * - * @group typedrel - * @since 3.4.0 - */ - def alias(alias: String): Dataset[T] = as(alias) - - /** - * (Scala-specific) Returns a new Dataset with an alias set. Same as `as`. - * - * @group typedrel - * @since 3.4.0 - */ - def alias(alias: Symbol): Dataset[T] = as(alias) - - /** - * Selects a set of column based expressions. - * {{{ - * ds.select($"colA", $"colB" + 1) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def select(cols: Column*): DataFrame = selectUntyped(UnboundRowEncoder, cols).asInstanceOf[DataFrame] - /** - * Selects a set of columns. This is a variant of `select` that can only select existing columns - * using column names (i.e. cannot construct expressions). - * - * {{{ - * // The following two are equivalent: - * ds.select("colA", "colB") - * ds.select($"colA", $"colB") - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)): _*) - - /** - * Selects a set of SQL expressions. This is a variant of `select` that accepts SQL expressions. - * - * {{{ - * // The following are equivalent: - * ds.selectExpr("colA", "colB as newName", "abs(colC)") - * ds.select(expr("colA"), expr("colB as newName"), expr("abs(colC)")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def selectExpr(exprs: String*): DataFrame = { - select(exprs.map(functions.expr): _*) - } - - /** - * Returns a new Dataset by computing the given [[Column]] expression for each element. - * - * {{{ - * val ds = Seq(1, 2, 3).toDS() - * val newDS = ds.select(expr("value + 1").as[Int]) - * }}} - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { - val encoder = c1.encoder - val expr = if (encoder.schema == encoder.dataType) { - functions.inline(functions.array(c1)).expr + val encoder = agnosticEncoderFor(c1.encoder) + val col = if (encoder.schema == encoder.dataType) { + functions.inline(functions.array(c1)) } else { - c1.expr + c1 } sparkSession.newDataset(encoder) { builder => builder.getProjectBuilder .setInput(plan.getRoot) - .addExpressions(expr) + .addExpressions(col.typedExpr(this.encoder)) } } - /** - * Internal helper function for building typed selects that return tuples. For simplicity and - * code reuse, we do this without the help of the type system and then use helper functions that - * cast appropriately for the user facing interface. - */ - private def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoder = ProductEncoder.tuple(columns.map(_.encoder)) + /** @inheritdoc */ + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoder = ProductEncoder.tuple(columns.map(c => agnosticEncoderFor(c.encoder))) selectUntyped(encoder, columns) } @@ -1209,109 +466,16 @@ class Dataset[T] private[sql] ( sparkSession.newDataset(encoder) { builder => builder.getProjectBuilder .setInput(plan.getRoot) - .addAllExpressions(cols.map(_.expr).asJava) + .addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava) } } - /** - * Returns a new Dataset by computing the given [[Column]] expressions for each element. - * - * @group typedrel - * @since 3.4.0 - */ - def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = - selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] - - /** - * Returns a new Dataset by computing the given [[Column]] expressions for each element. - * - * @group typedrel - * @since 3.4.0 - */ - def select[U1, U2, U3]( - c1: TypedColumn[T, U1], - c2: TypedColumn[T, U2], - c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = - selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] - - /** - * Returns a new Dataset by computing the given [[Column]] expressions for each element. - * - * @group typedrel - * @since 3.4.0 - */ - def select[U1, U2, U3, U4]( - c1: TypedColumn[T, U1], - c2: TypedColumn[T, U2], - c3: TypedColumn[T, U3], - c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = - selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] - - /** - * Returns a new Dataset by computing the given [[Column]] expressions for each element. - * - * @group typedrel - * @since 3.4.0 - */ - def select[U1, U2, U3, U4, U5]( - c1: TypedColumn[T, U1], - c2: TypedColumn[T, U2], - c3: TypedColumn[T, U3], - c4: TypedColumn[T, U4], - c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = - selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] - - /** - * Filters rows using the given condition. - * {{{ - * // The following are equivalent: - * peopleDs.filter($"age" > 15) - * peopleDs.where($"age" > 15) - * }}} - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def filter(condition: Column): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) } - /** - * Filters rows using the given SQL expression. - * {{{ - * peopleDs.filter("age > 15") - * }}} - * - * @group typedrel - * @since 3.4.0 - */ - def filter(conditionExpr: String): Dataset[T] = filter(functions.expr(conditionExpr)) - - /** - * Filters rows using the given condition. This is an alias for `filter`. - * {{{ - * // The following are equivalent: - * peopleDs.filter($"age" > 15) - * peopleDs.where($"age" > 15) - * }}} - * - * @group typedrel - * @since 3.4.0 - */ - def where(condition: Column): Dataset[T] = filter(condition) - - /** - * Filters rows using the given SQL expression. - * {{{ - * peopleDs.where("age > 15") - * }}} - * - * @group typedrel - * @since 3.4.0 - */ - def where(conditionExpr: String): Dataset[T] = filter(conditionExpr) - private def buildUnpivot( ids: Array[Column], valuesOption: Option[Array[Column]], @@ -1320,7 +484,7 @@ class Dataset[T] private[sql] ( val unpivot = builder.getUnpivotBuilder .setInput(plan.getRoot) .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava) - .setValueColumnName(variableColumnName) + .setVariableColumnName(variableColumnName) .setValueColumnName(valueColumnName) valuesOption.foreach { values => unpivot.getValuesBuilder @@ -1328,67 +492,23 @@ class Dataset[T] private[sql] ( } } - /** - * Groups the Dataset using the specified columns, so we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + private def buildTranspose(indices: Seq[Column]): DataFrame = + sparkSession.newDataFrame { builder => + val transpose = builder.getTransposeBuilder.setInput(plan.getRoot) + indices.foreach { indexColumn => + transpose.addIndexColumns(indexColumn.expr) + } + } + + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) } - /** - * Groups the Dataset using the specified columns, so that we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * This is a variant of groupBy that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy("department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) - } - - /** - * (Scala-specific) Reduces the elements of this Dataset using the specified binary function. - * The given `func` must be commutative and associative or the result may be non-deterministic. - * - * @group action - * @since 3.5.0 - */ + /** @inheritdoc */ def reduce(func: (T, T) => T): T = { - val udf = ScalaUserDefinedFunction( + val udf = SparkUserDefinedFunction( function = func, inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil, outputEncoder = agnosticEncoder) @@ -1406,164 +526,24 @@ class Dataset[T] private[sql] ( result(0) } - /** - * (Java-specific) Reduces the elements of this Dataset using the specified binary function. The - * given `func` must be commutative and associative or the result may be non-deterministic. - * - * @group action - * @since 3.5.0 - */ - def reduce(func: ReduceFunction[T]): T = reduce(UdfUtils.mapReduceFuncToScalaFunc(func)) - - /** - * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) + KeyValueGroupedDatasetImpl[K, T](this, agnosticEncoderFor[K], func) } - /** - * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(UdfUtils.mapFunctionToScalaFunc(func))(encoder) - - /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * {{{ - * // Compute the average for all numeric columns rolled up by department and group. - * ds.rollup($"department", $"group").avg() - * - * // Compute the max age and average salary, rolled up by department and gender. - * ds.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) } - /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * This is a variant of rollup that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns rolled up by department and group. - * ds.rollup("department", "group").avg() - * - * // Compute the max age and average salary, rolled up by department and gender. - * ds.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def rollup(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) - } - - /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * ds.cube($"department", $"group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * ds.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def cube(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_CUBE) } - /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * This is a variant of cube that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * ds.cube("department", "group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * ds.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def cube(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_CUBE) - } - - /** - * Create multi-dimensional aggregation for the current Dataset using the specified grouping - * sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] for all the - * available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns group by specific grouping sets. - * ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), $"department", $"group").avg() - * - * // Compute the max age and average salary, group by specific grouping sets. - * ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", $"group").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset = { val groupingSetMsgs = groupingSets.map { groupingSet => @@ -1580,119 +560,7 @@ class Dataset[T] private[sql] ( groupingSets = Some(groupingSetMsgs)) } - /** - * (Scala-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg("age" -> "max", "salary" -> "avg") - * ds.groupBy().agg("age" -> "max", "salary" -> "avg") - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs: _*) - } - - /** - * (Scala-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(Map("age" -> "max", "salary" -> "avg")) - * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * (Java-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(Map("age" -> "max", "salary" -> "avg")) - * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(max($"age"), avg($"salary")) - * ds.groupBy().agg(max($"age"), avg($"salary")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*) - - /** - * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns - * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, - * which cannot be reversed. - * - * This function is useful to massage a DataFrame into a format where some columns are - * identifier columns ("ids"), while all other columns ("values") are "unpivoted" to the rows, - * leaving just two non-id columns, named as given by `variableColumnName` and - * `valueColumnName`. - * - * {{{ - * val df = Seq((1, 11, 12L), (2, 21, 22L)).toDF("id", "int", "long") - * df.show() - * // output: - * // +---+---+----+ - * // | id|int|long| - * // +---+---+----+ - * // | 1| 11| 12| - * // | 2| 21| 22| - * // +---+---+----+ - * - * df.unpivot(Array($"id"), Array($"int", $"long"), "variable", "value").show() - * // output: - * // +---+--------+-----+ - * // | id|variable|value| - * // +---+--------+-----+ - * // | 1| int| 11| - * // | 1| long| 12| - * // | 2| int| 21| - * // | 2| long| 22| - * // +---+--------+-----+ - * // schema: - * //root - * // |-- id: integer (nullable = false) - * // |-- variable: string (nullable = false) - * // |-- value: long (nullable = true) - * }}} - * - * When no "id" columns are given, the unpivoted DataFrame consists of only the "variable" and - * "value" columns. - * - * All "value" columns must share a least common data type. Unless they are the same data type, - * all "value" columns are cast to the nearest common data type. For instance, types - * `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType` do - * not have a common data type and `unpivot` fails with an `AnalysisException`. - * - * @param ids - * Id columns - * @param values - * Value columns to unpivot - * @param variableColumnName - * Name of the variable column - * @param valueColumnName - * Name of the value column - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def unpivot( ids: Array[Column], values: Array[Column], @@ -1701,27 +569,7 @@ class Dataset[T] private[sql] ( buildUnpivot(ids, Option(values), variableColumnName, valueColumnName) } - /** - * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns - * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, - * which cannot be reversed. - * - * @see - * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` - * - * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` where `values` - * is set to all non-id columns that exist in the DataFrame. - * - * @param ids - * Id columns - * @param variableColumnName - * Name of the variable column - * @param valueColumnName - * Name of the value column - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def unpivot( ids: Array[Column], variableColumnName: String, @@ -1729,77 +577,35 @@ class Dataset[T] private[sql] ( buildUnpivot(ids, None, variableColumnName, valueColumnName) } - /** - * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns - * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, - * which cannot be reversed. This is an alias for `unpivot`. - * - * @see - * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` - * - * @param ids - * Id columns - * @param values - * Value columns to unpivot - * @param variableColumnName - * Name of the variable column - * @param valueColumnName - * Name of the value column - * - * @group untypedrel - * @since 3.4.0 - */ - def melt( - ids: Array[Column], - values: Array[Column], - variableColumnName: String, - valueColumnName: String): DataFrame = - unpivot(ids, values, variableColumnName, valueColumnName) + /** @inheritdoc */ + def transpose(indexColumn: Column): DataFrame = + buildTranspose(Seq(indexColumn)) - /** - * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns - * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, - * which cannot be reversed. This is an alias for `unpivot`. - * - * @see - * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` - * - * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` where `values` - * is set to all non-id columns that exist in the DataFrame. - * - * @param ids - * Id columns - * @param variableColumnName - * Name of the variable column - * @param valueColumnName - * Name of the value column - * - * @group untypedrel - * @since 3.4.0 - */ - def melt(ids: Array[Column], variableColumnName: String, valueColumnName: String): DataFrame = - unpivot(ids, variableColumnName, valueColumnName) + /** @inheritdoc */ + def transpose(): DataFrame = + buildTranspose(Seq.empty) - /** - * Returns a new Dataset by taking the first `n` rows. The difference between this function and - * `head` is that `head` is an action and returns an array (by triggering query execution) while - * `limit` returns a new Dataset. - * - * @group typedrel - * @since 3.4.0 - */ + // TODO(SPARK-50134): Support scalar Subquery API in Spark Connect + // scalastyle:off not.implemented.error.usage + /** @inheritdoc */ + def scalar(): Column = { + ??? + } + + /** @inheritdoc */ + def exists(): Column = { + ??? + } + // scalastyle:on not.implemented.error.usage + + /** @inheritdoc */ def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getLimitBuilder .setInput(plan.getRoot) .setLimit(n) } - /** - * Returns a new Dataset by skipping the first `n` rows. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def offset(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getOffsetBuilder .setInput(plan.getRoot) @@ -1827,260 +633,49 @@ class Dataset[T] private[sql] ( } } - /** - * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * - * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does - * deduplication of elements), use this function followed by a [[distinct]]. - * - * Also as standard in SQL, this function resolves columns by position (not by name): - * - * {{{ - * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") - * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") - * df1.union(df2).show - * - * // output: - * // +----+----+----+ - * // |col0|col1|col2| - * // +----+----+----+ - * // | 1| 2| 3| - * // | 4| 5| 6| - * // +----+----+----+ - * }}} - * - * Notice that the column positions in the schema aren't necessarily matched with the fields in - * the strongly typed objects in a Dataset. This function resolves columns by their positions in - * the schema, not the fields in the strongly typed objects. Use [[unionByName]] to resolve - * columns by field name in the typed objects. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def union(other: Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_UNION) { builder => builder.setIsAll(true) } } - /** - * Returns a new Dataset containing union of rows in this Dataset and another Dataset. This is - * an alias for `union`. - * - * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does - * deduplication of elements), use this function followed by a [[distinct]]. - * - * Also as standard in SQL, this function resolves columns by position (not by name). - * - * @group typedrel - * @since 3.4.0 - */ - def unionAll(other: Dataset[T]): Dataset[T] = union(other) - - /** - * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * - * This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set - * union (that does deduplication of elements), use this function followed by a [[distinct]]. - * - * The difference between this function and [[union]] is that this function resolves columns by - * name (not by position): - * - * {{{ - * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") - * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") - * df1.unionByName(df2).show - * - * // output: - * // +----+----+----+ - * // |col0|col1|col2| - * // +----+----+----+ - * // | 1| 2| 3| - * // | 6| 4| 5| - * // +----+----+----+ - * }}} - * - * Note that this supports nested columns in struct and array types. Nested columns in map types - * are not currently supported. - * - * @group typedrel - * @since 3.4.0 - */ - def unionByName(other: Dataset[T]): Dataset[T] = unionByName(other, allowMissingColumns = false) - - /** - * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * - * The difference between this function and [[union]] is that this function resolves columns by - * name (not by position). - * - * When the parameter `allowMissingColumns` is `true`, the set of column names in this and other - * `Dataset` can differ; missing columns will be filled with null. Further, the missing columns - * of this `Dataset` will be added at the end in the schema of the union result: - * - * {{{ - * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") - * val df2 = Seq((4, 5, 6)).toDF("col1", "col0", "col3") - * df1.unionByName(df2, true).show - * - * // output: "col3" is missing at left df1 and added at the end of schema. - * // +----+----+----+----+ - * // |col0|col1|col2|col3| - * // +----+----+----+----+ - * // | 1| 2| 3|null| - * // | 5| 4|null| 6| - * // +----+----+----+----+ - * - * df2.unionByName(df1, true).show - * - * // output: "col2" is missing at left df2 and added at the end of schema. - * // +----+----+----+----+ - * // |col1|col0|col3|col2| - * // +----+----+----+----+ - * // | 4| 5| 6|null| - * // | 2| 1|null| 3| - * // +----+----+----+----+ - * }}} - * - * Note that this supports nested columns in struct and array types. With `allowMissingColumns`, - * missing nested columns of struct columns with the same name will also be filled with null - * values and added to the end of struct. Nested columns in map types are not currently - * supported. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_UNION) { builder => builder.setByName(true).setIsAll(true).setAllowMissingColumns(allowMissingColumns) } } - /** - * Returns a new Dataset containing rows only in both this Dataset and another Dataset. This is - * equivalent to `INTERSECT` in SQL. - * - * @note - * Equality checking is performed directly on the encoded representation of the data and thus - * is not affected by a custom `equals` function defined on `T`. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def intersect(other: Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_INTERSECT) { builder => builder.setIsAll(false) } } - /** - * Returns a new Dataset containing rows only in both this Dataset and another Dataset while - * preserving the duplicates. This is equivalent to `INTERSECT ALL` in SQL. - * - * @note - * Equality checking is performed directly on the encoded representation of the data and thus - * is not affected by a custom `equals` function defined on `T`. Also as standard in SQL, this - * function resolves columns by position (not by name). - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def intersectAll(other: Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_INTERSECT) { builder => builder.setIsAll(true) } } - /** - * Returns a new Dataset containing rows in this Dataset but not in another Dataset. This is - * equivalent to `EXCEPT DISTINCT` in SQL. - * - * @note - * Equality checking is performed directly on the encoded representation of the data and thus - * is not affected by a custom `equals` function defined on `T`. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def except(other: Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_EXCEPT) { builder => builder.setIsAll(false) } } - /** - * Returns a new Dataset containing rows in this Dataset but not in another Dataset while - * preserving the duplicates. This is equivalent to `EXCEPT ALL` in SQL. - * - * @note - * Equality checking is performed directly on the encoded representation of the data and thus - * is not affected by a custom `equals` function defined on `T`. Also as standard in SQL, this - * function resolves columns by position (not by name). - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def exceptAll(other: Dataset[T]): Dataset[T] = { buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_EXCEPT) { builder => builder.setIsAll(true) } } - /** - * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a - * user-supplied seed. - * - * @param fraction - * Fraction of rows to generate, range [0.0, 1.0]. - * @param seed - * Seed for sampling. - * - * @note - * This is NOT guaranteed to provide exactly the fraction of the count of the given - * [[Dataset]]. - * - * @group typedrel - * @since 3.4.0 - */ - def sample(fraction: Double, seed: Long): Dataset[T] = { - sample(withReplacement = false, fraction = fraction, seed = seed) - } - - /** - * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a - * random seed. - * - * @param fraction - * Fraction of rows to generate, range [0.0, 1.0]. - * - * @note - * This is NOT guaranteed to provide exactly the fraction of the count of the given - * [[Dataset]]. - * - * @group typedrel - * @since 3.4.0 - */ - def sample(fraction: Double): Dataset[T] = { - sample(withReplacement = false, fraction = fraction) - } - - /** - * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. - * - * @param withReplacement - * Sample with replacement or not. - * @param fraction - * Fraction of rows to generate, range [0.0, 1.0]. - * @param seed - * Seed for sampling. - * - * @note - * This is NOT guaranteed to provide exactly the fraction of the count of the given - * [[Dataset]]. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { sparkSession.newDataset(agnosticEncoder) { builder => builder.getSampleBuilder @@ -2092,38 +687,7 @@ class Dataset[T] private[sql] ( } } - /** - * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. - * - * @param withReplacement - * Sample with replacement or not. - * @param fraction - * Fraction of rows to generate, range [0.0, 1.0]. - * - * @note - * This is NOT guaranteed to provide exactly the fraction of the total count of the given - * [[Dataset]]. - * - * @group typedrel - * @since 3.4.0 - */ - def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { - sample(withReplacement, fraction, SparkClassUtils.random.nextLong) - } - - /** - * Randomly splits this Dataset with the provided weights. - * - * @param weights - * weights for splits, will be normalized if they don't sum to 1. - * @param seed - * Seed for sampling. - * - * For Java API, use [[randomSplitAsList]]. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { require( weights.forall(_ >= 0), @@ -2162,35 +726,20 @@ class Dataset[T] private[sql] ( .toArray } - /** - * Returns a Java list that contains randomly split Dataset with the provided weights. - * - * @param weights - * weights for splits, will be normalized if they don't sum to 1. - * @param seed - * Seed for sampling. - * - * @group typedrel - * @since 3.4.0 - */ - def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = { - val values = randomSplit(weights, seed) - java.util.Arrays.asList(values: _*) - } + /** @inheritdoc */ + override def randomSplitAsList(weights: Array[Double], seed: Long): util.List[Dataset[T]] = + util.Arrays.asList(randomSplit(weights, seed): _*) - /** - * Randomly splits this Dataset with the provided weights. - * - * @param weights - * weights for splits, will be normalized if they don't sum to 1. - * @group typedrel - * @since 3.4.0 - */ - def randomSplit(weights: Array[Double]): Array[Dataset[T]] = { - randomSplit(weights, SparkClassUtils.random.nextLong) - } + /** @inheritdoc */ + override def randomSplit(weights: Array[Double]): Array[Dataset[T]] = + randomSplit(weights, SparkClassUtils.random.nextLong()) - private def withColumns(names: Seq[String], values: Seq[Column]): DataFrame = { + /** @inheritdoc */ + protected def withColumns(names: Seq[String], values: Seq[Column]): DataFrame = { + require( + names.size == values.size, + s"The size of column names: ${names.size} isn't equal to " + + s"the size of columns: ${values.size}") val aliases = values.zip(names).map { case (value, name) => value.name(name).expr.getAlias } @@ -2201,109 +750,27 @@ class Dataset[T] private[sql] ( } } - /** - * Returns a new Dataset by adding a column or replacing the existing column that has the same - * name. - * - * `column`'s expression must only refer to attributes supplied by this Dataset. It is an error - * to add a column that refers to some other Dataset. - * - * @note - * this method introduces a projection internally. Therefore, calling it multiple times, for - * instance, via loops in order to add multiple columns can generate big plans which can cause - * performance issues and even `StackOverflowException`. To avoid this, use `select` with the - * multiple columns at once. - * - * @group untypedrel - * @since 3.4.0 - */ - def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) - - /** - * (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns - * that has the same names. - * - * `colsMap` is a map of column name and column, the column must only refer to attributes - * supplied by this Dataset. It is an error to add columns that refers to some other Dataset. - * - * @group untypedrel - * @since 3.4.0 - */ - def withColumns(colsMap: Map[String, Column]): DataFrame = { - val (colNames, newCols) = colsMap.toSeq.unzip - withColumns(colNames, newCols) - } - - /** - * (Java-specific) Returns a new Dataset by adding columns or replacing the existing columns - * that has the same names. - * - * `colsMap` is a map of column name and column, the column must only refer to attribute - * supplied by this Dataset. It is an error to add columns that refers to some other Dataset. - * - * @group untypedrel - * @since 3.4.0 - */ - def withColumns(colsMap: java.util.Map[String, Column]): DataFrame = withColumns( - colsMap.asScala.toMap) - - /** - * Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain - * existingName. - * - * @group untypedrel - * @since 3.4.0 - */ - def withColumnRenamed(existingName: String, newName: String): DataFrame = { - withColumnsRenamed(Collections.singletonMap(existingName, newName)) - } - - /** - * (Scala-specific) Returns a new Dataset with a columns renamed. This is a no-op if schema - * doesn't contain existingName. - * - * `colsMap` is a map of existing column name and new column name. - * - * @throws AnalysisException - * if there are duplicate names in resulting projection - * - * @group untypedrel - * @since 3.4.0 - */ - @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = { - withColumnsRenamed(colsMap.asJava) - } - - /** - * (Java-specific) Returns a new Dataset with a columns renamed. This is a no-op if schema - * doesn't contain existingName. - * - * `colsMap` is a map of existing column name and new column name. - * - * @group untypedrel - * @since 3.4.0 - */ - def withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame = { + override protected def withColumnsRenamed( + colNames: Seq[String], + newColNames: Seq[String]): DataFrame = { + require( + colNames.size == newColNames.size, + s"The size of existing column names: ${colNames.size} isn't equal to " + + s"the size of new column names: ${newColNames.size}") sparkSession.newDataFrame { builder => - builder.getWithColumnsRenamedBuilder + val b = builder.getWithColumnsRenamedBuilder .setInput(plan.getRoot) - .addAllRenames(colsMap.asScala.toSeq.map { case (colName, newColName) => + colNames.zip(newColNames).foreach { case (colName, newColName) => + b.addRenames( proto.WithColumnsRenamed.Rename .newBuilder() .setColName(colName) - .setNewColName(newColName) - .build() - }.asJava) + .setNewColName(newColName)) + } } } - /** - * Returns a new Dataset by updating an existing column with metadata. - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def withMetadata(columnName: String, metadata: Metadata): DataFrame = { val newAlias = proto.Expression.Alias .newBuilder() @@ -2317,88 +784,7 @@ class Dataset[T] private[sql] ( } } - /** - * Registers this Dataset as a temporary table using the given name. The lifetime of this - * temporary table is tied to the [[SparkSession]] that was used to create this Dataset. - * - * @group basic - * @since 3.4.0 - */ - @deprecated("Use createOrReplaceTempView(viewName) instead.", "3.4.0") - def registerTempTable(tableName: String): Unit = { - createOrReplaceTempView(tableName) - } - - /** - * Creates a local temporary view using the given name. The lifetime of this temporary view is - * tied to the [[SparkSession]] that was used to create this Dataset. - * - * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that - * created it, i.e. it will be automatically dropped when the session terminates. It's not tied - * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. - * - * @throws AnalysisException - * if the view name is invalid or already exists - * - * @group basic - * @since 3.4.0 - */ - @throws[AnalysisException] - def createTempView(viewName: String): Unit = { - buildAndExecuteTempView(viewName, replace = false, global = false) - } - - /** - * Creates a local temporary view using the given name. The lifetime of this temporary view is - * tied to the [[SparkSession]] that was used to create this Dataset. - * - * @group basic - * @since 3.4.0 - */ - def createOrReplaceTempView(viewName: String): Unit = { - buildAndExecuteTempView(viewName, replace = true, global = false) - } - - /** - * Creates a global temporary view using the given name. The lifetime of this temporary view is - * tied to this Spark application. - * - * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark - * application, i.e. it will be automatically dropped when the application terminates. It's tied - * to a system preserved database `global_temp`, and we must use the qualified name to refer a - * global temp view, e.g. `SELECT * FROM global_temp.view1`. - * - * @throws AnalysisException - * if the view name is invalid or already exists - * - * @group basic - * @since 3.4.0 - */ - @throws[AnalysisException] - def createGlobalTempView(viewName: String): Unit = { - buildAndExecuteTempView(viewName, replace = false, global = true) - } - - /** - * Creates or replaces a global temporary view using the given name. The lifetime of this - * temporary view is tied to this Spark application. - * - * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark - * application, i.e. it will be automatically dropped when the application terminates. It's tied - * to a system preserved database `global_temp`, and we must use the qualified name to refer a - * global temp view, e.g. `SELECT * FROM global_temp.view1`. - * - * @group basic - * @since 3.4.0 - */ - def createOrReplaceGlobalTempView(viewName: String): Unit = { - buildAndExecuteTempView(viewName, replace = true, global = true) - } - - private def buildAndExecuteTempView( - viewName: String, - replace: Boolean, - global: Boolean): Unit = { + protected def createTempView(viewName: String, replace: Boolean, global: Boolean): Unit = { val command = sparkSession.newCommand { builder => builder.getCreateDataframeViewBuilder .setInput(plan.getRoot) @@ -2409,56 +795,11 @@ class Dataset[T] private[sql] ( sparkSession.execute(command) } - /** - * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column - * name. - * - * This method can only be used to drop top level columns. the colName string is treated - * literally without further interpretation. - * - * @group untypedrel - * @since 3.4.0 - */ - def drop(colName: String): DataFrame = { - drop(Seq(colName): _*) - } - - /** - * Returns a new Dataset with columns dropped. This is a no-op if schema doesn't contain column - * name(s). - * - * This method can only be used to drop top level columns. the colName string is treated - * literally without further interpretation. - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def drop(colNames: String*): DataFrame = buildDropByNames(colNames) - /** - * Returns a new Dataset with column dropped. - * - * This method can only be used to drop top level column. This version of drop accepts a - * [[Column]] rather than a name. This is a no-op if the Dataset doesn't have a column with an - * equivalent expression. - * - * @group untypedrel - * @since 3.4.0 - */ - def drop(col: Column): DataFrame = { - buildDrop(col :: Nil) - } - - /** - * Returns a new Dataset with columns dropped. - * - * This method can only be used to drop top level columns. This is a no-op if the Dataset - * doesn't have a columns with an equivalent expression. - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def drop(col: Column, cols: Column*): DataFrame = buildDrop(col +: cols) @@ -2489,167 +830,32 @@ class Dataset[T] private[sql] ( } } - /** - * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias - * for `distinct`. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def dropDuplicates(): Dataset[T] = buildDropDuplicates(None, withinWaterMark = false) - /** - * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only the - * subset of columns. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def dropDuplicates(colNames: Seq[String]): Dataset[T] = { buildDropDuplicates(Option(colNames), withinWaterMark = false) } - /** - * Returns a new Dataset with duplicate rows removed, considering only the subset of columns. - * - * @group typedrel - * @since 3.4.0 - */ - def dropDuplicates(colNames: Array[String]): Dataset[T] = - dropDuplicates(colNames.toImmutableArraySeq) - - /** - * Returns a new [[Dataset]] with duplicate rows removed, considering only the subset of - * columns. - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def dropDuplicates(col1: String, cols: String*): Dataset[T] = { - dropDuplicates(col1 +: cols) - } - + /** @inheritdoc */ def dropDuplicatesWithinWatermark(): Dataset[T] = buildDropDuplicates(None, withinWaterMark = true) + /** @inheritdoc */ def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = { buildDropDuplicates(Option(colNames), withinWaterMark = true) } - def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = { - dropDuplicatesWithinWatermark(colNames.toImmutableArraySeq) - } - - @scala.annotation.varargs - def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = { - dropDuplicatesWithinWatermark(col1 +: cols) - } - - /** - * Computes basic statistics for numeric and string columns, including count, mean, stddev, min, - * and max. If no columns are given, this function computes statistics for all numerical or - * string columns. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting Dataset. If you want to - * programmatically compute summary statistics, use the `agg` function instead. - * - * {{{ - * ds.describe("age", "height").show() - * - * // output: - * // summary age height - * // count 10.0 10.0 - * // mean 53.3 178.05 - * // stddev 11.6 15.7 - * // min 18.0 163.0 - * // max 92.0 192.0 - * }}} - * - * Use [[summary]] for expanded statistics and control over which statistics to compute. - * - * @param cols - * Columns to compute statistics on. - * - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = sparkSession.newDataFrame { builder => + override def describe(cols: String*): DataFrame = sparkSession.newDataFrame { builder => builder.getDescribeBuilder .setInput(plan.getRoot) .addAllCols(cols.asJava) } - /** - * Computes specified statistics for numeric and string columns. Available statistics are:
    - *
  • count
  • mean
  • stddev
  • min
  • max
  • arbitrary - * approximate percentiles specified as a percentage (e.g. 75%)
  • count_distinct
  • - *
  • approx_count_distinct
- * - * If no statistics are given, this function computes count, mean, stddev, min, approximate - * quartiles (percentiles at 25%, 50%, and 75%), and max. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting Dataset. If you want to - * programmatically compute summary statistics, use the `agg` function instead. - * - * {{{ - * ds.summary().show() - * - * // output: - * // summary age height - * // count 10.0 10.0 - * // mean 53.3 178.05 - * // stddev 11.6 15.7 - * // min 18.0 163.0 - * // 25% 24.0 176.0 - * // 50% 24.0 176.0 - * // 75% 32.0 180.0 - * // max 92.0 192.0 - * }}} - * - * {{{ - * ds.summary("count", "min", "25%", "75%", "max").show() - * - * // output: - * // summary age height - * // count 10.0 10.0 - * // min 18.0 163.0 - * // 25% 24.0 176.0 - * // 75% 32.0 180.0 - * // max 92.0 192.0 - * }}} - * - * To do a summary for specific columns first select them: - * - * {{{ - * ds.select("age", "height").summary().show() - * }}} - * - * Specify statistics to output custom summaries: - * - * {{{ - * ds.summary("count", "count_distinct").show() - * }}} - * - * The distinct count isn't included by default. - * - * You can also run approximate distinct counts which are faster: - * - * {{{ - * ds.summary("count", "approx_count_distinct").show() - * }}} - * - * See also [[describe]] for basic statistics. - * - * @param statistics - * Statistics from above list to be computed. - * - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def summary(statistics: String*): DataFrame = sparkSession.newDataFrame { builder => builder.getSummaryBuilder @@ -2657,56 +863,12 @@ class Dataset[T] private[sql] ( .addAllStatistics(statistics.asJava) } - /** - * Returns the first `n` rows. - * - * @note - * this method should only be used if the resulting array is expected to be small, as all the - * data is loaded into the driver's memory. - * - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ def head(n: Int): Array[T] = limit(n).collect() - /** - * Returns the first row. - * @group action - * @since 3.4.0 - */ - def head(): T = head(1).head - - /** - * Returns the first row. Alias for head(). - * @group action - * @since 3.4.0 - */ - def first(): T = head() - - /** - * Concise syntax for chaining custom transformations. - * {{{ - * def featurize(ds: Dataset[T]): Dataset[U] = ... - * - * ds - * .transform(featurize) - * .transform(...) - * }}} - * - * @group typedrel - * @since 3.4.0 - */ - def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) - - /** - * (Scala-specific) Returns a new Dataset that only contains elements where `func` returns - * `true`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def filter(func: T => Boolean): Dataset[T] = { - val udf = ScalaUserDefinedFunction( + val udf = SparkUserDefinedFunction( function = func, inputEncoders = agnosticEncoder :: Nil, outputEncoder = PrimitiveBooleanEncoder) @@ -2717,49 +879,25 @@ class Dataset[T] private[sql] ( } } - /** - * (Java-specific) Returns a new Dataset that only contains elements where `func` returns - * `true`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def filter(f: FilterFunction[T]): Dataset[T] = { - filter(UdfUtils.filterFuncToScalaFunc(f)) + filter(ToScalaUDF(f)) } - /** - * (Scala-specific) Returns a new Dataset that contains the result of applying `func` to each - * element. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def map[U: Encoder](f: T => U): Dataset[U] = { - mapPartitions(UdfUtils.mapFuncToMapPartitionsAdaptor(f)) + mapPartitions(UDFAdaptors.mapToMapPartitions(f)) } - /** - * (Java-specific) Returns a new Dataset that contains the result of applying `func` to each - * element. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def map[U](f: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - map(UdfUtils.mapFunctionToScalaFunc(f))(encoder) + mapPartitions(UDFAdaptors.mapToMapPartitions(f))(encoder) } - /** - * (Scala-specific) Returns a new Dataset that contains the result of applying `func` to each - * partition. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - val outputEncoder = encoderFor[U] - val udf = ScalaUserDefinedFunction( + val outputEncoder = agnosticEncoderFor[U] + val udf = SparkUserDefinedFunction( function = func, inputEncoders = agnosticEncoder :: Nil, outputEncoder = outputEncoder) @@ -2770,166 +908,34 @@ class Dataset[T] private[sql] ( } } - /** - * (Java-specific) Returns a new Dataset that contains the result of applying `f` to each - * partition. - * - * @group typedrel - * @since 3.5.0 - */ - def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - mapPartitions(UdfUtils.mapPartitionsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Returns a new Dataset by first applying a function to all elements of this - * Dataset, and then flattening the results. - * - * @group typedrel - * @since 3.5.0 - */ - def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = - mapPartitions(UdfUtils.flatMapFuncToMapPartitionsAdaptor(func)) - - /** - * (Java-specific) Returns a new Dataset by first applying a function to all elements of this - * Dataset, and then flattening the results. - * - * @group typedrel - * @since 3.5.0 - */ - def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows - * by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the - * input row are implicitly joined with each row that is output by the function. - * - * Given that this is deprecated, as an alternative, you can explode columns either using - * `functions.explode()` or `flatMap()`. The following example uses these alternatives to count - * the number of books that contain a given word: - * - * {{{ - * case class Book(title: String, words: String) - * val ds: Dataset[Book] - * - * val allWords = ds.select($"title", explode(split($"words", " ")).as("word")) - * - * val bookCountPerWord = allWords.groupBy("word").agg(count_distinct("title")) - * }}} - * - * Using `flatMap()` this can similarly be exploded as: - * - * {{{ - * ds.flatMap(_.words.split(" ")) - * }}} - * - * @group untypedrel - * @since 3.5.0 - */ + /** @inheritdoc */ @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0") def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = { - val generator = ScalaUserDefinedFunction( - UdfUtils.iterableOnceToSeq(f), + val generator = SparkUserDefinedFunction( + UDFAdaptors.iterableOnceToSeq(f), UnboundRowEncoder :: Nil, ScalaReflection.encoderFor[Seq[A]]) select(col("*"), functions.inline(generator(struct(input: _*)))) } - /** - * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or - * more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All - * columns of the input row are implicitly joined with each value that is output by the - * function. - * - * Given that this is deprecated, as an alternative, you can explode columns either using - * `functions.explode()`: - * - * {{{ - * ds.select(explode(split($"words", " ")).as("word")) - * }}} - * - * or `flatMap()`: - * - * {{{ - * ds.flatMap(_.words.split(" ")) - * }}} - * - * @group untypedrel - * @since 3.5.0 - */ + /** @inheritdoc */ @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0") def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( f: A => IterableOnce[B]): DataFrame = { - val generator = ScalaUserDefinedFunction( - UdfUtils.iterableOnceToSeq(f), + val generator = SparkUserDefinedFunction( + UDFAdaptors.iterableOnceToSeq(f), Nil, ScalaReflection.encoderFor[Seq[B]]) select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn))) } - /** - * Applies a function `f` to all rows. - * - * @group action - * @since 3.5.0 - */ - def foreach(f: T => Unit): Unit = { - foreachPartition(UdfUtils.foreachFuncToForeachPartitionsAdaptor(f)) - } - - /** - * (Java-specific) Runs `func` on each element of this Dataset. - * - * @group action - * @since 3.5.0 - */ - def foreach(func: ForeachFunction[T]): Unit = foreach(UdfUtils.foreachFuncToScalaFunc(func)) - - /** - * Applies a function `f` to each partition of this Dataset. - * - * @group action - * @since 3.5.0 - */ - def foreachPartition(f: Iterator[T] => Unit): Unit = { - // Delegate to mapPartition with empty result. - mapPartitions(UdfUtils.foreachPartitionFuncToMapPartitionsAdaptor(f))(RowEncoder(Seq.empty)) - .collect() - } - - /** - * (Java-specific) Runs `func` on each partition of this Dataset. - * - * @group action - * @since 3.5.0 - */ - def foreachPartition(func: ForeachPartitionFunction[T]): Unit = { - foreachPartition(UdfUtils.foreachPartitionFuncToScalaFunc(func)) - } - - /** - * Returns the first `n` rows in the Dataset. - * - * Running take requires moving data into the application's driver process, and doing so with a - * very large `n` can crash the driver process with OutOfMemoryError. - * - * @group action - * @since 3.4.0 - */ - def take(n: Int): Array[T] = head(n) - - /** - * Returns the last `n` rows in the Dataset. - * - * Running tail requires moving data into the application's driver process, and doing so with a - * very large `n` can crash the driver process with OutOfMemoryError. - * - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ + def foreachPartition(f: Iterator[T] => Unit): Unit = { + // Delegate to mapPartition with empty result. + mapPartitions(UDFAdaptors.foreachPartitionToMapPartitions(f))(NullEncoder).collect() + } + + /** @inheritdoc */ def tail(n: Int): Array[T] = { val lastN = sparkSession.newDataset(agnosticEncoder) { builder => builder.getTailBuilder @@ -2939,64 +945,22 @@ class Dataset[T] private[sql] ( lastN.collect() } - /** - * Returns the first `n` rows in the Dataset as a list. - * - * Running take requires moving data into the application's driver process, and doing so with a - * very large `n` can crash the driver process with OutOfMemoryError. - * - * @group action - * @since 3.4.0 - */ - def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n): _*) - - /** - * Returns an array that contains all rows in this Dataset. - * - * Running collect requires moving all the data into the application's driver process, and doing - * so on a very large dataset can crash the driver process with OutOfMemoryError. - * - * For Java API, use [[collectAsList]]. - * - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ def collect(): Array[T] = withResult { result => result.toArray } - /** - * Returns a Java list that contains all rows in this Dataset. - * - * Running collect requires moving all the data into the application's driver process, and doing - * so on a very large dataset can crash the driver process with OutOfMemoryError. - * - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ def collectAsList(): java.util.List[T] = { java.util.Arrays.asList(collect(): _*) } - /** - * Returns an iterator that contains all rows in this Dataset. - * - * The returned iterator implements [[AutoCloseable]]. For resource management it is better to - * close it once you are done. If you don't close it, it and the underlying data will be cleaned - * up once the iterator is garbage collected. - * - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ def toLocalIterator(): java.util.Iterator[T] = { collectResult().destructiveIterator.asJava } - /** - * Returns the number of rows in the Dataset. - * @group action - * @since 3.4.0 - */ + /** @inheritdoc */ def count(): Long = { groupBy().count().as(PrimitiveLongEncoder).collect().head } @@ -3020,17 +984,12 @@ class Dataset[T] private[sql] ( numPartitions.foreach(repartitionBuilder.setNumPartitions) } - /** - * Returns a new Dataset that has exactly `numPartitions` partitions. - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def repartition(numPartitions: Int): Dataset[T] = { buildRepartition(numPartitions, shuffle = true) } - private def repartitionByExpression( + protected def repartitionByExpression( numPartitions: Option[Int], partitionExprs: Seq[Column]): Dataset[T] = { // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. @@ -3045,36 +1004,7 @@ class Dataset[T] private[sql] ( buildRepartitionByExpression(numPartitions, partitionExprs) } - /** - * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`. - * The resulting Dataset is hash partitioned. - * - * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { - repartitionByExpression(Some(numPartitions), partitionExprs) - } - - /** - * Returns a new Dataset partitioned by the given partitioning expressions, using - * `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash - * partitioned. - * - * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). - * - * @group typedrel - * @since 2.0.0 - */ - @scala.annotation.varargs - def repartition(partitionExprs: Column*): Dataset[T] = { - repartitionByExpression(None, partitionExprs) - } - - private def repartitionByRange( + protected def repartitionByRange( numPartitions: Option[Int], partitionExprs: Seq[Column]): Dataset[T] = { require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") @@ -3085,93 +1015,12 @@ class Dataset[T] private[sql] ( buildRepartitionByExpression(numPartitions, sortExprs) } - /** - * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`. - * The resulting Dataset is range partitioned. - * - * At least one partition-by expression must be specified. When no explicit sort order is - * specified, "ascending nulls first" is assumed. Note, the rows are not sorted in each - * partition of the resulting Dataset. - * - * Note that due to performance reasons this method uses sampling to estimate the ranges. Hence, - * the output may not be consistent, since sampling can return different values. The sample size - * can be controlled by the config `spark.sql.execution.rangeExchange.sampleSizePerPartition`. - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { - repartitionByRange(Some(numPartitions), partitionExprs) - } - - /** - * Returns a new Dataset partitioned by the given partitioning expressions, using - * `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is range - * partitioned. - * - * At least one partition-by expression must be specified. When no explicit sort order is - * specified, "ascending nulls first" is assumed. Note, the rows are not sorted in each - * partition of the resulting Dataset. - * - * Note that due to performance reasons this method uses sampling to estimate the ranges. Hence, - * the output may not be consistent, since sampling can return different values. The sample size - * can be controlled by the config `spark.sql.execution.rangeExchange.sampleSizePerPartition`. - * - * @group typedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def repartitionByRange(partitionExprs: Column*): Dataset[T] = { - repartitionByRange(None, partitionExprs) - } - - /** - * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions - * are requested. If a larger number of partitions is requested, it will stay at the current - * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in a - * narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a - * shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. - * - * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, this may result in - * your computation taking place on fewer nodes than you like (e.g. one node in the case of - * numPartitions = 1). To avoid this, you can call repartition. This will add a shuffle step, - * but means the current upstream partitions will be executed in parallel (per whatever the - * current partitioning is). - * - * @group typedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def coalesce(numPartitions: Int): Dataset[T] = { buildRepartition(numPartitions, shuffle = false) } - /** - * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias - * for `dropDuplicates`. - * - * Note that for a streaming [[Dataset]], this method returns distinct rows only once regardless - * of the output mode, which the behavior may not be same with `DISTINCT` in SQL against - * streaming [[Dataset]]. - * - * @note - * Equality checking is performed directly on the encoded representation of the data and thus - * is not affected by a custom `equals` function defined on `T`. - * - * @group typedrel - * @since 3.4.0 - */ - def distinct(): Dataset[T] = dropDuplicates() - - /** - * Returns a best-effort snapshot of the files that compose this Dataset. This method simply - * asks each constituent BaseRelation for its respective files and takes the union of all - * results. Depending on the source relations, this may not find all input files. Duplicates are - * removed. - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def inputFiles: Array[String] = sparkSession .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES) @@ -3180,61 +1029,17 @@ class Dataset[T] private[sql] ( .asScala .toArray - /** - * Interface for saving the content of the non-streaming Dataset out into external storage. - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def write: DataFrameWriter[T] = { - new DataFrameWriter[T](this) + new DataFrameWriterImpl[T](this) } - /** - * Create a write configuration builder for v2 sources. - * - * This builder is used to configure and execute write operations. For example, to append to an - * existing table, run: - * - * {{{ - * df.writeTo("catalog.db.table").append() - * }}} - * - * This can also be used to create or replace existing tables: - * - * {{{ - * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() - * }}} - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def writeTo(table: String): DataFrameWriterV2[T] = { - new DataFrameWriterV2[T](table, this) + new DataFrameWriterV2Impl[T](table, this) } - /** - * Merges a set of updates, insertions, and deletions based on a source table into a target - * table. - * - * Scala Examples: - * {{{ - * spark.table("source") - * .mergeInto("target", $"source.id" === $"target.id") - * .whenMatched($"salary" === 100) - * .delete() - * .whenNotMatched() - * .insertAll() - * .whenNotMatchedBySource($"salary" === 100) - * .update(Map( - * "salary" -> lit(200) - * )) - * .merge() - * }}} - * - * @group basic - * @since 4.0.0 - */ + /** @inheritdoc */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { throw new AnalysisException( @@ -3242,25 +1047,18 @@ class Dataset[T] private[sql] ( messageParameters = Map("methodName" -> toSQLId("mergeInto"))) } - new MergeIntoWriter[T](table, this, condition) + new MergeIntoWriterImpl[T](table, this, condition) } - /** - * Interface for saving the content of the streaming Dataset out into external storage. - * - * @group basic - * @since 3.5.0 - */ + /** @inheritdoc */ def writeStream: DataStreamWriter[T] = { new DataStreamWriter[T](this) } - /** - * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ + override def cache(): this.type = persist() + + /** @inheritdoc */ def persist(): this.type = { sparkSession.analyze { builder => builder.getPersistBuilder.setRelation(plan.getRoot) @@ -3268,15 +1066,7 @@ class Dataset[T] private[sql] ( this } - /** - * Persist this Dataset with the given storage level. - * - * @param newLevel - * One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, `MEMORY_AND_DISK_SER`, - * `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc. - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def persist(newLevel: StorageLevel): this.type = { sparkSession.analyze { builder => builder.getPersistBuilder @@ -3286,15 +1076,7 @@ class Dataset[T] private[sql] ( this } - /** - * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This - * will not un-persist any cached data that is built upon this Dataset. - * - * @param blocking - * Whether to block until all blocks are deleted. - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def unpersist(blocking: Boolean): this.type = { sparkSession.analyze { builder => builder.getUnpersistBuilder @@ -3304,29 +1086,10 @@ class Dataset[T] private[sql] ( this } - /** - * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This - * will not un-persist any cached data that is built upon this Dataset. - * - * @group basic - * @since 3.4.0 - */ - def unpersist(): this.type = unpersist(blocking = false) - - /** - * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). - * - * @group basic - * @since 3.4.0 - */ - def cache(): this.type = persist() + /** @inheritdoc */ + override def unpersist(): this.type = unpersist(blocking = false) - /** - * Get the Dataset's current storage level, or StorageLevel.NONE if not persisted. - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def storageLevel: StorageLevel = { StorageLevelProtoConverter.toStorageLevel( sparkSession @@ -3337,30 +1100,7 @@ class Dataset[T] private[sql] ( .getStorageLevel) } - /** - * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time - * before which we assume no more late data is going to arrive. - * - * Spark will use this watermark for several purposes:
  • To know when a given time window - * aggregation can be finalized and thus can be emitted when using output modes that do not - * allow updates.
  • To minimize the amount of state that we need to keep for on-going - * aggregations, `mapGroupsWithState` and `dropDuplicates` operators.
The current - * watermark is computed by looking at the `MAX(eventTime)` seen across all of the partitions in - * the query minus a user specified `delayThreshold`. Due to the cost of coordinating this value - * across partitions, the actual watermark used is only guaranteed to be at least - * `delayThreshold` behind the actual event time. In some cases we may still process records - * that arrive more than `delayThreshold` late. - * - * @param eventTime - * the name of the column that contains the event time of the row. - * @param delayThreshold - * the minimum delay to wait to data to arrive late, relative to the latest record that has - * been processed in the form of an interval (e.g. "1 minute" or "5 hours"). NOTE: This should - * not be negative. - * - * @group streaming - * @since 3.5.0 - */ + /** @inheritdoc */ def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = { sparkSession.newDataset(agnosticEncoder) { builder => builder.getWithWatermarkBuilder @@ -3370,32 +1110,7 @@ class Dataset[T] private[sql] ( } } - /** - * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset - * that returns the same result as the input, with the following guarantees:
  • It will - * compute the defined aggregates (metrics) on all the data that is flowing through the Dataset - * at that point.
  • It will report the value of the defined aggregate columns as soon as - * we reach a completion point. A completion point is currently defined as the end of a - * query.
Please note that continuous execution is currently not supported. - * - * The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or - * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that - * contain references to the input Dataset's columns must always be wrapped in an aggregate - * function. - * - * A user can retrieve the metrics by calling - * `org.apache.spark.sql.Dataset.collectResult().getObservedMetrics`. - * - * {{{ - * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it - * val observed_ds = ds.observe("my_metrics", count(lit(1)).as("rows"), max($"id").as("maxid")) - * observed_ds.write.parquet("ds.parquet") - * val metrics = observed_ds.collectResult().getObservedMetrics - * }}} - * - * @group typedrel - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = { sparkSession.newDataset(agnosticEncoder) { builder => @@ -3406,28 +1121,7 @@ class Dataset[T] private[sql] ( } } - /** - * Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This is - * equivalent to calling `observe(String, Column, Column*)` but does not require to collect all - * results before returning the metrics - the metrics are filled during iterating the results, - * as soon as they are available. This method does not support streaming datasets. - * - * A user can retrieve the metrics by accessing `org.apache.spark.sql.Observation.get`. - * - * {{{ - * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it - * val observation = Observation("my_metrics") - * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid")) - * observed_ds.write.parquet("ds.parquet") - * val metrics = observation.get - * }}} - * - * @throws IllegalArgumentException - * If this is a streaming Dataset (this.isStreaming == true) - * - * @group typedrel - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = { val df = observe(observation.name, expr, exprs: _*) @@ -3435,88 +1129,21 @@ class Dataset[T] private[sql] ( df } - /** - * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to - * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms - * where the plan may grow exponentially. It will be saved to files inside the checkpoint - * directory set with `SparkContext#setCheckpointDir`. - * - * @group basic - * @since 4.0.0 - */ - def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) - - /** - * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the - * logical plan of this Dataset, which is especially useful in iterative algorithms where the - * plan may grow exponentially. It will be saved to files inside the checkpoint directory set - * with `SparkContext#setCheckpointDir`. - * - * @param eager - * Whether to checkpoint this dataframe immediately - * - * @note - * When checkpoint is used with eager = false, the final data that is checkpointed after the - * first action may be different from the data that was used during the job due to - * non-determinism of the underlying operation and retries. If checkpoint is used to achieve - * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is - * only deterministic after the first execution, after the checkpoint was finalized. - * - * @group basic - * @since 4.0.0 - */ - def checkpoint(eager: Boolean): Dataset[T] = - checkpoint(eager = eager, reliableCheckpoint = true) - - /** - * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used - * to truncate the logical plan of this Dataset, which is especially useful in iterative - * algorithms where the plan may grow exponentially. Local checkpoints are written to executor - * storage and despite potentially faster they are unreliable and may compromise job completion. - * - * @group basic - * @since 4.0.0 - */ - def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) - - /** - * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to - * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms - * where the plan may grow exponentially. Local checkpoints are written to executor storage and - * despite potentially faster they are unreliable and may compromise job completion. - * - * @param eager - * Whether to checkpoint this dataframe immediately - * - * @note - * When checkpoint is used with eager = false, the final data that is checkpointed after the - * first action may be different from the data that was used during the job due to - * non-determinism of the underlying operation and retries. If checkpoint is used to achieve - * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is - * only deterministic after the first execution, after the checkpoint was finalized. - * - * @group basic - * @since 4.0.0 - */ - def localCheckpoint(eager: Boolean): Dataset[T] = - checkpoint(eager = eager, reliableCheckpoint = false) - - /** - * Returns a checkpointed version of this Dataset. - * - * @param eager - * Whether to checkpoint this dataframe immediately - * @param reliableCheckpoint - * Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If - * false creates a local checkpoint using the caching subsystem - */ - private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + /** @inheritdoc */ + protected def checkpoint( + eager: Boolean, + reliableCheckpoint: Boolean, + storageLevel: Option[StorageLevel]): Dataset[T] = { sparkSession.newDataset(agnosticEncoder) { builder => val command = sparkSession.newCommand { builder => - builder.getCheckpointCommandBuilder + val checkpointBuilder = builder.getCheckpointCommandBuilder .setLocal(!reliableCheckpoint) .setEager(eager) .setRelation(this.plan.getRoot) + storageLevel.foreach { storageLevel => + checkpointBuilder.setStorageLevel( + StorageLevelProtoConverter.toConnectProtoType(storageLevel)) + } } val responseIter = sparkSession.execute(command) try { @@ -3536,38 +1163,19 @@ class Dataset[T] private[sql] ( } } - /** - * Returns `true` when the logical query plans inside both [[Dataset]]s are equal and therefore - * return same results. - * - * @note - * The equality comparison here is simplified by tolerating the cosmetic differences such as - * attribute names. - * @note - * This API can compare both [[Dataset]]s but can still return `false` on the [[Dataset]] that - * return the same results, for instance, from different plans. Such false negative semantic - * can be useful when caching as an example. This comparison may not be fast because it will - * execute a RPC call. - * @since 3.4.0 - */ + /** @inheritdoc */ @DeveloperApi def sameSemantics(other: Dataset[T]): Boolean = { sparkSession.sameSemantics(this.plan, other.plan) } - /** - * Returns a `hashCode` of the logical query plan against this [[Dataset]]. - * - * @note - * Unlike the standard `hashCode`, the hash is calculated against the query plan simplified by - * tolerating the cosmetic differences such as attribute names. - * @since 3.4.0 - */ + /** @inheritdoc */ @DeveloperApi def semanticHash(): Int = { sparkSession.semanticHash(this.plan) } + /** @inheritdoc */ def toJSON: Dataset[String] = { select(to_json(struct(col("*")))).as(StringEncoder) } @@ -3590,5 +1198,306 @@ class Dataset[T] private[sql] ( * We cannot deserialize a connect [[Dataset]] because of a class clash on the server side. We * null out the instance for now. */ + @scala.annotation.unused("this is used by java serialization") private def writeReplace(): Any = null + + //////////////////////////////////////////////////////////////////////////// + // Return type overrides to make sure we return the implementation instead + // of the interface. This is done for a couple of reasons: + // - Retain the old signatures for binary compatibility; + // - Java compatibility . The java compiler uses the byte code signatures, + // and those would point to api.Dataset being returned instead of Dataset. + // This causes issues when the java code tries to materialize results, or + // tries to use functionality that is implementation specfic. + // - Scala method resolution runs into problems when the ambiguous methods are + // scattered across the interface and implementation. `drop` and `select` + // suffered from this. + //////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def drop(colName: String): DataFrame = super.drop(colName) + + /** @inheritdoc */ + override def drop(col: Column): DataFrame = super.drop(col) + + /** @inheritdoc */ + override def join(right: Dataset[_], usingColumn: String): DataFrame = + super.join(right, usingColumn) + + /** @inheritdoc */ + override def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = + super.join(right, usingColumns) + + /** @inheritdoc */ + override def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = + super.join(right, usingColumns) + + /** @inheritdoc */ + override def join(right: Dataset[_], usingColumn: String, joinType: String): DataFrame = + super.join(right, usingColumn, joinType) + + /** @inheritdoc */ + override def join(right: Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = + super.join(right, usingColumns, joinType) + + /** @inheritdoc */ + override def join(right: Dataset[_], joinExprs: Column): DataFrame = + super.join(right, joinExprs) + + /** @inheritdoc */ + @scala.annotation.varargs + override def select(col: String, cols: String*): DataFrame = super.select(col, cols: _*) + + /** @inheritdoc */ + override def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = + super.select(c1, c2) + + /** @inheritdoc */ + override def select[U1, U2, U3]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = + super.select(c1, c2, c3) + + /** @inheritdoc */ + override def select[U1, U2, U3, U4]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = + super.select(c1, c2, c3, c4) + + /** @inheritdoc */ + override def select[U1, U2, U3, U4, U5]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4], + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = + super.select(c1, c2, c3, c4, c5) + + override def melt( + ids: Array[Column], + values: Array[Column], + variableColumnName: String, + valueColumnName: String): DataFrame = + super.melt(ids, values, variableColumnName, valueColumnName) + + /** @inheritdoc */ + override def melt( + ids: Array[Column], + variableColumnName: String, + valueColumnName: String): DataFrame = + super.melt(ids, variableColumnName, valueColumnName) + + /** @inheritdoc */ + override def withColumn(colName: String, col: Column): DataFrame = + super.withColumn(colName, col) + + /** @inheritdoc */ + override def withColumns(colsMap: Map[String, Column]): DataFrame = + super.withColumns(colsMap) + + /** @inheritdoc */ + override def withColumns(colsMap: util.Map[String, Column]): DataFrame = + super.withColumns(colsMap) + + /** @inheritdoc */ + override def withColumnRenamed(existingName: String, newName: String): DataFrame = + super.withColumnRenamed(existingName, newName) + + /** @inheritdoc */ + override def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = + super.withColumnsRenamed(colsMap) + + /** @inheritdoc */ + override def withColumnsRenamed(colsMap: util.Map[String, String]): DataFrame = + super.withColumnsRenamed(colsMap) + + /** @inheritdoc */ + override def checkpoint(): Dataset[T] = super.checkpoint() + + /** @inheritdoc */ + override def checkpoint(eager: Boolean): Dataset[T] = super.checkpoint(eager) + + /** @inheritdoc */ + override def localCheckpoint(): Dataset[T] = super.localCheckpoint() + + /** @inheritdoc */ + override def localCheckpoint(eager: Boolean): Dataset[T] = super.localCheckpoint(eager) + + /** @inheritdoc */ + override def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] = + super.localCheckpoint(eager, storageLevel) + + /** @inheritdoc */ + override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = + super.joinWith(other, condition) + + /** @inheritdoc */ + @scala.annotation.varargs + override def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = + super.sortWithinPartitions(sortCol, sortCols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def sortWithinPartitions(sortExprs: Column*): Dataset[T] = + super.sortWithinPartitions(sortExprs: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def sort(sortCol: String, sortCols: String*): Dataset[T] = + super.sort(sortCol, sortCols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def sort(sortExprs: Column*): Dataset[T] = super.sort(sortExprs: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def orderBy(sortCol: String, sortCols: String*): Dataset[T] = + super.orderBy(sortCol, sortCols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def orderBy(sortExprs: Column*): Dataset[T] = super.orderBy(sortExprs: _*) + + /** @inheritdoc */ + override def as(alias: Symbol): Dataset[T] = super.as(alias) + + /** @inheritdoc */ + override def alias(alias: String): Dataset[T] = super.alias(alias) + + /** @inheritdoc */ + override def alias(alias: Symbol): Dataset[T] = super.alias(alias) + + /** @inheritdoc */ + @scala.annotation.varargs + override def selectExpr(exprs: String*): DataFrame = super.selectExpr(exprs: _*) + + /** @inheritdoc */ + override def filter(conditionExpr: String): Dataset[T] = super.filter(conditionExpr) + + /** @inheritdoc */ + override def where(condition: Column): Dataset[T] = super.where(condition) + + /** @inheritdoc */ + override def where(conditionExpr: String): Dataset[T] = super.where(conditionExpr) + + /** @inheritdoc */ + override def unionAll(other: Dataset[T]): Dataset[T] = super.unionAll(other) + + /** @inheritdoc */ + override def unionByName(other: Dataset[T]): Dataset[T] = super.unionByName(other) + + /** @inheritdoc */ + override def sample(fraction: Double, seed: Long): Dataset[T] = super.sample(fraction, seed) + + /** @inheritdoc */ + override def sample(fraction: Double): Dataset[T] = super.sample(fraction) + + /** @inheritdoc */ + override def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = + super.sample(withReplacement, fraction) + + /** @inheritdoc */ + override def dropDuplicates(colNames: Array[String]): Dataset[T] = + super.dropDuplicates(colNames) + + /** @inheritdoc */ + @scala.annotation.varargs + override def dropDuplicates(col1: String, cols: String*): Dataset[T] = + super.dropDuplicates(col1, cols: _*) + + /** @inheritdoc */ + override def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = + super.dropDuplicatesWithinWatermark(colNames) + + /** @inheritdoc */ + @scala.annotation.varargs + override def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = + super.dropDuplicatesWithinWatermark(col1, cols: _*) + + /** @inheritdoc */ + override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = + super.mapPartitions(f, encoder) + + /** @inheritdoc */ + override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = + super.flatMap(func) + + /** @inheritdoc */ + override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = + super.flatMap(f, encoder) + + /** @inheritdoc */ + override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = + super.foreachPartition(func) + + /** @inheritdoc */ + @scala.annotation.varargs + override def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = + super.repartition(numPartitions, partitionExprs: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def repartition(partitionExprs: Column*): Dataset[T] = + super.repartition(partitionExprs: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = + super.repartitionByRange(numPartitions, partitionExprs: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def repartitionByRange(partitionExprs: Column*): Dataset[T] = + super.repartitionByRange(partitionExprs: _*) + + /** @inheritdoc */ + override def distinct(): Dataset[T] = super.distinct() + + /** @inheritdoc */ + @scala.annotation.varargs + override def groupBy(col1: String, cols: String*): RelationalGroupedDataset = + super.groupBy(col1, cols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def rollup(col1: String, cols: String*): RelationalGroupedDataset = + super.rollup(col1, cols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def cube(col1: String, cols: String*): RelationalGroupedDataset = + super.cube(col1, cols: _*) + + /** @inheritdoc */ + override def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = + super.agg(aggExpr, aggExprs: _*) + + /** @inheritdoc */ + override def agg(exprs: Map[String, String]): DataFrame = super.agg(exprs) + + /** @inheritdoc */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = super.agg(exprs) + + /** @inheritdoc */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + + /** @inheritdoc */ + override def rdd: RDD[T] = throw ConnectClientUnsupportedErrors.rdd() + + /** @inheritdoc */ + override def toJavaRDD: JavaRDD[T] = throw ConnectClientUnsupportedErrors.rdd() + + override def queryExecution: QueryExecution = + throw ConnectClientUnsupportedErrors.queryExecution() } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index d2875dfe48c88..63b5f27c4745e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -19,16 +19,19 @@ package org.apache.spark.sql import java.util.Arrays +import scala.annotation.unused import scala.jdk.CollectionConverters._ -import scala.language.existentials import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.UdfUtils -import org.apache.spark.sql.expressions.ScalaUserDefinedFunction +import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr +import org.apache.spark.sql.internal.UDFAdaptors import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} /** @@ -38,7 +41,10 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * * @since 3.5.0 */ -class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { +class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] { + type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] + + private def unsupported(): Nothing = throw new UnsupportedOperationException() /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the @@ -47,499 +53,52 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * * @since 3.5.0 */ - def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = { - throw new UnsupportedOperationException - } - - /** - * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to - * the data. The grouping key is unchanged by this. - * - * {{{ - * // Create values grouped by key from a Dataset[(K, V)] - * ds.groupByKey(_._1).mapValues(_._2) // Scala - * }}} - * - * @since 3.5.0 - */ - def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = { - throw new UnsupportedOperationException - } - - /** - * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to - * the data. The grouping key is unchanged by this. - * - * {{{ - * // Create Integer values grouped by String key from a Dataset> - * Dataset> ds = ...; - * KeyValueGroupedDataset grouped = - * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); - * }}} - * - * @since 3.5.0 - */ - def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { - mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder) - } - - /** - * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over - * the Dataset to extract the keys and then running a distinct operation on those. - * - * @since 3.5.0 - */ - def keys: Dataset[K] = { - throw new UnsupportedOperationException - } + def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = unsupported() - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an iterator containing elements of an arbitrary type which - * will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { - flatMapSortedGroups()(f) - } + /** @inheritdoc */ + def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = + unsupported() - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an iterator containing elements of an arbitrary type which - * will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder) - } + /** @inheritdoc */ + def keys: Dataset[K] = unsupported() - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and a sorted iterator that contains all of the elements - * in the group. The function can return an iterator containing elements of an arbitrary type - * which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapSortedGroups[U: Encoder](sortExprs: Column*)( - f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and a sorted iterator that contains all of the elements - * in the group. The function can return an iterator containing elements of an arbitrary type - * which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ - def flatMapSortedGroups[U]( - SortExprs: Array[Column], - f: FlatMapGroupsFunction[K, V, U], - encoder: Encoder[U]): Dataset[U] = { - import org.apache.spark.util.ArrayImplicits._ - flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)( - UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an element of arbitrary type which will be returned as a - * new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { - flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f)) - } - - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an element of arbitrary type which will be returned as a - * new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Reduces the elements of each group of data using the specified binary - * function. The given function must be commutative and associative or the result may be - * non-deterministic. - * - * @since 3.5.0 - */ - def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Reduces the elements of each group of data using the specified binary - * function. The given function must be commutative and associative or the result may be - * non-deterministic. - * - * @since 3.5.0 - */ - def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { - reduceGroups(UdfUtils.mapReduceFuncToScalaFunc(f)) - } - - /** - * Internal helper function for building typed aggregations that return tuples. For simplicity - * and code reuse, we do this without the help of the type system and then use helper functions - * that cast appropriately for the user facing interface. - */ - protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - throw new UnsupportedOperationException - } - - /** - * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the - * result of computing this aggregation over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = - aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = - aggUntyped(col1, col2, col3, col4, col5, col6) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6, U7]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6], - col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = - aggUntyped(col1, col2, col3, col4, col5, col6, col7) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6, U7, U8]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6], - col7: TypedColumn[V, U7], - col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = - aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] - - /** - * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for - * that key. - * - * @since 3.5.0 - */ - def count(): Dataset[(K, Long)] = agg(functions.count("*")) + f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = + unsupported() - /** - * (Scala-specific) Applies the given function to each cogrouped data. For each unique group, - * the function will be passed the grouping key and 2 iterators containing all elements in the - * group from [[Dataset]] `this` and `other`. The function can return an iterator containing - * elements of an arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 3.5.0 - */ - def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - cogroupSorted(other)()()(f) - } + /** @inheritdoc */ + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = unsupported() - /** - * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the - * function will be passed the grouping key and 2 iterators containing all elements in the group - * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements - * of an arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 3.5.0 - */ - def cogroup[U, R]( - other: KeyValueGroupedDataset[K, U], - f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) - } + /** @inheritdoc */ + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = unsupported() - /** - * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique - * group, the function will be passed the grouping key and 2 sorted iterators containing all - * elements in the group from [[Dataset]] `this` and `other`. The function can return an - * iterator containing elements of an arbitrary type which will be returned as a new - * [[Dataset]]. - * - * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)( - otherSortExprs: Column*)( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique - * group, the function will be passed the grouping key and 2 sorted iterators containing all - * elements in the group from [[Dataset]] `this` and `other`. The function can return an - * iterator containing elements of an arbitrary type which will be returned as a new - * [[Dataset]]. - * - * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ - def cogroupSorted[U, R]( - other: KeyValueGroupedDataset[K, U], - thisSortExprs: Array[Column], - otherSortExprs: Array[Column], - f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - import org.apache.spark.util.ArrayImplicits._ - cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)( - otherSortExprs.toImmutableArraySeq: _*)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) - } + otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = + unsupported() protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder]( outputMode: Option[OutputMode], timeoutConf: GroupStateTimeout, initialState: Option[KeyValueGroupedDataset[K, S]], isMapGroupWithState: Boolean)( - func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { - throw new UnsupportedOperationException - } + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = unsupported() - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { mapGroupsWithState(GroupStateTimeout.NoTimeout)(func) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { flatMapGroupsWithStateHelper(None, timeoutConf, None, isMapGroupWithState = true)( - UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func)) + UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func)) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param timeoutConf - * Timeout Conf, see GroupStateTimeout for more details - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to - * a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, initialState: KeyValueGroupedDataset[K, S])( @@ -548,134 +107,10 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { None, timeoutConf, Some(initialState), - isMapGroupWithState = true)(UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func)) - } - - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - mapGroupsWithState[S, U](UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))( - stateEncoder, - outputEncoder) - } - - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): Dataset[U] = { - mapGroupsWithState[S, U](timeoutConf)(UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))( - stateEncoder, - outputEncoder) - } - - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - mapGroupsWithState[S, U](timeoutConf, initialState)( - UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(stateEncoder, outputEncoder) + isMapGroupWithState = true)(UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func)) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout)( @@ -687,33 +122,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { isMapGroupWithState = false)(func) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To covert a Dataset `ds` of type of type - * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use - * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[Encoder]] for more details on what - * types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -726,201 +135,244 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { isMapGroupWithState = false)(func) } - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def flatMapGroupsWithState[S, U]( + /** @inheritdoc */ + private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode): Dataset[U] = + unsupported() + + /** @inheritdoc */ + private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + unsupported() + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + eventTimeColumnName: String, + outputMode: OutputMode): Dataset[U] = unsupported() + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + eventTimeColumnName: String, + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported() + + // Overrides... + /** @inheritdoc */ + override def mapValues[W]( + func: MapFunction[V, W], + encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = super.mapValues(func, encoder) + + /** @inheritdoc */ + override def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = + super.flatMapGroups(f) + + /** @inheritdoc */ + override def flatMapGroups[U]( + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = super.flatMapGroups(f, encoder) + + /** @inheritdoc */ + override def flatMapSortedGroups[U]( + SortExprs: Array[Column], + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = super.flatMapSortedGroups(SortExprs, f, encoder) + + /** @inheritdoc */ + override def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = super.mapGroups(f) + + /** @inheritdoc */ + override def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = + super.mapGroups(f, encoder) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState) + + /** @inheritdoc */ + override def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): Dataset[U] = { - val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func) - flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) - } + timeoutConf: GroupStateTimeout): Dataset[U] = + super.flatMapGroupsWithState(func, outputMode, stateEncoder, outputEncoder, timeoutConf) - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To covert a Dataset `ds` of type of type - * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use - * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def flatMapGroupsWithState[S, U]( + /** @inheritdoc */ + override def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func) - flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(f)( - stateEncoder, - outputEncoder) - } - - /** - * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state - * API v2. We allow the user to act on per-group set of input rows along with keyed state and - * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will - * repeatedly invoke the interface methods for new rows in each trigger and the user's - * state/state variables will be stored persistently across invocations. Currently this operator - * is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - */ - def transformWithState[U: Encoder]( + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = super.flatMapGroupsWithState( + func, + outputMode, + stateEncoder, + outputEncoder, + timeoutConf, + initialState) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, - outputMode: OutputMode): Dataset[U] = { - throw new UnsupportedOperationException - } + outputMode: OutputMode, + outputEncoder: Encoder[U]) = + super.transformWithState(statefulProcessor, timeMode, outputMode, outputEncoder) - /** - * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API - * v2. We allow the user to act on per-group set of input rows along with keyed state and the - * user can choose to output/return 0 or more rows. For a streaming dataframe, we will - * repeatedly invoke the interface methods for new rows in each trigger and the user's - * state/state variables will be stored persistently across invocations. Currently this operator - * is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param outputEncoder - * Encoder for the output type. - */ - def transformWithState[U: Encoder]( + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], - timeMode: TimeMode, + eventTimeColumnName: String, outputMode: OutputMode, - outputEncoder: Encoder[U]): Dataset[U] = { - throw new UnsupportedOperationException - } + outputEncoder: Encoder[U]) = + super.transformWithState(statefulProcessor, eventTimeColumnName, outputMode, outputEncoder) - /** - * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state - * API v2. Functions as the function above, but with additional initial state. Currently this - * operator is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S - * The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param initialState - * User provided initial state that will be used to initiate state for the query in the first - * batch. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - def transformWithState[U: Encoder, S: Encoder]( + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API - * v2. Functions as the function above, but with additional initial state. Currently this - * operator is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S - * The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param initialState - * User provided initial state that will be used to initiate state for the query in the first - * batch. - * @param outputEncoder - * Encoder for the output type. - * @param initialStateEncoder - * Encoder for the initial state type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( + initialState: KeyValueGroupedDataset[K, S], + outputEncoder: Encoder[U], + initialStateEncoder: Encoder[S]) = super.transformWithState( + statefulProcessor, + timeMode, + outputMode, + initialState, + outputEncoder, + initialStateEncoder) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], - timeMode: TimeMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S], + eventTimeColumnName: String, outputEncoder: Encoder[U], - initialStateEncoder: Encoder[S]): Dataset[U] = { - throw new UnsupportedOperationException - } + initialStateEncoder: Encoder[S]) = super.transformWithState( + statefulProcessor, + outputMode, + initialState, + eventTimeColumnName, + outputEncoder, + initialStateEncoder) + + /** @inheritdoc */ + override def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = super.reduceGroups(f) + + /** @inheritdoc */ + override def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = super.agg(col1) + + /** @inheritdoc */ + override def agg[U1, U2]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = super.agg(col1, col2) + + /** @inheritdoc */ + override def agg[U1, U2, U3]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = super.agg(col1, col2, col3) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = super.agg(col1, col2, col3, col4) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = + super.agg(col1, col2, col3, col4, col5) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = + super.agg(col1, col2, col3, col4, col5, col6) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6, U7]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = + super.agg(col1, col2, col3, col4, col5, col6, col7) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6, U7, U8]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7], + col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = + super.agg(col1, col2, col3, col4, col5, col6, col7, col8) + + /** @inheritdoc */ + override def count(): Dataset[(K, Long)] = super.count() + + /** @inheritdoc */ + override def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = + super.cogroup(other)(f) + + /** @inheritdoc */ + override def cogroup[U, R]( + other: KeyValueGroupedDataset[K, U], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = super.cogroup(other, f, encoder) + + /** @inheritdoc */ + override def cogroupSorted[U, R]( + other: KeyValueGroupedDataset[K, U], + thisSortExprs: Array[Column], + otherSortExprs: Array[Column], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = + super.cogroupSorted(other, thisSortExprs, otherSortExprs, f, encoder) } /** @@ -933,21 +385,20 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( private val sparkSession: SparkSession, private val plan: proto.Plan, - private val ikEncoder: AgnosticEncoder[IK], private val kEncoder: AgnosticEncoder[K], private val ivEncoder: AgnosticEncoder[IV], private val vEncoder: AgnosticEncoder[V], private val groupingExprs: java.util.List[proto.Expression], - private val valueMapFunc: IV => V, + private val valueMapFunc: Option[IV => V], private val keysFunc: () => Dataset[IK]) extends KeyValueGroupedDataset[K, V] { + import sparkSession.RichColumn override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = { new KeyValueGroupedDatasetImpl[L, V, IK, IV]( sparkSession, plan, - ikEncoder, - encoderFor[L], + agnosticEncoderFor[L], ivEncoder, vEncoder, groupingExprs, @@ -959,12 +410,13 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( new KeyValueGroupedDatasetImpl[K, W, IK, IV]( sparkSession, plan, - ikEncoder, kEncoder, ivEncoder, - encoderFor[W], + agnosticEncoderFor[W], groupingExprs, - valueMapFunc.andThen(valueFunc), + valueMapFunc + .map(_.andThen(valueFunc)) + .orElse(Option(valueFunc.asInstanceOf[IV => W])), keysFunc) } @@ -977,9 +429,8 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( override def flatMapSortedGroups[U: Encoder](sortExprs: Column*)( f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { // Apply mapValues changes to the udf - val nf = - if (valueMapFunc == UdfUtils.identical()) f else UdfUtils.mapValuesAdaptor(f, valueMapFunc) - val outputEncoder = encoderFor[U] + val nf = UDFAdaptors.flatMapGroupsWithMappedValues(f, valueMapFunc) + val outputEncoder = agnosticEncoderFor[U] sparkSession.newDataset[U](outputEncoder) { builder => builder.getGroupMapBuilder .setInput(plan.getRoot) @@ -992,11 +443,10 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( override def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( thisSortExprs: Column*)(otherSortExprs: Column*)( f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - assert(other.isInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]]) - val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]] + val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, Any]] // Apply mapValues changes to the udf - val nf = UdfUtils.mapValuesAdaptor(f, valueMapFunc, otherImpl.valueMapFunc) - val outputEncoder = encoderFor[R] + val nf = UDFAdaptors.coGroupWithMappedValues(f, valueMapFunc, otherImpl.valueMapFunc) + val outputEncoder = agnosticEncoderFor[R] sparkSession.newDataset[R](outputEncoder) { builder => builder.getCoGroupMapBuilder .setInput(plan.getRoot) @@ -1010,26 +460,26 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( } override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - // TODO(SPARK-43415): For each column, apply the valueMap func first - val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(_.encoder)) // apply keyAs change + // TODO(SPARK-43415): For each column, apply the valueMap func first... + val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => agnosticEncoderFor(c.encoder))) sparkSession.newDataset(rEnc) { builder => builder.getAggregateBuilder .setInput(plan.getRoot) .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) .addAllGroupingExpressions(groupingExprs) - .addAllAggregateExpressions(columns.map(_.expr).asJava) + .addAllAggregateExpressions(columns.map(_.typedExpr(vEncoder)).asJava) } } override def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { val inputEncoders = Seq(vEncoder, vEncoder) - val udf = ScalaUserDefinedFunction( + val udf = SparkUserDefinedFunction( function = f, inputEncoders = inputEncoders, outputEncoder = vEncoder) val input = udf.apply(inputEncoders.map(_ => col("*")): _*) - val expr = Column.fn("reduce", input).expr - val aggregator: TypedColumn[V, V] = new TypedColumn[V, V](expr, vEncoder) + val expr = Column.fn("reduce", input) + val aggregator: TypedColumn[V, V] = new TypedColumn[V, V](expr.node, vEncoder) agg(aggregator) } @@ -1044,22 +494,15 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( throw new IllegalArgumentException("The output mode of function should be append or update") } - if (initialState.isDefined) { - assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) - } - val initialStateImpl = if (initialState.isDefined) { + assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]] } else { null } - val outputEncoder = encoderFor[U] - val nf = if (valueMapFunc == UdfUtils.identical()) { - func - } else { - UdfUtils.mapValuesAdaptor(func, valueMapFunc) - } + val outputEncoder = agnosticEncoderFor[U] + val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func, valueMapFunc) sparkSession.newDataset[U](outputEncoder) { builder => val groupMapBuilder = builder.getGroupMapBuilder @@ -1083,7 +526,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: AgnosticEncoder[U])( inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction = { val inputEncoders = kEncoder +: inEncoders // Apply keyAs changes by setting kEncoder - val udf = ScalaUserDefinedFunction( + val udf = SparkUserDefinedFunction( function = nf, inputEncoders = inputEncoders, outputEncoder = outputEncoder) @@ -1094,6 +537,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( * We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a class clash on the * server side. We null out the instance for now. */ + @unused("this is used by java serialization") private def writeReplace(): Any = null } @@ -1102,19 +546,19 @@ private object KeyValueGroupedDatasetImpl { ds: Dataset[V], kEncoder: AgnosticEncoder[K], groupingFunc: V => K): KeyValueGroupedDatasetImpl[K, V, K, V] = { - val gf = ScalaUserDefinedFunction( + val gf = SparkUserDefinedFunction( function = groupingFunc, inputEncoders = ds.agnosticEncoder :: Nil, // Using the original value and key encoders outputEncoder = kEncoder) + val session = ds.sparkSession new KeyValueGroupedDatasetImpl( - ds.sparkSession, + session, ds.plan, kEncoder, - kEncoder, ds.agnosticEncoder, ds.agnosticEncoder, - Arrays.asList(gf.apply(col("*")).expr), - UdfUtils.identical(), + Arrays.asList(toExpr(gf.apply(col("*")))), + None, () => ds.map(groupingFunc)(kEncoder)) } @@ -1124,20 +568,19 @@ private object KeyValueGroupedDatasetImpl { vEncoder: AgnosticEncoder[V], groupingExprs: Seq[Column]): KeyValueGroupedDatasetImpl[K, V, K, V] = { // Use a dummy udf to pass the K V encoders - val dummyGroupingFunc = ScalaUserDefinedFunction( + val dummyGroupingFunc = SparkUserDefinedFunction( function = UdfUtils.noOp[V, K](), inputEncoders = vEncoder :: Nil, outputEncoder = kEncoder).apply(col("*")) - + val session = df.sparkSession new KeyValueGroupedDatasetImpl( - df.sparkSession, + session, df.plan, kEncoder, - kEncoder, vEncoder, vEncoder, - (Seq(dummyGroupingFunc) ++ groupingExprs).map(_.expr).asJava, - UdfUtils.identical(), + (Seq(dummyGroupingFunc) ++ groupingExprs).map(toExpr).asJava, + None, () => df.select(groupingExprs: _*).as(kEncoder)) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 754906f736cf0..5bded40b0d132 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql -import java.util.Locale - import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor +import org.apache.spark.sql.connect.ConnectConversions._ /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -36,343 +36,125 @@ import org.apache.spark.connect.proto * @since 3.4.0 */ class RelationalGroupedDataset private[sql] ( - private[sql] val df: DataFrame, + protected val df: DataFrame, private[sql] val groupingExprs: Seq[Column], groupType: proto.Aggregate.GroupType, pivot: Option[proto.Aggregate.Pivot] = None, - groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) { + groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) + extends api.RelationalGroupedDataset { + import df.sparkSession.RichColumn - private[this] def toDF(aggExprs: Seq[Column]): DataFrame = { + protected def toDF(aggExprs: Seq[Column]): DataFrame = { df.sparkSession.newDataFrame { builder => - builder.getAggregateBuilder + val aggBuilder = builder.getAggregateBuilder .setInput(df.plan.getRoot) - .addAllGroupingExpressions(groupingExprs.map(_.expr).asJava) - .addAllAggregateExpressions(aggExprs.map(e => e.expr).asJava) + groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr)) + aggExprs.foreach(c => aggBuilder.addAggregateExpressions(c.typedExpr(df.encoder))) groupType match { case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) case proto.Aggregate.GroupType.GROUP_TYPE_CUBE => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE) case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT => assert(pivot.isDefined) - builder.getAggregateBuilder + aggBuilder .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT) .setPivot(pivot.get) case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS => assert(groupingSets.isDefined) - val aggBuilder = builder.getAggregateBuilder - .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS) groupingSets.get.foreach(aggBuilder.addGroupingSets) case g => throw new UnsupportedOperationException(g.toString) } } } - /** - * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of - * current `RelationalGroupedDataset`. - * - * @since 3.5.0 - */ - def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { - KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs) + protected def selectNumericColumns(colNames: Seq[String]): Seq[Column] = { + // This behaves different than the classic implementation. The classic implementation validates + // if a column is actually a number, and if it is not it throws an error immediately. In connect + // it depends on the input type (casting) rules for the method invoked. If the input violates + // the a different error will be thrown. However it is also possible to get a result for a + // non-numeric column in connect, for example when you use min/max. + colNames.map(df.col) } - /** - * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The - * resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} - * - * @since 3.4.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - toDF((aggExpr +: aggExprs).map { case (colName, expr) => - strToColumn(expr, df(colName)) - }) + /** @inheritdoc */ + def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { + KeyValueGroupedDatasetImpl[K, T]( + df, + agnosticEncoderFor[K], + agnosticEncoderFor[T], + groupingExprs) } - /** - * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate - * methods. The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg(Map( - * "age" -> "max", - * "expense" -> "sum" - * )) - * }}} - * - * @since 3.4.0 - */ - def agg(exprs: Map[String, String]): DataFrame = { - toDF(exprs.map { case (colName, expr) => - strToColumn(expr, df(colName)) - }.toSeq) - } + /** @inheritdoc */ + override def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = + super.agg(aggExpr, aggExprs: _*) - /** - * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods. - * The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); - * }}} - * - * @since 3.4.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.asScala.toMap) - } + /** @inheritdoc */ + override def agg(exprs: Map[String, String]): DataFrame = super.agg(exprs) - private[this] def strToColumn(expr: String, inputExpr: Column): Column = { - expr.toLowerCase(Locale.ROOT) match { - case "avg" | "average" | "mean" => functions.avg(inputExpr) - case "stddev" | "std" => functions.stddev(inputExpr) - case "count" | "size" => functions.count(inputExpr) - case name => Column.fn(name, inputExpr) - } - } + /** @inheritdoc */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = super.agg(exprs) - /** - * Compute aggregates by specifying a series of aggregate columns. Note that this function by - * default retains the grouping columns in its output. To not retain grouping columns, set - * `spark.sql.retainGroupColumns` to false. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change - * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. - * {{{ - * // Scala, 1.3.x: - * df.groupBy("department").agg($"department", max("age"), sum("expense")) - * - * // Java, 1.3.x: - * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map { case c => - c - // TODO: deal with typed columns. - }) - } + override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) - /** - * Count the number of rows for each group. The resulting `DataFrame` will also contain the - * grouping columns. - * - * @since 3.4.0 - */ - def count(): DataFrame = toDF(Seq(functions.count(functions.lit(1)).alias("count"))) + /** @inheritdoc */ + override def count(): DataFrame = super.count() - /** - * Compute the average value for each numeric columns for each group. This is an alias for - * `avg`. The resulting `DataFrame` will also contain the grouping columns. When specified - * columns are given, only compute the average values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def mean(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.mean(colName))) - } + override def mean(colNames: String*): DataFrame = super.mean(colNames: _*) - /** - * Compute the max value for each numeric columns for each group. The resulting `DataFrame` will - * also contain the grouping columns. When specified columns are given, only compute the max - * values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def max(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.max(colName))) - } + override def max(colNames: String*): DataFrame = super.max(colNames: _*) - /** - * Compute the mean value for each numeric columns for each group. The resulting `DataFrame` - * will also contain the grouping columns. When specified columns are given, only compute the - * mean values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def avg(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.avg(colName))) - } + override def avg(colNames: String*): DataFrame = super.avg(colNames: _*) - /** - * Compute the min value for each numeric column for each group. The resulting `DataFrame` will - * also contain the grouping columns. When specified columns are given, only compute the min - * values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def min(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.min(colName))) - } + override def min(colNames: String*): DataFrame = super.min(colNames: _*) - /** - * Compute the sum for each numeric columns for each group. The resulting `DataFrame` will also - * contain the grouping columns. When specified columns are given, only compute the sum for - * them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def sum(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.sum(colName))) - } + override def sum(colNames: String*): DataFrame = super.sum(colNames: _*) - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the - * resulting schema of the transformation. To avoid any eager computations, provide an explicit - * list of values via `pivot(pivotColumn: String, values: Seq[Any])`. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @since 3.4.0 - */ - def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) + /** @inheritdoc */ + override def pivot(pivotColumn: String): RelationalGroupedDataset = super.pivot(pivotColumn) - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are - * two versions of pivot function: one that requires the caller to specify the list of distinct - * values to pivot on, and one that does not. The latter is more concise but less efficient, - * because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by - * multiple columns, use the `struct` function to combine the columns and values: - * - * {{{ - * df.groupBy("year") - * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts")))) - * .agg(sum($"earnings")) - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) - } + /** @inheritdoc */ + override def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = + super.pivot(pivotColumn, values) + + /** @inheritdoc */ + override def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = + super.pivot(pivotColumn, values) - /** - * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. - * - * There are two versions of pivot function: one that requires the caller to specify the list of - * distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) + /** @inheritdoc */ + override def pivot( + pivotColumn: Column, + values: java.util.List[Any]): RelationalGroupedDataset = { + super.pivot(pivotColumn, values) } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. This is an - * overloaded version of the `pivot` method with `pivotColumn` of the `String` type. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ + /** @inheritdoc */ def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => - val valueExprs = values.map(_ match { + val valueExprs = values.map { case c: Column if c.expr.hasLiteral => c.expr.getLiteral case c: Column if !c.expr.hasLiteral => throw new IllegalArgumentException("values only accept literal Column") case v => functions.lit(v).expr.getLiteral - }) + } new RelationalGroupedDataset( df, groupingExprs, @@ -388,46 +170,8 @@ class RelationalGroupedDataset private[sql] ( } } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the - * resulting schema of the transformation. To avoid any eager computations, provide an explicit - * list of values via `pivot(pivotColumn: Column, values: Seq[Any])`. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course").sum($"earnings"); - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * he column to pivot. - * @since 3.4.0 - */ + /** @inheritdoc */ def pivot(pivotColumn: Column): RelationalGroupedDataset = { - pivot(pivotColumn, Seq()) - } - - /** - * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of the - * `String` type. - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(pivotColumn, values.asScala.toSeq) + pivot(pivotColumn, Nil) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 7799d395d5c6a..4690253da808b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -16,283 +16,8 @@ */ package org.apache.spark.sql -import scala.collection.Map -import scala.language.implicitConversions -import scala.reflect.classTag -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ - -/** - * A collection of implicit methods for converting names and Symbols into [[Column]]s, and for - * converting common Scala objects into [[Dataset]]s. - * - * @since 3.4.0 - */ -abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrioritySQLImplicits { - - /** - * Converts $"col name" into a [[Column]]. - * - * @since 3.4.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 3.4.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** @since 3.4.0 */ - implicit val newIntEncoder: Encoder[Int] = PrimitiveIntEncoder - - /** @since 3.4.0 */ - implicit val newLongEncoder: Encoder[Long] = PrimitiveLongEncoder - - /** @since 3.4.0 */ - implicit val newDoubleEncoder: Encoder[Double] = PrimitiveDoubleEncoder - - /** @since 3.4.0 */ - implicit val newFloatEncoder: Encoder[Float] = PrimitiveFloatEncoder - - /** @since 3.4.0 */ - implicit val newByteEncoder: Encoder[Byte] = PrimitiveByteEncoder - - /** @since 3.4.0 */ - implicit val newShortEncoder: Encoder[Short] = PrimitiveShortEncoder - - /** @since 3.4.0 */ - implicit val newBooleanEncoder: Encoder[Boolean] = PrimitiveBooleanEncoder - - /** @since 3.4.0 */ - implicit val newStringEncoder: Encoder[String] = StringEncoder - - /** @since 3.4.0 */ - implicit val newJavaDecimalEncoder: Encoder[java.math.BigDecimal] = - AgnosticEncoders.DEFAULT_JAVA_DECIMAL_ENCODER - - /** @since 3.4.0 */ - implicit val newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = - AgnosticEncoders.DEFAULT_SCALA_DECIMAL_ENCODER - - /** @since 3.4.0 */ - implicit val newDateEncoder: Encoder[java.sql.Date] = AgnosticEncoders.STRICT_DATE_ENCODER - - /** @since 3.4.0 */ - implicit val newLocalDateEncoder: Encoder[java.time.LocalDate] = - AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER - - /** @since 3.4.0 */ - implicit val newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = - AgnosticEncoders.LocalDateTimeEncoder - - /** @since 3.4.0 */ - implicit val newTimeStampEncoder: Encoder[java.sql.Timestamp] = - AgnosticEncoders.STRICT_TIMESTAMP_ENCODER - - /** @since 3.4.0 */ - implicit val newInstantEncoder: Encoder[java.time.Instant] = - AgnosticEncoders.STRICT_INSTANT_ENCODER - - /** @since 3.4.0 */ - implicit val newDurationEncoder: Encoder[java.time.Duration] = DayTimeIntervalEncoder - - /** @since 3.4.0 */ - implicit val newPeriodEncoder: Encoder[java.time.Period] = YearMonthIntervalEncoder - - /** @since 3.4.0 */ - implicit def newJavaEnumEncoder[A <: java.lang.Enum[_]: TypeTag]: Encoder[A] = { - ScalaReflection.encoderFor[A] - } - - // Boxed primitives - - /** @since 3.4.0 */ - implicit val newBoxedIntEncoder: Encoder[java.lang.Integer] = BoxedIntEncoder - - /** @since 3.4.0 */ - implicit val newBoxedLongEncoder: Encoder[java.lang.Long] = BoxedLongEncoder - - /** @since 3.4.0 */ - implicit val newBoxedDoubleEncoder: Encoder[java.lang.Double] = BoxedDoubleEncoder - - /** @since 3.4.0 */ - implicit val newBoxedFloatEncoder: Encoder[java.lang.Float] = BoxedFloatEncoder - - /** @since 3.4.0 */ - implicit val newBoxedByteEncoder: Encoder[java.lang.Byte] = BoxedByteEncoder - - /** @since 3.4.0 */ - implicit val newBoxedShortEncoder: Encoder[java.lang.Short] = BoxedShortEncoder - - /** @since 3.4.0 */ - implicit val newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = BoxedBooleanEncoder - - // Seqs - private def newSeqEncoder[E](elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Seq[E]] = { - IterableEncoder( - classTag[Seq[E]], - elementEncoder, - elementEncoder.nullable, - elementEncoder.lenientSerialization) - } - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newIntSeqEncoder: Encoder[Seq[Int]] = newSeqEncoder(PrimitiveIntEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newLongSeqEncoder: Encoder[Seq[Long]] = newSeqEncoder(PrimitiveLongEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newDoubleSeqEncoder: Encoder[Seq[Double]] = newSeqEncoder(PrimitiveDoubleEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newFloatSeqEncoder: Encoder[Seq[Float]] = newSeqEncoder(PrimitiveFloatEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newByteSeqEncoder: Encoder[Seq[Byte]] = newSeqEncoder(PrimitiveByteEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newShortSeqEncoder: Encoder[Seq[Short]] = newSeqEncoder(PrimitiveShortEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newBooleanSeqEncoder: Encoder[Seq[Boolean]] = newSeqEncoder(PrimitiveBooleanEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newStringSeqEncoder: Encoder[Seq[String]] = newSeqEncoder(StringEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newProductSeqEncoder[A <: Product: TypeTag]: Encoder[Seq[A]] = - newSeqEncoder(ScalaReflection.encoderFor[A]) - - /** @since 3.4.0 */ - implicit def newSequenceEncoder[T <: Seq[_]: TypeTag]: Encoder[T] = - ScalaReflection.encoderFor[T] - - // Maps - /** @since 3.4.0 */ - implicit def newMapEncoder[T <: Map[_, _]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] - - /** - * Notice that we serialize `Set` to Catalyst array. The set property is only kept when - * manipulating the domain objects. The serialization format doesn't keep the set property. When - * we have a Catalyst array which contains duplicated elements and convert it to - * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. - * - * @since 3.4.0 - */ - implicit def newSetEncoder[T <: Set[_]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] - - // Arrays - private def newArrayEncoder[E]( - elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Array[E]] = { - ArrayEncoder(elementEncoder, elementEncoder.nullable) - } - - /** @since 3.4.0 */ - implicit val newIntArrayEncoder: Encoder[Array[Int]] = newArrayEncoder(PrimitiveIntEncoder) - - /** @since 3.4.0 */ - implicit val newLongArrayEncoder: Encoder[Array[Long]] = newArrayEncoder(PrimitiveLongEncoder) - - /** @since 3.4.0 */ - implicit val newDoubleArrayEncoder: Encoder[Array[Double]] = - newArrayEncoder(PrimitiveDoubleEncoder) - - /** @since 3.4.0 */ - implicit val newFloatArrayEncoder: Encoder[Array[Float]] = newArrayEncoder( - PrimitiveFloatEncoder) - - /** @since 3.4.0 */ - implicit val newByteArrayEncoder: Encoder[Array[Byte]] = BinaryEncoder - - /** @since 3.4.0 */ - implicit val newShortArrayEncoder: Encoder[Array[Short]] = newArrayEncoder( - PrimitiveShortEncoder) - - /** @since 3.4.0 */ - implicit val newBooleanArrayEncoder: Encoder[Array[Boolean]] = - newArrayEncoder(PrimitiveBooleanEncoder) - - /** @since 3.4.0 */ - implicit val newStringArrayEncoder: Encoder[Array[String]] = newArrayEncoder(StringEncoder) - - /** @since 3.4.0 */ - implicit def newProductArrayEncoder[A <: Product: TypeTag]: Encoder[Array[A]] = { - newArrayEncoder(ScalaReflection.encoderFor[A]) - } - - /** - * Creates a [[Dataset]] from a local Seq. - * @since 3.4.0 - */ - implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] = { - DatasetHolder(session.createDataset(s)) - } -} - -/** - * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. Conflicting - * implicits are placed here to disambiguate resolution. - * - * Reasons for including specific implicits: newProductEncoder - to disambiguate for `List`s which - * are both `Seq` and `Product` - */ -trait LowPrioritySQLImplicits { - - /** @since 3.4.0 */ - implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] = - ScalaReflection.encoderFor[T] +/** @inheritdoc */ +abstract class SQLImplicits private[sql] (override val session: SparkSession) + extends api.SQLImplicits { + type DS[U] = Dataset[U] } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2e54617928aad..b74d0c2ff2243 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -16,35 +16,42 @@ */ package org.apache.spark.sql -import java.io.Closeable import java.net.URI +import java.nio.file.{Files, Paths} import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.{AtomicLong, AtomicReference} +import java.util.concurrent.atomic.AtomicLong import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import com.google.common.cache.{CacheBuilder, CacheLoader} import io.grpc.ClientInterceptor import org.apache.arrow.memory.RootAllocator +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder} +import org.apache.spark.sql.connect.ConnectClientUnsupportedErrors import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf} +import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, SessionCleaner, SessionState, SharedState, SqlApiConf} +import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, toTypedExpr} +import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.ArrayImplicits._ /** @@ -68,8 +75,7 @@ import org.apache.spark.util.ArrayImplicits._ class SparkSession private[sql] ( private[sql] val client: SparkConnectClient, private val planIdGenerator: AtomicLong) - extends Serializable - with Closeable + extends api.SparkSession with Logging { private[this] val allocator = new RootAllocator() @@ -84,50 +90,22 @@ class SparkSession private[sql] ( private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]() - private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = { + private[sql] def hijackServerSideSessionIdForTesting(suffix: String): Unit = { client.hijackServerSideSessionIdForTesting(suffix) } - /** - * Runtime configuration interface for Spark. - * - * This is the interface through which the user can get and set all Spark configurations that - * are relevant to Spark SQL. When getting the value of a config, his defaults to the value set - * in server, if any. - * - * @since 3.4.0 - */ - val conf: RuntimeConfig = new RuntimeConfig(client) + /** @inheritdoc */ + override def sparkContext: SparkContext = + throw ConnectClientUnsupportedErrors.sparkContext() - /** - * Executes some code block and prints to stdout the time taken to execute the block. This is - * available in Scala only and is used primarily for interactive testing and debugging. - * - * @since 3.4.0 - */ - def time[T](f: => T): T = { - val start = System.nanoTime() - val ret = f - val end = System.nanoTime() - // scalastyle:off println - println(s"Time taken: ${NANOSECONDS.toMillis(end - start)} ms") - // scalastyle:on println - ret - } + /** @inheritdoc */ + val conf: RuntimeConfig = new ConnectRuntimeConfig(client) - /** - * Returns a `DataFrame` with no rows or columns. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @transient val emptyDataFrame: DataFrame = emptyDataset(UnboundRowEncoder) - /** - * Creates a new [[Dataset]] of type T containing zero elements. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def emptyDataset[T: Encoder]: Dataset[T] = createDataset[T](Nil) private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = { @@ -150,113 +128,120 @@ class SparkSession private[sql] ( } } - /** - * Creates a `DataFrame` from a local Seq of Product. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DataFrame = { createDataset(ScalaReflection.encoderFor[A], data.iterator).toDF() } - /** - * :: DeveloperApi :: Creates a `DataFrame` from a `java.util.List` containing [[Row]]s using - * the given schema. It is important to make sure that the structure of every [[Row]] of the - * provided List matches the provided schema. Otherwise, there will be runtime exception. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { createDataset(RowEncoder.encoderFor(schema), rows.iterator().asScala).toDF() } - /** - * Applies a schema to a List of Java Beans. - * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, SELECT * queries - * will return the columns in an undefined order. - * @since 3.4.0 - */ + /** @inheritdoc */ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { val encoder = JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]]) createDataset(encoder, data.iterator().asScala).toDF() } - /** - * Creates a [[Dataset]] from a local Seq of data of a given type. This method requires an - * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL - * representation) that is generally created automatically through implicits from a - * `SparkSession`, or can be created explicitly by calling static methods on [[Encoders]]. - * - * ==Example== - * - * {{{ - * - * import spark.implicits._ - * case class Person(name: String, age: Long) - * val data = Seq(Person("Michael", 29), Person("Andy", 30), Person("Justin", 19)) - * val ds = spark.createDataset(data) - * - * ds.show() - * // +-------+---+ - * // | name|age| - * // +-------+---+ - * // |Michael| 29| - * // | Andy| 30| - * // | Justin| 19| - * // +-------+---+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = { - createDataset(encoderFor[T], data.iterator) + createDataset(agnosticEncoderFor[T], data.iterator) } - /** - * Creates a [[Dataset]] from a `java.util.List` of a given type. This method requires an - * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL - * representation) that is generally created automatically through implicits from a - * `SparkSession`, or can be created explicitly by calling static methods on [[Encoders]]. - * - * ==Java Example== - * - * {{{ - * List data = Arrays.asList("hello", "world"); - * Dataset ds = spark.createDataset(data, Encoders.STRING()); - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def createDataset[T: Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala.toSeq) } - /** - * Executes a SQL query substituting positional parameters by the given arguments, returning the - * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries. - * - * @param sqlText - * A SQL statement with positional parameters to execute. - * @param args - * An array of Java/Scala objects that can be converted to SQL literal expressions. See Supported Data - * Types for supported value types in Scala/Java. For example: 1, "Steven", - * LocalDate.of(2023, 4, 2). A value can be also a `Column` of a literal or collection - * constructor functions such as `map()`, `array()`, `struct()`, in that case it is taken as - * is. - * - * @since 3.5.0 - */ - @Experimental - def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder => + /** @inheritdoc */ + override def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = + throw ConnectClientUnsupportedErrors.rdd() + + /** @inheritdoc */ + override def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = + throw ConnectClientUnsupportedErrors.rdd() + + /** @inheritdoc */ + override def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = + throw ConnectClientUnsupportedErrors.rdd() + + /** @inheritdoc */ + override def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = + throw ConnectClientUnsupportedErrors.rdd() + + /** @inheritdoc */ + override def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = + throw ConnectClientUnsupportedErrors.rdd() + + /** @inheritdoc */ + override def createDataset[T: Encoder](data: RDD[T]): Dataset[T] = + throw ConnectClientUnsupportedErrors.rdd() + + /** @inheritdoc */ + override def sharedState: SharedState = + throw ConnectClientUnsupportedErrors.sharedState() + + /** @inheritdoc */ + override def sessionState: SessionState = + throw ConnectClientUnsupportedErrors.sessionState() + + /** @inheritdoc */ + override def sqlContext: SQLContext = + throw ConnectClientUnsupportedErrors.sqlContext() + + /** @inheritdoc */ + override def listenerManager: ExecutionListenerManager = + throw ConnectClientUnsupportedErrors.listenerManager() + + /** @inheritdoc */ + override def experimental: ExperimentalMethods = + throw ConnectClientUnsupportedErrors.experimental() + + /** @inheritdoc */ + override def baseRelationToDataFrame(baseRelation: BaseRelation): api.Dataset[Row] = + throw ConnectClientUnsupportedErrors.baseRelationToDataFrame() + + /** @inheritdoc */ + override def executeCommand( + runner: String, + command: String, + options: Map[String, String]): DataFrame = + throw ConnectClientUnsupportedErrors.executeCommand() + + /** @inheritdoc */ + def sql(sqlText: String, args: Array[_]): DataFrame = { + val sqlCommand = proto.SqlCommand + .newBuilder() + .setSql(sqlText) + .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava) + .build() + sql(sqlCommand) + } + + /** @inheritdoc */ + def sql(sqlText: String, args: Map[String, Any]): DataFrame = { + sql(sqlText, args.asJava) + } + + /** @inheritdoc */ + override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = { + val sqlCommand = proto.SqlCommand + .newBuilder() + .setSql(sqlText) + .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava) + .build() + sql(sqlCommand) + } + + /** @inheritdoc */ + override def sql(query: String): DataFrame = { + sql(query, Array.empty) + } + + private def sql(sqlCommand: proto.SqlCommand): DataFrame = newDataFrame { builder => // Send the SQL once to the server and then check the output. - val cmd = newCommand(b => - b.setSqlCommand( - proto.SqlCommand - .newBuilder() - .setSql(sqlText) - .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) + val cmd = newCommand(b => b.setSqlCommand(sqlCommand)) val plan = proto.Plan.newBuilder().setCommand(cmd) val responseIter = client.execute(plan.build()) @@ -272,206 +257,53 @@ class SparkSession private[sql] ( } } - /** - * Executes a SQL query substituting named parameters by the given arguments, returning the - * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries. - * - * @param sqlText - * A SQL statement with named parameters to execute. - * @param args - * A map of parameter names to Java/Scala objects that can be converted to SQL literal - * expressions. See - * Supported Data Types for supported value types in Scala/Java. For example, map keys: - * "rank", "name", "birthdate"; map values: 1, "Steven", LocalDate.of(2023, 4, 2). Map value - * can be also a `Column` of a literal or collection constructor functions such as `map()`, - * `array()`, `struct()`, in that case it is taken as is. - * - * @since 3.4.0 - */ - @Experimental - def sql(sqlText: String, args: Map[String, Any]): DataFrame = { - sql(sqlText, args.asJava) - } - - /** - * Executes a SQL query substituting named parameters by the given arguments, returning the - * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries. - * - * @param sqlText - * A SQL statement with named parameters to execute. - * @param args - * A map of parameter names to Java/Scala objects that can be converted to SQL literal - * expressions. See - * Supported Data Types for supported value types in Scala/Java. For example, map keys: - * "rank", "name", "birthdate"; map values: 1, "Steven", LocalDate.of(2023, 4, 2). Map value - * can be also a `Column` of a literal or collection constructor functions such as `map()`, - * `array()`, `struct()`, in that case it is taken as is. - * - * @since 3.4.0 - */ - @Experimental - def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = newDataFrame { - builder => - // Send the SQL once to the server and then check the output. - val cmd = newCommand(b => - b.setSqlCommand( - proto.SqlCommand - .newBuilder() - .setSql(sqlText) - .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava))) - val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseIter = client.execute(plan.build()) - - try { - val response = responseIter - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) - } finally { - // consume the rest of the iterator - responseIter.foreach(_ => ()) - } - } - - /** - * Executes a SQL query using Spark, returning the result as a `DataFrame`. This API eagerly - * runs DDL/DML commands, but not for SELECT queries. - * - * @since 3.4.0 - */ - def sql(query: String): DataFrame = { - sql(query, Array.empty) - } - - /** - * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a - * `DataFrame`. - * {{{ - * sparkSession.read.parquet("/path/to/file.parquet") - * sparkSession.read.schema(schema).json("/path/to/file.json") - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(this) - /** - * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. - * {{{ - * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") - * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") - * }}} - * - * @since 3.5.0 - */ + /** @inheritdoc */ def readStream: DataStreamReader = new DataStreamReader(this) + /** @inheritdoc */ + def tvf: TableValuedFunction = new TableValuedFunction(this) + + /** @inheritdoc */ lazy val streams: StreamingQueryManager = new StreamingQueryManager(this) - /** - * Interface through which the user may create, drop, alter or query underlying databases, - * tables, functions etc. - * - * @since 3.5.0 - */ + /** @inheritdoc */ lazy val catalog: Catalog = new CatalogImpl(this) - /** - * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch - * reading and the returned DataFrame is the batch scan query plan of this table. If it's a - * view, the returned DataFrame is simply the query plan of the view, which can either be a - * batch or streaming query plan. - * - * @param tableName - * is either a qualified or unqualified name that designates a table or view. If a database is - * specified, it identifies the table/view from the database. Otherwise, it first attempts to - * find a temporary view with the given name and then match the table/view from the current - * database. Note that, the global temporary view database is also valid here. - * @since 3.4.0 - */ + /** @inheritdoc */ def table(tableName: String): DataFrame = { read.table(tableName) } - /** - * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a - * range from 0 to `end` (exclusive) with step value 1. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def range(end: Long): Dataset[java.lang.Long] = range(0, end) - /** - * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a - * range from `start` to `end` (exclusive) with step value 1. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1) } - /** - * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a - * range from `start` to `end` (exclusive) with a step value. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, None) } - /** - * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a - * range from `start` to `end` (exclusive) with a step value, with partition number specified. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { range(start, end, step, Option(numPartitions)) } - /** - * A collection of methods for registering user-defined functions (UDF). - * - * The following example registers a Scala closure as UDF: - * {{{ - * sparkSession.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1) - * }}} - * - * The following example registers a UDF in Java: - * {{{ - * sparkSession.udf().register("myUDF", - * (Integer arg1, String arg2) -> arg2 + arg1, - * DataTypes.StringType); - * }}} - * - * @note - * The user-defined functions must be deterministic. Due to optimization, duplicate - * invocations may be eliminated or the function may even be invoked more times than it is - * present in the query. - * - * @since 3.5.0 - */ + /** @inheritdoc */ lazy val udf: UDFRegistration = new UDFRegistration(this) // scalastyle:off - // Disable style checker so "implicits" object can start with lowercase i - /** - * (Scala-specific) Implicit methods available in Scala for converting common names and Symbols - * into [[Column]]s, and for converting common Scala objects into DataFrame`s. - * - * {{{ - * val sparkSession = SparkSession.builder.getOrCreate() - * import sparkSession.implicits._ - * }}} - * - * @since 3.4.0 - */ - object implicits extends SQLImplicits(this) with Serializable + /** @inheritdoc */ + object implicits extends SQLImplicits(this) // scalastyle:on + /** @inheritdoc */ def newSession(): SparkSession = { SparkSession.builder().client(client.copy()).create() } @@ -574,77 +406,30 @@ class SparkSession private[sql] ( execute(command) } - /** - * Add a single artifact to the client session. - * - * Currently only local files with extensions .jar and .class are supported. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @Experimental - def addArtifact(path: String): Unit = client.addArtifact(path) + override def addArtifact(path: String): Unit = client.addArtifact(path) - /** - * Add a single artifact to the client session. - * - * Currently it supports local files with extensions .jar and .class and Apache Ivy URIs - * - * @since 3.4.0 - */ + /** @inheritdoc */ @Experimental - def addArtifact(uri: URI): Unit = client.addArtifact(uri) + override def addArtifact(uri: URI): Unit = client.addArtifact(uri) - /** - * Add a single in-memory artifact to the session while preserving the directory structure - * specified by `target` under the session's working directory of that particular file - * extension. - * - * Supported target file extensions are .jar and .class. - * - * ==Example== - * {{{ - * addArtifact(bytesBar, "foo/bar.class") - * addArtifact(bytesFlat, "flat.class") - * // Directory structure of the session's working directory for class files would look like: - * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class - * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class - * }}} - * - * @since 4.0.0 - */ + /** @inheritdoc */ @Experimental - def addArtifact(bytes: Array[Byte], target: String): Unit = client.addArtifact(bytes, target) + override def addArtifact(bytes: Array[Byte], target: String): Unit = { + client.addArtifact(bytes, target) + } - /** - * Add a single artifact to the session while preserving the directory structure specified by - * `target` under the session's working directory of that particular file extension. - * - * Supported target file extensions are .jar and .class. - * - * ==Example== - * {{{ - * addArtifact("/Users/dummyUser/files/foo/bar.class", "foo/bar.class") - * addArtifact("/Users/dummyUser/files/flat.class", "flat.class") - * // Directory structure of the session's working directory for class files would look like: - * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class - * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class - * }}} - * - * @since 4.0.0 - */ + /** @inheritdoc */ @Experimental - def addArtifact(source: String, target: String): Unit = client.addArtifact(source, target) + override def addArtifact(source: String, target: String): Unit = { + client.addArtifact(source, target) + } - /** - * Add one or more artifacts to the session. - * - * Currently it supports local files with extensions .jar and .class and Apache Ivy URIs - * - * @since 3.4.0 - */ + /** @inheritdoc */ @Experimental @scala.annotation.varargs - def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri) + override def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri) /** * Register a ClassFinder for dynamically generated classes. @@ -671,7 +456,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptAll(): Seq[String] = { + override def interruptAll(): Seq[String] = { client.interruptAll().getInterruptedIdsList.asScala.toSeq } @@ -684,7 +469,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptTag(tag: String): Seq[String] = { + override def interruptTag(tag: String): Seq[String] = { client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq } @@ -697,17 +482,10 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptOperation(operationId: String): Seq[String] = { + override def interruptOperation(operationId: String): Seq[String] = { client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq } - /** - * Synonym for `close()`. - * - * @since 3.4.0 - */ - def stop(): Unit = close() - /** * Close the [[SparkSession]]. * @@ -735,65 +513,17 @@ class SparkSession private[sql] ( SparkSession.onSessionClose(this) } - /** - * Add a tag to be assigned to all the operations started by this thread in this session. - * - * Often, a unit of execution in an application consists of multiple Spark executions. - * Application programmers can use this method to group all those jobs together and give a group - * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all - * running running executions with this tag. For example: - * {{{ - * // In the main thread: - * spark.addTag("myjobs") - * spark.range(10).map(i => { Thread.sleep(10); i }).collect() - * - * // In a separate thread: - * spark.interruptTag("myjobs") - * }}} - * - * There may be multiple tags present at the same time, so different parts of application may - * use different tags to perform cancellation at different levels of granularity. - * - * @param tag - * The tag to be added. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def addTag(tag: String): Unit = { - client.addTag(tag) - } + /** @inheritdoc */ + override def addTag(tag: String): Unit = client.addTag(tag) - /** - * Remove a tag previously added to be assigned to all the operations started by this thread in - * this session. Noop if such a tag was not added earlier. - * - * @param tag - * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def removeTag(tag: String): Unit = { - client.removeTag(tag) - } + /** @inheritdoc */ + override def removeTag(tag: String): Unit = client.removeTag(tag) - /** - * Get the tags that are currently set to be assigned to all the operations started by this - * thread. - * - * @since 3.5.0 - */ - def getTags(): Set[String] = { - client.getTags() - } + /** @inheritdoc */ + override def getTags(): Set[String] = client.getTags() - /** - * Clear the current thread's operation tags. - * - * @since 3.5.0 - */ - def clearTags(): Unit = { - client.clearTags() - } + /** @inheritdoc */ + override def clearTags(): Unit = client.clearTags() /** * We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side. @@ -807,26 +537,36 @@ class SparkSession private[sql] ( private[sql] var releaseSessionOnClose = true private[sql] def registerObservation(planId: Long, observation: Observation): Unit = { - if (observationRegistry.putIfAbsent(planId, observation) != null) { - throw new IllegalArgumentException("An Observation can be used with a Dataset only once") - } + observation.markRegistered() + observationRegistry.putIfAbsent(planId, observation) } - private[sql] def setMetricsAndUnregisterObservation( - planId: Long, - metrics: Map[String, Any]): Unit = { + private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics: Row): Unit = { val observationOrNull = observationRegistry.remove(planId) if (observationOrNull != null) { - observationOrNull.setMetricsAndNotify(Some(metrics)) + observationOrNull.setMetricsAndNotify(metrics) } } + + override private[sql] def isUsable: Boolean = client.isSessionValid + + implicit class RichColumn(c: Column) { + def expr: proto.Expression = toExpr(c) + def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e) + } } // The minimal builder needed to create a spark session. // TODO: implements all methods mentioned in the scaladoc of [[SparkSession]] -object SparkSession extends Logging { +object SparkSession extends api.BaseSparkSessionCompanion with Logging { + override private[sql] type Session = SparkSession + private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong + private var server: Option[Process] = None + private[sql] val sparkOptions = sys.props.filter { p => + p._1.startsWith("spark.") && p._2.nonEmpty + }.toMap private val sessions = CacheBuilder .newBuilder() @@ -836,27 +576,49 @@ object SparkSession extends Logging { override def load(c: Configuration): SparkSession = create(c) }) - /** The active SparkSession for the current thread. */ - private val activeThreadSession = new InheritableThreadLocal[SparkSession] + /** + * Create a new Spark Connect server to connect locally. + */ + private[sql] def withLocalConnectServer[T](f: => T): T = { + synchronized { + val remoteString = sparkOptions + .get("spark.remote") + .orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit + .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE)) + + val maybeConnectScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) + + if (server.isEmpty && + remoteString.exists(_.startsWith("local")) && + maybeConnectScript.exists(Files.exists(_))) { + server = Some { + val args = + Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions + .filter(p => !p._1.startsWith("spark.remote")) + .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } + val pb = new ProcessBuilder(args: _*) + // So don't exclude spark-sql jar in classpath + pb.environment().remove(SparkConnectClient.SPARK_REMOTE) + pb.start() + } - /** Reference to the root SparkSession. */ - private val defaultSession = new AtomicReference[SparkSession] + // Let the server start. We will directly request to set the configurations + // and this sleep makes less noisy with retries. + Thread.sleep(2000L) + System.setProperty("spark.remote", "sc://localhost") - /** - * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when - * they are not set yet or the associated [[SparkConnectClient]] is unusable. - */ - private def setDefaultAndActiveSession(session: SparkSession): Unit = { - val currentDefault = defaultSession.getAcquire - if (currentDefault == null || !currentDefault.client.isSessionValid) { - // Update `defaultSession` if it is null or the contained session is not valid. There is a - // chance that the following `compareAndSet` fails if a new default session has just been set, - // but that does not matter since that event has happened after this method was invoked. - defaultSession.compareAndSet(currentDefault, session) - } - if (getActiveSession.isEmpty) { - setActiveSession(session) + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = if (server.isDefined) { + new ProcessBuilder(maybeConnectScript.get.toString) + .start() + } + }) + // scalastyle:on runtimeaddshutdownhook + } } + f } /** @@ -866,17 +628,6 @@ object SparkSession extends Logging { new SparkSession(configuration.toSparkConnectClient, planIdGenerator) } - /** - * Hook called when a session is closed. - */ - private[sql] def onSessionClose(session: SparkSession): Unit = { - sessions.invalidate(session.client.configuration) - defaultSession.compareAndSet(session, null) - if (getActiveSession.contains(session)) { - clearActiveSession() - } - } - /** * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. * @@ -884,15 +635,15 @@ object SparkSession extends Logging { */ def builder(): Builder = new Builder() - class Builder() extends Logging { + class Builder() extends api.SparkSessionBuilder { // Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE // by default, if it exists. The connection string can be overridden using // the remote() function, as it takes precedence over the SPARK_REMOTE environment variable. private val builder = SparkConnectClient.builder().loadFromEnvironment() private var client: SparkConnectClient = _ - private[this] val options = new scala.collection.mutable.HashMap[String, String] - def remote(connectionString: String): Builder = { + /** @inheritdoc */ + def remote(connectionString: String): this.type = { builder.connectionString(connectionString) this } @@ -904,93 +655,52 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def interceptor(interceptor: ClientInterceptor): Builder = { + def interceptor(interceptor: ClientInterceptor): this.type = { builder.interceptor(interceptor) this } - private[sql] def client(client: SparkConnectClient): Builder = { + private[sql] def client(client: SparkConnectClient): this.type = { this.client = client this } - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: String): Builder = synchronized { - options += key -> value - this - } + /** @inheritdoc */ + override def config(key: String, value: String): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Long): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Long): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Double): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Double): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Boolean): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Boolean): this.type = super.config(key, value) - /** - * Sets a config a map of options. Options set using this method are automatically propagated - * to the Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(map: Map[String, Any]): Builder = synchronized { - map.foreach { kv: (String, Any) => - { - options += kv._1 -> kv._2.toString - } - } - this - } + /** @inheritdoc */ + override def config(map: Map[String, Any]): this.type = super.config(map) - /** - * Sets a config option. Options set using this method are automatically propagated to both - * `SparkConf` and SparkSession's own configuration. - * - * @since 3.5.0 - */ - def config(map: java.util.Map[String, Any]): Builder = synchronized { - config(map.asScala.toMap) - } + /** @inheritdoc */ + override def config(map: java.util.Map[String, Any]): this.type = super.config(map) + /** @inheritdoc */ + override def config(conf: SparkConf): Builder.this.type = super.config(conf) + + /** @inheritdoc */ @deprecated("enableHiveSupport does not work in Spark Connect") - def enableHiveSupport(): Builder = this + override def enableHiveSupport(): this.type = this + /** @inheritdoc */ @deprecated("master does not work in Spark Connect, please use remote instead") - def master(master: String): Builder = this + override def master(master: String): this.type = this + /** @inheritdoc */ @deprecated("appName does not work in Spark Connect") - def appName(name: String): Builder = this + override def appName(name: String): this.type = this + + /** @inheritdoc */ + @deprecated("withExtensions does not work in Spark Connect") + override def withExtensions(f: SparkSessionExtensions => Unit): this.type = this private def tryCreateSessionFromClient(): Option[SparkSession] = { if (client != null && client.isSessionValid) { @@ -1001,6 +711,16 @@ object SparkSession extends Logging { } private def applyOptions(session: SparkSession): Unit = { + // Only attempts to set Spark SQL configurations. + // If the configurations are static, it might throw an exception so + // simply ignore it for now. + sparkOptions + .filter { case (k, _) => + k.startsWith("spark.sql.") + } + .foreach { case (key, value) => + Try(session.conf.set(key, value)) + } options.foreach { case (key, value) => session.conf.set(key, value) } @@ -1023,7 +743,7 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def create(): SparkSession = { + def create(): SparkSession = withLocalConnectServer { val session = tryCreateSessionFromClient() .getOrElse(SparkSession.this.create(builder.configuration)) setDefaultAndActiveSession(session) @@ -1043,7 +763,7 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def getOrCreate(): SparkSession = { + def getOrCreate(): SparkSession = withLocalConnectServer { val session = tryCreateSessionFromClient() .getOrElse({ var existingSession = sessions.get(builder.configuration) @@ -1061,71 +781,12 @@ object SparkSession extends Logging { } } - /** - * Returns the default SparkSession. If the previously set default SparkSession becomes - * unusable, returns None. - * - * @since 3.5.0 - */ - def getDefaultSession: Option[SparkSession] = - Option(defaultSession.get()).filter(_.client.isSessionValid) + /** @inheritdoc */ + override def getActiveSession: Option[SparkSession] = super.getActiveSession - /** - * Sets the default SparkSession. - * - * @since 3.5.0 - */ - def setDefaultSession(session: SparkSession): Unit = { - defaultSession.set(session) - } + /** @inheritdoc */ + override def getDefaultSession: Option[SparkSession] = super.getDefaultSession - /** - * Clears the default SparkSession. - * - * @since 3.5.0 - */ - def clearDefaultSession(): Unit = { - defaultSession.set(null) - } - - /** - * Returns the active SparkSession for the current thread. If the previously set active - * SparkSession becomes unusable, returns None. - * - * @since 3.5.0 - */ - def getActiveSession: Option[SparkSession] = - Option(activeThreadSession.get()).filter(_.client.isSessionValid) - - /** - * Changes the SparkSession that will be returned in this thread and its children when - * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives - * an isolated SparkSession. - * - * @since 3.5.0 - */ - def setActiveSession(session: SparkSession): Unit = { - activeThreadSession.set(session) - } - - /** - * Clears the active SparkSession for current thread. - * - * @since 3.5.0 - */ - def clearActiveSession(): Unit = { - activeThreadSession.remove() - } - - /** - * Returns the currently active SparkSession, otherwise the default one. If there is no default - * SparkSession, throws an exception. - * - * @since 3.5.0 - */ - def active: SparkSession = { - getActiveSession - .orElse(getDefaultSession) - .getOrElse(throw new IllegalStateException("No active or default Spark session found")) - } + /** @inheritdoc */ + override def active: SparkSession = super.active } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala new file mode 100644 index 0000000000000..4f2687b537862 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr + +class TableValuedFunction(sparkSession: SparkSession) extends api.TableValuedFunction { + + /** @inheritdoc */ + override def range(end: Long): Dataset[java.lang.Long] = { + sparkSession.range(end) + } + + /** @inheritdoc */ + override def range(start: Long, end: Long): Dataset[java.lang.Long] = { + sparkSession.range(start, end) + } + + /** @inheritdoc */ + override def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { + sparkSession.range(start, end, step) + } + + /** @inheritdoc */ + override def range( + start: Long, + end: Long, + step: Long, + numPartitions: Int): Dataset[java.lang.Long] = { + sparkSession.range(start, end, step, numPartitions) + } + + private def fn(name: String, args: Seq[Column]): Dataset[Row] = { + sparkSession.newDataFrame { builder => + builder.getUnresolvedTableValuedFunctionBuilder + .setFunctionName(name) + .addAllArguments(args.map(toExpr).asJava) + } + } + + /** @inheritdoc */ + override def explode(collection: Column): Dataset[Row] = + fn("explode", Seq(collection)) + + /** @inheritdoc */ + override def explode_outer(collection: Column): Dataset[Row] = + fn("explode_outer", Seq(collection)) + + /** @inheritdoc */ + override def inline(input: Column): Dataset[Row] = + fn("inline", Seq(input)) + + /** @inheritdoc */ + override def inline_outer(input: Column): Dataset[Row] = + fn("inline_outer", Seq(input)) + + /** @inheritdoc */ + override def json_tuple(input: Column, fields: Column*): Dataset[Row] = + fn("json_tuple", input +: fields) + + /** @inheritdoc */ + override def posexplode(collection: Column): Dataset[Row] = + fn("posexplode", Seq(collection)) + + /** @inheritdoc */ + override def posexplode_outer(collection: Column): Dataset[Row] = + fn("posexplode_outer", Seq(collection)) + + /** @inheritdoc */ + override def stack(n: Column, fields: Column*): Dataset[Row] = + fn("stack", n +: fields) + + /** @inheritdoc */ + override def collations(): Dataset[Row] = + fn("collations", Seq.empty) + + /** @inheritdoc */ + override def sql_keywords(): Dataset[Row] = + fn("sql_keywords", Seq.empty) + + /** @inheritdoc */ + override def variant_explode(input: Column): Dataset[Row] = + fn("variant_explode", Seq(input)) + + /** @inheritdoc */ + override def variant_explode_outer(input: Column): Dataset[Row] = + fn("variant_explode_outer", Seq(input)) +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 5965a2b7a61de..93d085a25c7b5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.api.java._ -import org.apache.spark.sql.connect.common.UdfUtils -import org.apache.spark.sql.expressions.{ScalaUserDefinedFunction, UserDefinedFunction} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.internal.UdfToProtoUtils import org.apache.spark.sql.types.DataType /** @@ -34,1261 +30,19 @@ import org.apache.spark.sql.types.DataType * * @since 3.5.0 */ -class UDFRegistration(session: SparkSession) extends Logging { - - /** - * Registers a user-defined function (UDF), for a UDF that's already defined using the Dataset - * API (i.e. of type UserDefinedFunction). To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. To change a UDF to nonNullable, call the API - * `UserDefinedFunction.asNonNullable()`. - * - * Example: - * {{{ - * val foo = udf(() => Math.random()) - * spark.udf.register("random", foo.asNondeterministic()) - * - * val bar = udf(() => "bar") - * spark.udf.register("stringLit", bar.asNonNullable()) - * }}} - * - * @param name - * the name of the UDF. - * @param udf - * the UDF needs to be registered. - * @return - * the registered UDF. - * - * @since 3.5.0 - */ - def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { - udf.withName(name) match { - case scalarUdf: ScalaUserDefinedFunction => - session.registerUdf(scalarUdf.toProto) - scalarUdf - case other => - throw new UnsupportedOperationException( - s"Registering a UDF of type " + - s"${other.getClass.getSimpleName} is currently unsupported.") - } - } - - // scalastyle:off line.size.limit - - /* register 0-22 were generated by this script: - (0 to 22).foreach { x => - val params = (1 to x).map(num => s"A$num").mkString(", ") - val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - println(s""" - |/** - | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). - | * @tparam RT return type of UDF. - | * @since 3.5.0 - | */ - |def register[$typeTags](name: String, func: ($params) => RT): UserDefinedFunction = { - | register(name, functions.udf(func)) - |}""".stripMargin) - } - */ - - /** - * Registers a deterministic Scala closure of 0 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction(func, typeTag[RT]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 1 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[RT: TypeTag, A1: TypeTag](name: String, func: (A1) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction(func, typeTag[RT], typeTag[A1]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 2 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag]( - name: String, - func: (A1, A2) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction(func, typeTag[RT], typeTag[A1], typeTag[A2]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 3 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag]( - name: String, - func: (A1, A2, A3) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction(func, typeTag[RT], typeTag[A1], typeTag[A2], typeTag[A3]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 4 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag]( - name: String, - func: (A1, A2, A3, A4) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 5 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 6 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 7 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 8 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 9 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 10 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 11 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 12 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 13 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13) => RT) - : UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 14 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14) => RT) - : UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 15 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag, - A15: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15) => RT) - : UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14], - typeTag[A15]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 16 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag, - A15: TypeTag, - A16: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16) => RT) - : UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14], - typeTag[A15], - typeTag[A16]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 17 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag, - A15: TypeTag, - A16: TypeTag, - A17: TypeTag]( - name: String, - func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17) => RT) - : UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14], - typeTag[A15], - typeTag[A16], - typeTag[A17]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 18 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag, - A15: TypeTag, - A16: TypeTag, - A17: TypeTag, - A18: TypeTag]( - name: String, - func: ( - A1, - A2, - A3, - A4, - A5, - A6, - A7, - A8, - A9, - A10, - A11, - A12, - A13, - A14, - A15, - A16, - A17, - A18) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14], - typeTag[A15], - typeTag[A16], - typeTag[A17], - typeTag[A18]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 19 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag, - A15: TypeTag, - A16: TypeTag, - A17: TypeTag, - A18: TypeTag, - A19: TypeTag]( - name: String, - func: ( - A1, - A2, - A3, - A4, - A5, - A6, - A7, - A8, - A9, - A10, - A11, - A12, - A13, - A14, - A15, - A16, - A17, - A18, - A19) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14], - typeTag[A15], - typeTag[A16], - typeTag[A17], - typeTag[A18], - typeTag[A19]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 20 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag, - A15: TypeTag, - A16: TypeTag, - A17: TypeTag, - A18: TypeTag, - A19: TypeTag, - A20: TypeTag]( - name: String, - func: ( - A1, - A2, - A3, - A4, - A5, - A6, - A7, - A8, - A9, - A10, - A11, - A12, - A13, - A14, - A15, - A16, - A17, - A18, - A19, - A20) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14], - typeTag[A15], - typeTag[A16], - typeTag[A17], - typeTag[A18], - typeTag[A19], - typeTag[A20]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 21 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag, - A15: TypeTag, - A16: TypeTag, - A17: TypeTag, - A18: TypeTag, - A19: TypeTag, - A20: TypeTag, - A21: TypeTag]( - name: String, - func: ( - A1, - A2, - A3, - A4, - A5, - A6, - A7, - A8, - A9, - A10, - A11, - A12, - A13, - A14, - A15, - A16, - A17, - A18, - A19, - A20, - A21) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14], - typeTag[A15], - typeTag[A16], - typeTag[A17], - typeTag[A18], - typeTag[A19], - typeTag[A20], - typeTag[A21]) - register(name, udf) - } - - /** - * Registers a deterministic Scala closure of 22 arguments as user-defined function (UDF). - * @tparam RT - * return type of UDF. - * @since 3.5.0 - */ - def register[ - RT: TypeTag, - A1: TypeTag, - A2: TypeTag, - A3: TypeTag, - A4: TypeTag, - A5: TypeTag, - A6: TypeTag, - A7: TypeTag, - A8: TypeTag, - A9: TypeTag, - A10: TypeTag, - A11: TypeTag, - A12: TypeTag, - A13: TypeTag, - A14: TypeTag, - A15: TypeTag, - A16: TypeTag, - A17: TypeTag, - A18: TypeTag, - A19: TypeTag, - A20: TypeTag, - A21: TypeTag, - A22: TypeTag]( - name: String, - func: ( - A1, - A2, - A3, - A4, - A5, - A6, - A7, - A8, - A9, - A10, - A11, - A12, - A13, - A14, - A15, - A16, - A17, - A18, - A19, - A20, - A21, - A22) => RT): UserDefinedFunction = { - val udf = ScalaUserDefinedFunction( - func, - typeTag[RT], - typeTag[A1], - typeTag[A2], - typeTag[A3], - typeTag[A4], - typeTag[A5], - typeTag[A6], - typeTag[A7], - typeTag[A8], - typeTag[A9], - typeTag[A10], - typeTag[A11], - typeTag[A12], - typeTag[A13], - typeTag[A14], - typeTag[A15], - typeTag[A16], - typeTag[A17], - typeTag[A18], - typeTag[A19], - typeTag[A20], - typeTag[A21], - typeTag[A22]) - register(name, udf) - } - - // (0 to 22).foreach { i => - // val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") - // val version = "3.5.0" - // println(s""" - // |/** - // | * Register a deterministic Java UDF$i instance as user-defined function (UDF). - // | * @since $version - // | */ - // |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { - // | val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - // | register(name, udf) - // |}""".stripMargin) - // } - - /** - * Register a deterministic Java UDF0 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF0[_], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF1 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF2 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF3 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF4 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF5 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF6 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF7 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF8 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF9 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF9[_, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF10 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF10[_, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF11 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF12 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF13 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF14 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF15 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF16 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF17 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF18 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF19 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF20 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) - } - - /** - * Register a deterministic Java UDF21 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( - name: String, - f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) +class UDFRegistration(session: SparkSession) extends api.UDFRegistration { + override def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + throw new UnsupportedOperationException( + "registerJava is currently not supported in Spark Connect.") } - /** - * Register a deterministic Java UDF22 instance as user-defined function (UDF). - * @since 3.5.0 - */ - def register( + override protected def register( name: String, - f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], - returnType: DataType): Unit = { - val udf = ScalaUserDefinedFunction(UdfUtils.wrap(f), returnType) - register(name, udf) + udf: UserDefinedFunction, + source: String, + validateParameterCount: Boolean): UserDefinedFunction = { + val named = udf.withName(name) + session.registerUdf(UdfToProtoUtils.toProto(named)) + named } - // scalastyle:on line.size.limit } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index 9fd3ae4368f4c..bff6db25a21f2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -23,11 +23,16 @@ import scala.util.control.NonFatal import ammonite.compiler.CodeClassWrapper import ammonite.compiler.iface.CodeWrapper -import ammonite.util.{Bind, Imports, Name, Util} +import ammonite.interp.{Interpreter, Watchable} +import ammonite.main.Defaults +import ammonite.repl.Repl +import ammonite.util.{Bind, Imports, Name, PredefInfo, Ref, Res, Util} +import ammonite.util.Util.newLine +import org.apache.spark.SparkBuildInfo.spark_version import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.SparkSession.withLocalConnectServer import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser} /** @@ -37,29 +42,29 @@ import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClie object ConnectRepl { private val name = "Spark Connect REPL" - private val splash = - """ - |Spark session available as 'spark'. - | _____ __ ______ __ - | / ___/____ ____ ______/ /__ / ____/___ ____ ____ ___ _____/ /_ - | \__ \/ __ \/ __ `/ ___/ //_/ / / / __ \/ __ \/ __ \/ _ \/ ___/ __/ - | ___/ / /_/ / /_/ / / / ,< / /___/ /_/ / / / / / / / __/ /__/ /_ - |/____/ .___/\__,_/_/ /_/|_| \____/\____/_/ /_/_/ /_/\___/\___/\__/ - | /_/ - |""".stripMargin + private val splash: String = """Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + +Type in expressions to have them evaluated. +Spark connect server version %s. +Spark session available as 'spark'. + """ def main(args: Array[String]): Unit = doMain(args) + private var server: Option[Process] = None + private val sparkHome = System.getenv("SPARK_HOME") + private[application] def doMain( args: Array[String], semaphore: Option[Semaphore] = None, inputStream: InputStream = System.in, outputStream: OutputStream = System.out, - errorStream: OutputStream = System.err): Unit = { - // For interpreters, structured logging is disabled by default to avoid generating mixed - // plain text and structured logs on the same console. - Logging.disableStructuredLogging() - + errorStream: OutputStream = System.err): Unit = withLocalConnectServer { // Build the client. val client = try { @@ -97,14 +102,65 @@ object ConnectRepl { |""".stripMargin // Please note that we make ammonite generate classes instead of objects. // Classes tend to have superior serialization behavior when using UDFs. - val main = ammonite.Main( - welcomeBanner = Option(splash), + val main = new ammonite.Main( + welcomeBanner = Option(splash.format(spark_version, spark.version)), predefCode = predefCode, replCodeWrapper = ExtendedCodeClassWrapper, scriptCodeWrapper = ExtendedCodeClassWrapper, inputStream = inputStream, outputStream = outputStream, - errorStream = errorStream) + errorStream = errorStream) { + + override def instantiateRepl(replArgs: IndexedSeq[Bind[_]] = Vector.empty) + : Either[(Res.Failure, Seq[(Watchable.Path, Long)]), Repl] = { + loadedPredefFile.map { predefFileInfoOpt => + val augmentedImports = + if (defaultPredef) Defaults.replImports ++ Interpreter.predefImports + else Imports() + + val argString = replArgs.zipWithIndex + .map { case (b, idx) => + s""" + val ${b.name} = ammonite + .repl + .ReplBridge + .value + .Internal + .replArgs($idx) + .value + .asInstanceOf[${b.typeName.value}] + """ + } + .mkString(newLine) + + new Repl( + this.inputStream, + this.outputStream, + this.errorStream, + storage = storageBackend, + baseImports = augmentedImports, + basePredefs = Seq(PredefInfo(Name("ArgsPredef"), argString, false, None)), + customPredefs = predefFileInfoOpt.toSeq ++ Seq( + PredefInfo(Name("CodePredef"), this.predefCode, false, Some(wd / "(console)"))), + wd = wd, + welcomeBanner = welcomeBanner, + replArgs = replArgs, + initialColors = colors, + replCodeWrapper = replCodeWrapper, + scriptCodeWrapper = scriptCodeWrapper, + alreadyLoadedDependencies = alreadyLoadedDependencies, + importHooks = importHooks, + compilerBuilder = compilerBuilder, + parser = parser(), + initialClassLoader = initialClassLoader, + classPathWhitelist = classPathWhitelist, + warnings = warnings) { + override val prompt = Ref("scala> ") + } + } + } + } + if (semaphore.nonEmpty) { // Used for testing. main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get)) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index cf0fef147ee84..86b1dbe4754e6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -17,660 +17,152 @@ package org.apache.spark.sql.catalog -import scala.jdk.CollectionConverters._ +import java.util -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} +import org.apache.spark.sql.{api, DataFrame, Dataset} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.types.StructType -import org.apache.spark.storage.StorageLevel -/** - * Catalog interface for Spark. To access this, use `SparkSession.catalog`. - * - * @since 3.5.0 - */ -abstract class Catalog { - - /** - * Returns the current database (namespace) in this session. - * - * @since 3.5.0 - */ - def currentDatabase: String - - /** - * Sets the current database (namespace) in this session. - * - * @since 3.5.0 - */ - def setCurrentDatabase(dbName: String): Unit - - /** - * Returns a list of databases (namespaces) available within the current catalog. - * - * @since 3.5.0 - */ - def listDatabases(): Dataset[Database] - - /** - * Returns a list of databases (namespaces) which name match the specify pattern and available - * within the current catalog. - * - * @since 3.5.0 - */ - def listDatabases(pattern: String): Dataset[Database] - - /** - * Returns a list of tables/views in the current database (namespace). This includes all - * temporary views. - * - * @since 3.5.0 - */ - def listTables(): Dataset[Table] - - /** - * Returns a list of tables/views in the specified database (namespace) (the name can be - * qualified with catalog). This includes all temporary views. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listTables(dbName: String): Dataset[Table] - - /** - * Returns a list of tables/views in the specified database (namespace) which name match the - * specify pattern (the name can be qualified with catalog). This includes all temporary views. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listTables(dbName: String, pattern: String): Dataset[Table] - - /** - * Returns a list of functions registered in the current database (namespace). This includes all - * temporary functions. - * - * @since 3.5.0 - */ - def listFunctions(): Dataset[Function] - - /** - * Returns a list of functions registered in the specified database (namespace) (the name can be - * qualified with catalog). This includes all built-in and temporary functions. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String): Dataset[Function] - - /** - * Returns a list of functions registered in the specified database (namespace) which name match - * the specify pattern (the name can be qualified with catalog). This includes all built-in and - * temporary functions. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String, pattern: String): Dataset[Function] - - /** - * Returns a list of columns for the given table/view or temporary view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("table does not exist") - def listColumns(tableName: String): Dataset[Column] - - /** - * Returns a list of columns for the given table/view in the specified database under the Hive - * Metastore. - * - * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with - * qualified table/view name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param tableName - * is an unqualified name that designates a table/view. - * @since 3.5.0 - */ - @throws[AnalysisException]("database or table does not exist") - def listColumns(dbName: String, tableName: String): Dataset[Column] - - /** - * Get the database (namespace) with the specified name (can be qualified with catalog). This - * throws an AnalysisException when the database (namespace) cannot be found. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def getDatabase(dbName: String): Database - - /** - * Get the table or view with the specified name. This table can be a temporary view or a - * table/view. This throws an AnalysisException when no Table can be found. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("table does not exist") - def getTable(tableName: String): Table - - /** - * Get the table or view with the specified name in the specified database under the Hive - * Metastore. This throws an AnalysisException when no Table can be found. - * - * To get table/view in other catalogs, please use `getTable(tableName)` with qualified - * table/view name instead. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database or table does not exist") - def getTable(dbName: String, tableName: String): Table - - /** - * Get the function with the specified name. This function can be a temporary function or a - * function. This throws an AnalysisException when the function cannot be found. - * - * @param functionName - * is either a qualified or unqualified name that designates a function. It follows the same - * resolution rule with SQL: search for built-in/temp functions first then functions in the - * current database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("function does not exist") - def getFunction(functionName: String): Function - - /** - * Get the function with the specified name in the specified database under the Hive Metastore. - * This throws an AnalysisException when the function cannot be found. - * - * To get functions in other catalogs, please use `getFunction(functionName)` with qualified - * function name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param functionName - * is an unqualified name that designates a function in the specified database - * @since 3.5.0 - */ - @throws[AnalysisException]("database or function does not exist") - def getFunction(dbName: String, functionName: String): Function - - /** - * Check if the database (namespace) with the specified name exists (the name can be qualified - * with catalog). - * - * @since 3.5.0 - */ - def databaseExists(dbName: String): Boolean - - /** - * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - def tableExists(tableName: String): Boolean - - /** - * Check if the table or view with the specified name exists in the specified database under the - * Hive Metastore. - * - * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with - * qualified table/view name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param tableName - * is an unqualified name that designates a table. - * @since 3.5.0 - */ - def tableExists(dbName: String, tableName: String): Boolean - - /** - * Check if the function with the specified name exists. This can either be a temporary function - * or a function. - * - * @param functionName - * is either a qualified or unqualified name that designates a function. It follows the same - * resolution rule with SQL: search for built-in/temp functions first then functions in the - * current database (namespace). - * @since 3.5.0 - */ - def functionExists(functionName: String): Boolean - - /** - * Check if the function with the specified name exists in the specified database under the Hive - * Metastore. - * - * To check existence of functions in other catalogs, please use `functionExists(functionName)` - * with qualified function name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param functionName - * is an unqualified name that designates a function. - * @since 3.5.0 - */ - def functionExists(dbName: String, functionName: String): Boolean - - /** - * Creates a table from the given path and returns the corresponding DataFrame. It will use the - * default data source configured by spark.sql.sources.default. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String): DataFrame = { - createTable(tableName, path) - } - - /** - * Creates a table from the given path and returns the corresponding DataFrame. It will use the - * default data source configured by spark.sql.sources.default. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, path: String): DataFrame - - /** - * Creates a table from the given path based on a data source and returns the corresponding - * DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String, source: String): DataFrame = { - createTable(tableName, path, source) - } - - /** - * Creates a table from the given path based on a data source and returns the corresponding - * DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, path: String, source: String): DataFrame - - /** - * Creates a table from the given path based on a data source and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( +/** @inheritdoc */ +abstract class Catalog extends api.Catalog { + + /** @inheritdoc */ + override def listDatabases(): Dataset[Database] + + /** @inheritdoc */ + override def listDatabases(pattern: String): Dataset[Database] + + /** @inheritdoc */ + override def listTables(): Dataset[Table] + + /** @inheritdoc */ + override def listTables(dbName: String): Dataset[Table] + + /** @inheritdoc */ + override def listTables(dbName: String, pattern: String): Dataset[Table] + + /** @inheritdoc */ + override def listFunctions(): Dataset[Function] + + /** @inheritdoc */ + override def listFunctions(dbName: String): Dataset[Function] + + /** @inheritdoc */ + override def listFunctions(dbName: String, pattern: String): Dataset[Function] + + /** @inheritdoc */ + override def listColumns(tableName: String): Dataset[Column] + + /** @inheritdoc */ + override def listColumns(dbName: String, tableName: String): Dataset[Column] + + /** @inheritdoc */ + override def createTable(tableName: String, path: String): DataFrame + + /** @inheritdoc */ + override def createTable(tableName: String, path: String, source: String): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, options) - } - - /** - * Creates a table based on the dataset in a data source and a set of options. Then, returns the - * corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, options.asScala.toMap) - } - - /** - * (Scala-specific) Creates a table from the given path based on a data source and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + description: String, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: Map[String, String]): DataFrame = { - createTable(tableName, source, options) - } - - /** - * (Scala-specific) Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, source: String, options: Map[String, String]): DataFrame - - /** - * Create a table from the given path based on a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + schema: StructType, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options) - } - - /** - * Creates a table based on the dataset in a data source and a set of options. Then, returns the - * corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + description: String, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def listCatalogs(): Dataset[CatalogMetadata] + + /** @inheritdoc */ + override def listCatalogs(pattern: String): Dataset[CatalogMetadata] + + /** @inheritdoc */ + override def createExternalTable(tableName: String, path: String): DataFrame = + super.createExternalTable(tableName, path) + + /** @inheritdoc */ + override def createExternalTable(tableName: String, path: String, source: String): DataFrame = + super.createExternalTable(tableName, path, source) + + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, - description: String, - options: java.util.Map[String, String]): DataFrame = { - createTable( - tableName, - source = source, - description = description, - options = options.asScala.toMap) - } - - /** - * (Scala-specific) Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: util.Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - description: String, - options: Map[String, String]): DataFrame + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, options) - /** - * Create a table based on the dataset in a data source, a schema and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options.asScala.toMap) - } - - /** - * (Scala-specific) Create a table from the given path based on a data source, a schema and a - * set of options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + options: Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, options) + + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, schema: StructType, - options: Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options) - } - - /** - * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: util.Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, schema, options) + + /** @inheritdoc */ + override def createTable( + tableName: String, + source: String, + description: String, + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, description, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, - options: Map[String, String]): DataFrame + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, schema, options) - /** - * Create a table based on the dataset in a data source, a schema and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, schema: StructType, - description: String, - options: java.util.Map[String, String]): DataFrame = { - createTable( - tableName, - source = source, - schema = schema, - description = description, - options = options.asScala.toMap) - } - - /** - * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, schema, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, description: String, - options: Map[String, String]): DataFrame - - /** - * Drops the local temporary view with the given view name in the catalog. If the view has been - * cached before, then it will also be uncached. - * - * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that - * created it, i.e. it will be automatically dropped when the session terminates. It's not tied - * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. - * - * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean in - * Spark 2.1. - * - * @param viewName - * the name of the temporary view to be dropped. - * @return - * true if the view is dropped successfully, false otherwise. - * @since 3.5.0 - */ - def dropTempView(viewName: String): Boolean - - /** - * Drops the global temporary view with the given view name in the catalog. If the view has been - * cached before, then it will also be uncached. - * - * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark - * application, i.e. it will be automatically dropped when the application terminates. It's tied - * to a system preserved database `global_temp`, and we must use the qualified name to refer a - * global temp view, e.g. `SELECT * FROM global_temp.view1`. - * - * @param viewName - * the unqualified name of the temporary view to be dropped. - * @return - * true if the view is dropped successfully, false otherwise. - * @since 3.5.0 - */ - def dropGlobalTempView(viewName: String): Boolean - - /** - * Recovers all the partitions in the directory of a table and update the catalog. Only works - * with a partitioned table, and not a view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def recoverPartitions(tableName: String): Unit - - /** - * Returns true if the table is currently cached in-memory. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def isCached(tableName: String): Boolean - - /** - * Caches the specified table in-memory. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def cacheTable(tableName: String): Unit - - /** - * Caches the specified table with the given storage level. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @param storageLevel - * storage level to cache table. - * @since 3.5.0 - */ - def cacheTable(tableName: String, storageLevel: StorageLevel): Unit - - /** - * Removes the specified table from the in-memory cache. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def uncacheTable(tableName: String): Unit - - /** - * Removes all cached tables from the in-memory cache. - * - * @since 3.5.0 - */ - def clearCache(): Unit - - /** - * Invalidates and refreshes all the cached data and metadata of the given table. For - * performance reasons, Spark SQL or the external data source library it uses might cache - * certain metadata about a table, such as the location of blocks. When those change outside of - * Spark SQL, users should call this function to invalidate the cache. - * - * If this table is cached as an InMemoryRelation, drop the original cached version and make the - * new version cached lazily. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def refreshTable(tableName: String): Unit - - /** - * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` - * that contains the given data source path. Path matching is by prefix, i.e. "/" would - * invalidate everything that is cached. - * - * @since 3.5.0 - */ - def refreshByPath(path: String): Unit - - /** - * Returns the current catalog in this session. - * - * @since 3.5.0 - */ - def currentCatalog(): String - - /** - * Sets the current catalog in this session. - * - * @since 3.5.0 - */ - def setCurrentCatalog(catalogName: String): Unit - - /** - * Returns a list of catalogs available in this session. - * - * @since 3.5.0 - */ - def listCatalogs(): Dataset[CatalogMetadata] - - /** - * Returns a list of catalogs which name match the specify pattern and available in this - * session. - * - * @since 3.5.0 - */ - def listCatalogs(pattern: String): Dataset[CatalogMetadata] + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, schema, description, options) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala new file mode 100644 index 0000000000000..e73bcb8a0059d --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectClientUnsupportedErrors.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import org.apache.spark.SparkUnsupportedOperationException + +private[sql] object ConnectClientUnsupportedErrors { + + private def unsupportedFeatureException( + subclass: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + "UNSUPPORTED_CONNECT_FEATURE." + subclass, + Map.empty[String, String]) + } + + def rdd(): SparkUnsupportedOperationException = + unsupportedFeatureException("RDD") + + def queryExecution(): SparkUnsupportedOperationException = + unsupportedFeatureException("DATASET_QUERY_EXECUTION") + + def executeCommand(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_EXECUTE_COMMAND") + + def baseRelationToDataFrame(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_BASE_RELATION_TO_DATAFRAME") + + def experimental(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_EXPERIMENTAL_METHODS") + + def listenerManager(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_LISTENER_MANAGER") + + def sessionState(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_SESSION_STATE") + + def sharedState(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_SHARED_STATE") + + def sparkContext(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_SPARK_CONTEXT") + + def sqlContext(): SparkUnsupportedOperationException = + unsupportedFeatureException("SESSION_SQL_CONTEXT") +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala new file mode 100644 index 0000000000000..0344152be86e6 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import scala.language.implicitConversions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.connect.proto +import org.apache.spark.sql._ +import org.apache.spark.sql.internal.ProtoColumnNode + +/** + * Conversions from sql interfaces to the Connect specific implementation. + * + * This class is mainly used by the implementation. It is also meant to be used by extension + * developers. + * + * We provide both a trait and an object. The trait is useful in situations where an extension + * developer needs to use these conversions in a project covering multiple Spark versions. They + * can create a shim for these conversions, the Spark 4+ version of the shim implements this + * trait, and shims for older versions do not. + */ +@DeveloperApi +trait ConnectConversions { + implicit def castToImpl(session: api.SparkSession): SparkSession = + session.asInstanceOf[SparkSession] + + implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + ds.asInstanceOf[Dataset[T]] + + implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + rgds.asInstanceOf[RelationalGroupedDataset] + + implicit def castToImpl[K, V]( + kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = + kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] + + /** + * Create a [[Column]] from a [[proto.Expression]] + * + * This method is meant to be used by Connect plugins. We do not guarantee any compatibility + * between (minor) versions. + */ + @DeveloperApi + def column(expr: proto.Expression): Column = { + Column(ProtoColumnNode(expr)) + } + + /** + * Create a [[Column]] using a function that manipulates an [[proto.Expression.Builder]]. + * + * This method is meant to be used by Connect plugins. We do not guarantee any compatibility + * between (minor) versions. + */ + @DeveloperApi + def column(f: proto.Expression.Builder => Unit): Column = { + val builder = proto.Expression.newBuilder() + f(builder) + column(builder.build()) + } + + /** + * Implicit helper that makes it easy to construct a Column from an Expression or an Expression + * builder. This allows developers to create a Column in the same way as in earlier versions of + * Spark (before 4.0). + */ + @DeveloperApi + implicit class ColumnConstructorExt(val c: Column.type) { + def apply(e: proto.Expression): Column = column(e) + } +} + +object ConnectConversions extends ConnectConversions diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala index 4ebc22202b0b7..b359a871d8c28 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala @@ -22,6 +22,8 @@ import java.nio.file.Paths import ammonite.repl.api.Session import ammonite.runtime.SpecialClassLoader +import org.apache.spark.sql.Artifact + /** * A special [[ClassFinder]] for the Ammonite REPL to handle in-memory class files. * diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala deleted file mode 100644 index b177c110285dd..0000000000000 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.expressions - -import scala.reflect.runtime.universe._ - -import org.apache.spark.connect.proto -import org.apache.spark.sql.{encoderFor, Encoder, TypedColumn} -import org.apache.spark.sql.catalyst.ScalaReflection - -/** - * A base class for user-defined aggregations, which can be used in `Dataset` operations to take - * all of the elements of a group and reduce them to a single value. - * - * For example, the following aggregator extracts an `int` from a specific class and adds them up: - * {{{ - * case class Data(i: Int) - * - * val customSummer = new Aggregator[Data, Int, Int] { - * def zero: Int = 0 - * def reduce(b: Int, a: Data): Int = b + a.i - * def merge(b1: Int, b2: Int): Int = b1 + b2 - * def finish(r: Int): Int = r - * def bufferEncoder: Encoder[Int] = Encoders.scalaInt - * def outputEncoder: Encoder[Int] = Encoders.scalaInt - * } - * - * spark.udf.register("customSummer", udaf(customSummer)) - * val ds: Dataset[Data] = ... - * val aggregated = ds.selectExpr("customSummer(i)") - * }}} - * - * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird - * - * @tparam IN - * The input type for the aggregation. - * @tparam BUF - * The type of the intermediate value of the reduction. - * @tparam OUT - * The type of the final output result. - * @since 4.0.0 - */ -@SerialVersionUID(2093413866369130093L) -abstract class Aggregator[-IN, BUF, OUT] extends Serializable { - - /** - * A zero value for this aggregation. Should satisfy the property that any b + zero = b. - * @since 4.0.0 - */ - def zero: BUF - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - * @since 4.0.0 - */ - def reduce(b: BUF, a: IN): BUF - - /** - * Merge two intermediate values. - * @since 4.0.0 - */ - def merge(b1: BUF, b2: BUF): BUF - - /** - * Transform the output of the reduction. - * @since 4.0.0 - */ - def finish(reduction: BUF): OUT - - /** - * Specifies the `Encoder` for the intermediate value type. - * @since 4.0.0 - */ - def bufferEncoder: Encoder[BUF] - - /** - * Specifies the `Encoder` for the final output value type. - * @since 4.0.0 - */ - def outputEncoder: Encoder[OUT] - - /** - * Returns this `Aggregator` as a `TypedColumn` that can be used in `Dataset` operations. - * @since 4.0.0 - */ - def toColumn: TypedColumn[IN, OUT] = { - val ttpe = getInputTypeTag[IN] - val inputEncoder = ScalaReflection.encoderFor(ttpe) - val udaf = - ScalaUserDefinedFunction( - this, - Seq(inputEncoder), - encoderFor(outputEncoder), - aggregate = true) - - val builder = proto.TypedAggregateExpression.newBuilder() - builder.setScalarScalaUdf(udaf.udf) - val expr = proto.Expression.newBuilder().setTypedAggregateExpression(builder).build() - - new TypedColumn(expr, encoderFor(outputEncoder)) - } - - private final def getInputTypeTag[T]: TypeTag[T] = { - val mirror = runtimeMirror(this.getClass.getClassLoader) - val tpe = mirror.classSymbol(this.getClass).toType - // Find the most generic (last in the tree) Aggregator class - val baseAgg = - tpe.baseClasses - .findLast(_.asClass.toType <:< typeOf[Aggregator[_, _, _]]) - .getOrElse(throw new IllegalStateException("Could not find the Aggregator base class.")) - val typeArgs = tpe.baseType(baseAgg).typeArgs - assert( - typeArgs.length == 3, - s"Aggregator should have 3 type arguments, " + - s"but found ${typeArgs.length}: ${typeArgs.mkString}.") - val inType = typeArgs.head - - import scala.reflect.api._ - def areCompatibleMirrors(one: Mirror[_], another: Mirror[_]): Boolean = { - def checkAllParents(target: JavaMirror, candidate: JavaMirror): Boolean = { - var current = candidate.classLoader - while (current != null) { - if (current == target.classLoader) { - return true - } - current = current.getParent - } - false - } - - (one, another) match { - case (a: JavaMirror, b: JavaMirror) => - a == b || checkAllParents(a, b) || checkAllParents(b, a) - case _ => one == another - } - } - - TypeTag( - mirror, - new TypeCreator { - def apply[U <: Universe with Singleton](m: Mirror[U]): U#Type = - if (areCompatibleMirrors(m, mirror)) { - inType.asInstanceOf[U#Type] - } else { - throw new IllegalArgumentException( - s"Type tag defined in [$mirror] cannot be migrated to another mirror [$m].") - } - }) - } -} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala deleted file mode 100644 index dcf7f67551d30..0000000000000 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.expressions - -import scala.collection.mutable -import scala.jdk.CollectionConverters._ -import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal - -import com.google.protobuf.ByteString - -import org.apache.spark.SparkException -import org.apache.spark.connect.proto -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket} -import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils} - -/** - * A user-defined function. To create one, use the `udf` functions in `functions`. - * - * As an example: - * {{{ - * // Define a UDF that returns true or false based on some numeric score. - * val predict = udf((score: Double) => score > 0.5) - * - * // Projects a column that adds a prediction column based on the score column. - * df.select( predict(df("score")) ) - * }}} - * - * @since 3.4.0 - */ -sealed abstract class UserDefinedFunction { - - /** - * Returns true when the UDF can return a nullable value. - * - * @since 3.4.0 - */ - def nullable: Boolean - - /** - * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the - * same input. - * - * @since 3.4.0 - */ - def deterministic: Boolean - - /** - * Returns an expression that invokes the UDF, using the given arguments. - * - * @since 3.4.0 - */ - @scala.annotation.varargs - def apply(exprs: Column*): Column - - /** - * Updates UserDefinedFunction with a given name. - * - * @since 3.4.0 - */ - def withName(name: String): UserDefinedFunction - - /** - * Updates UserDefinedFunction to non-nullable. - * - * @since 3.4.0 - */ - def asNonNullable(): UserDefinedFunction - - /** - * Updates UserDefinedFunction to nondeterministic. - * - * @since 3.4.0 - */ - def asNondeterministic(): UserDefinedFunction -} - -/** - * Holder class for a scala user-defined function and it's input/output encoder(s). - */ -case class ScalaUserDefinedFunction private[sql] ( - // SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class. - serializedUdfPacket: Array[Byte], - inputTypes: Seq[proto.DataType], - outputType: proto.DataType, - name: Option[String], - override val nullable: Boolean, - override val deterministic: Boolean, - aggregate: Boolean) - extends UserDefinedFunction { - - private[expressions] lazy val udf = { - val scalaUdfBuilder = proto.ScalarScalaUDF - .newBuilder() - .setPayload(ByteString.copyFrom(serializedUdfPacket)) - // Send the real inputs and return types to obtain the types without deser the udf bytes. - .addAllInputTypes(inputTypes.asJava) - .setOutputType(outputType) - .setNullable(nullable) - .setAggregate(aggregate) - - scalaUdfBuilder.build() - } - - @scala.annotation.varargs - override def apply(exprs: Column*): Column = Column { builder => - val udfBuilder = builder.getCommonInlineUserDefinedFunctionBuilder - udfBuilder - .setDeterministic(deterministic) - .setScalarScalaUdf(udf) - .addAllArguments(exprs.map(_.expr).asJava) - - name.foreach(udfBuilder.setFunctionName) - } - - override def withName(name: String): ScalaUserDefinedFunction = copy(name = Option(name)) - - override def asNonNullable(): ScalaUserDefinedFunction = copy(nullable = false) - - override def asNondeterministic(): ScalaUserDefinedFunction = copy(deterministic = false) - - def toProto: proto.CommonInlineUserDefinedFunction = { - val builder = proto.CommonInlineUserDefinedFunction.newBuilder() - builder - .setDeterministic(deterministic) - .setScalarScalaUdf(udf) - - name.foreach(builder.setFunctionName) - builder.build() - } -} - -object ScalaUserDefinedFunction { - private val LAMBDA_DESERIALIZATION_ERR_MSG: String = - "cannot assign instance of java.lang.invoke.SerializedLambda to field" - - private def checkDeserializable(bytes: Array[Byte]): Unit = { - try { - SparkSerDeUtils.deserialize(bytes, SparkClassUtils.getContextOrSparkClassLoader) - } catch { - case e: ClassCastException if e.getMessage.contains(LAMBDA_DESERIALIZATION_ERR_MSG) => - throw new SparkException( - "UDF cannot be executed on a Spark cluster: it cannot be deserialized. " + - "This is very likely to be caused by the lambda function (the UDF) having a " + - "self-reference. This is not supported by java serialization.") - case NonFatal(e) => - throw new SparkException( - "UDF cannot be executed on a Spark cluster: it cannot be deserialized.", - e) - } - } - - private[sql] def apply( - function: AnyRef, - returnType: TypeTag[_], - parameterTypes: TypeTag[_]*): ScalaUserDefinedFunction = { - - ScalaUserDefinedFunction( - function = function, - // Input can be a row because the input data schema can be found from the plan. - inputEncoders = - parameterTypes.map(tag => ScalaReflection.encoderForWithRowEncoderSupport(tag)), - // Output cannot be a row as there is no good way to get the return data type. - outputEncoder = ScalaReflection.encoderFor(returnType)) - } - - private[sql] def apply( - function: AnyRef, - inputEncoders: Seq[AgnosticEncoder[_]], - outputEncoder: AgnosticEncoder[_], - aggregate: Boolean = false): ScalaUserDefinedFunction = { - SparkConnectClosureCleaner.clean(function) - val udfPacketBytes = - SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) - checkDeserializable(udfPacketBytes) - ScalaUserDefinedFunction( - serializedUdfPacket = udfPacketBytes, - inputTypes = inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType), - outputType = DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType), - name = None, - nullable = true, - deterministic = true, - aggregate = aggregate) - } - - private[sql] def apply(function: AnyRef, returnType: DataType): ScalaUserDefinedFunction = { - ScalaUserDefinedFunction( - function = function, - inputEncoders = Seq.empty[AgnosticEncoder[_]], - outputEncoder = RowEncoder.encoderForDataType(returnType, lenient = false)) - } -} - -private object SparkConnectClosureCleaner { - def clean(closure: AnyRef): Unit = { - ClosureCleaner.clean(closure, cleanTransitively = true, mutable.Map.empty) - } -} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala similarity index 67% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala rename to connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala index f77dd512ef257..be1a13cb2fed2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala @@ -14,10 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal import org.apache.spark.connect.proto.{ConfigRequest, ConfigResponse, KeyValue} import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.connect.client.SparkConnectClient /** @@ -25,61 +26,31 @@ import org.apache.spark.sql.connect.client.SparkConnectClient * * @since 3.4.0 */ -class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { +class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) + extends RuntimeConfig + with Logging { - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def set(key: String, value: String): Unit = { executeConfigRequest { builder => builder.getSetBuilder.addPairsBuilder().setKey(key).setValue(value) } } - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ - def set(key: String, value: Boolean): Unit = set(key, String.valueOf(value)) - - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ - def set(key: String, value: Long): Unit = set(key, String.valueOf(value)) - - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @throws java.util.NoSuchElementException - * if the key is not set and does not have a default value - * @since 3.4.0 - */ - @throws[NoSuchElementException]("if the key is not set") + /** @inheritdoc */ + @throws[NoSuchElementException]("if the key is not set and there is no default value") def get(key: String): String = getOption(key).getOrElse { throw new NoSuchElementException(key) } - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def get(key: String, default: String): String = { executeConfigRequestSingleValue { builder => builder.getGetWithDefaultBuilder.addPairsBuilder().setKey(key).setValue(default) } } - /** - * Returns all properties set in this conf. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def getAll: Map[String, String] = { val response = executeConfigRequest { builder => builder.getGetAllBuilder @@ -92,11 +63,7 @@ class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { builder.result() } - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def getOption(key: String): Option[String] = { val pair = executeConfigRequestSinglePair { builder => builder.getGetOptionBuilder.addKeys(key) @@ -108,27 +75,14 @@ class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { } } - /** - * Resets the configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def unset(key: String): Unit = { executeConfigRequest { builder => builder.getUnsetBuilder.addKeys(key) } } - /** - * Indicates whether the configuration property with the given key is modifiable in the current - * session. - * - * @return - * `true` if the configuration property is modifiable. For static SQL, Spark Core, invalid - * (not existing) and other non-modifiable configuration properties, the returned value is - * `false`. - * @since 3.4.0 - */ + /** @inheritdoc */ def isModifiable(key: String): Boolean = { val modifiable = executeConfigRequestSingleValue { builder => builder.getIsModifiableBuilder.addKeys(key) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala new file mode 100644 index 0000000000000..58fbfea48afec --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.annotation.Stable +import org.apache.spark.connect.proto +import org.apache.spark.sql.{DataFrameWriter, Dataset, SaveMode} + +/** + * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, key-value + * stores, etc). Use `Dataset.write` to access this. + * + * @since 3.4.0 + */ +@Stable +final class DataFrameWriterImpl[T] private[sql] (ds: Dataset[T]) extends DataFrameWriter[T] { + + /** @inheritdoc */ + override def mode(saveMode: SaveMode): this.type = super.mode(saveMode) + + /** @inheritdoc */ + override def mode(saveMode: String): this.type = super.mode(saveMode) + + /** @inheritdoc */ + override def format(source: String): this.type = super.format(source) + + /** @inheritdoc */ + override def option(key: String, value: String): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = + super.options(options) + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = + super.options(options) + + /** @inheritdoc */ + @scala.annotation.varargs + override def partitionBy(colNames: String*): this.type = super.partitionBy(colNames: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def bucketBy(numBuckets: Int, colName: String, colNames: String*): this.type = + super.bucketBy(numBuckets, colName, colNames: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def sortBy(colName: String, colNames: String*): this.type = + super.sortBy(colName, colNames: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): this.type = + super.clusterBy(colName, colNames: _*) + + /** @inheritdoc */ + def save(path: String): Unit = { + saveInternal(Some(path)) + } + + /** @inheritdoc */ + def save(): Unit = saveInternal(None) + + private def saveInternal(path: Option[String]): Unit = { + executeWriteOperation(builder => path.foreach(builder.setPath)) + } + + private def executeWriteOperation(f: proto.WriteOperation.Builder => Unit): Unit = { + val builder = proto.WriteOperation.newBuilder() + + builder.setInput(ds.plan.getRoot) + + // Set path or table + f(builder) + + // Cannot both be set + require(!(builder.hasPath && builder.hasTable)) + + builder.setMode(mode match { + case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND + case SaveMode.Overwrite => proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE + case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE + case SaveMode.ErrorIfExists => proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS + }) + + if (source.nonEmpty) { + builder.setSource(source) + } + sortColumnNames.foreach(names => builder.addAllSortColumnNames(names.asJava)) + partitioningColumns.foreach(cols => builder.addAllPartitioningColumns(cols.asJava)) + clusteringColumns.foreach(cols => builder.addAllClusteringColumns(cols.asJava)) + + numBuckets.foreach(n => { + val bucketBuilder = proto.WriteOperation.BucketBy.newBuilder() + bucketBuilder.setNumBuckets(n) + bucketColumnNames.foreach(names => bucketBuilder.addAllBucketColumnNames(names.asJava)) + builder.setBucketBy(bucketBuilder) + }) + + extraOptions.foreach { case (k, v) => + builder.putOptions(k, v) + } + + ds.sparkSession.execute(proto.Command.newBuilder().setWriteOperation(builder).build()) + } + + /** @inheritdoc */ + def insertInto(tableName: String): Unit = { + executeWriteOperation(builder => { + builder.setTable( + proto.WriteOperation.SaveTable + .newBuilder() + .setTableName(tableName) + .setSaveMethod( + proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO)) + }) + } + + /** @inheritdoc */ + def saveAsTable(tableName: String): Unit = { + executeWriteOperation(builder => { + builder.setTable( + proto.WriteOperation.SaveTable + .newBuilder() + .setTableName(tableName) + .setSaveMethod( + proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE)) + }) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala new file mode 100644 index 0000000000000..4afa8b6d566c5 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.connect.proto +import org.apache.spark.sql.{Column, DataFrameWriterV2, Dataset} + +/** + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 + * API. + * + * @since 3.4.0 + */ +@Experimental +final class DataFrameWriterV2Impl[T] private[sql] (table: String, ds: Dataset[T]) + extends DataFrameWriterV2[T] { + import ds.sparkSession.RichColumn + + private val builder = proto.WriteOperationV2 + .newBuilder() + .setInput(ds.plan.getRoot) + .setTableName(table) + + /** @inheritdoc */ + override def using(provider: String): this.type = { + builder.setProvider(provider) + this + } + + /** @inheritdoc */ + override def option(key: String, value: String): this.type = { + builder.putOptions(key, value) + this + } + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = { + builder.putAllOptions(options.asJava) + this + } + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = { + builder.putAllOptions(options) + this + } + + /** @inheritdoc */ + override def tableProperty(property: String, value: String): this.type = { + builder.putTableProperties(property, value) + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def partitionedBy(column: Column, columns: Column*): this.type = { + builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava) + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): this.type = { + builder.addAllClusteringColumns((colName +: colNames).asJava) + this + } + + /** @inheritdoc */ + override def create(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE) + } + + /** @inheritdoc */ + override def replace(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE) + } + + /** @inheritdoc */ + override def createOrReplace(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE) + } + + /** @inheritdoc */ + def append(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND) + } + + /** @inheritdoc */ + def overwrite(condition: Column): Unit = { + builder.setOverwriteCondition(condition.expr) + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE) + } + + /** @inheritdoc */ + def overwritePartitions(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS) + } + + private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = { + val command = proto.Command + .newBuilder() + .setWriteOperationV2(builder.setMode(mode)) + .build() + ds.sparkSession.execute(command) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala new file mode 100644 index 0000000000000..fba3c6343558b --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.annotation.Experimental +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{Expression, MergeAction, MergeIntoTableCommand} +import org.apache.spark.connect.proto.MergeAction.ActionType._ +import org.apache.spark.sql.{Column, Dataset, MergeIntoWriter} +import org.apache.spark.sql.functions.expr + +/** + * `MergeIntoWriter` provides methods to define and execute merge actions based on specified + * conditions. + * + * @tparam T + * the type of data in the Dataset. + * @param table + * the name of the target table for the merge operation. + * @param ds + * the source Dataset to merge into the target table. + * @param on + * the merge condition. + * + * @since 4.0.0 + */ +@Experimental +class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Column) + extends MergeIntoWriter[T] { + import ds.sparkSession.RichColumn + + private val builder = MergeIntoTableCommand + .newBuilder() + .setTargetTableName(table) + .setSourceTablePlan(ds.plan.getRoot) + .setMergeCondition(on.expr) + + /** + * Executes the merge operation. + */ + def merge(): Unit = { + if (builder.getMatchActionsCount == 0 && + builder.getNotMatchedActionsCount == 0 && + builder.getNotMatchedBySourceActionsCount == 0) { + throw new SparkRuntimeException( + errorClass = "NO_MERGE_ACTION_SPECIFIED", + messageParameters = Map.empty) + } + ds.sparkSession.execute( + proto.Command + .newBuilder() + .setMergeIntoTableCommand(builder.setWithSchemaEvolution(schemaEvolutionEnabled)) + .build()) + } + + override protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T] = { + builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT_STAR, condition)) + this + } + + override protected[sql] def insert( + condition: Option[Column], + map: Map[String, Column]): MergeIntoWriter[T] = { + builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT, condition, map)) + this + } + + override protected[sql] def updateAll( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction( + buildMergeAction(ACTION_TYPE_UPDATE_STAR, condition), + notMatchedBySource) + } + + override protected[sql] def update( + condition: Option[Column], + map: Map[String, Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction( + buildMergeAction(ACTION_TYPE_UPDATE, condition, map), + notMatchedBySource) + } + + override protected[sql] def delete( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction(buildMergeAction(ACTION_TYPE_DELETE, condition), notMatchedBySource) + } + + private def appendUpdateDeleteAction( + action: Expression, + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + if (notMatchedBySource) { + builder.addNotMatchedBySourceActions(action) + } else { + builder.addMatchActions(action) + } + this + } + + private def buildMergeAction( + actionType: MergeAction.ActionType, + condition: Option[Column], + assignments: Map[String, Column] = Map.empty): Expression = { + val builder = proto.MergeAction.newBuilder().setActionType(actionType) + condition.foreach(c => builder.setCondition(c.expr)) + assignments.foreach { case (k, v) => + builder + .addAssignmentsBuilder() + .setKey(expr(k).expr) + .setValue(v.expr) + } + Expression + .newBuilder() + .setMergeAction(builder) + .build() + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala new file mode 100644 index 0000000000000..409c43f480b8e --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ +import scala.util.control.NonFatal + +import com.google.protobuf.ByteString + +import org.apache.spark.SparkException +import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor +import org.apache.spark.sql.connect.common.DataTypeProtoConverter.toConnectProtoType +import org.apache.spark.sql.connect.common.UdfPacket +import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils} + +/** + * Utility for converting a `UserDefinedFunction` into a Connect Protobuf message. + */ +private[sql] object UdfToProtoUtils { + private val LAMBDA_DESERIALIZATION_ERR_MSG: String = + "cannot assign instance of java.lang.invoke.SerializedLambda to field" + + private def checkDeserializable(bytes: Array[Byte]): Unit = { + try { + SparkSerDeUtils.deserialize(bytes, SparkClassUtils.getContextOrSparkClassLoader) + } catch { + case e: ClassCastException if e.getMessage.contains(LAMBDA_DESERIALIZATION_ERR_MSG) => + throw new SparkException( + "UDF cannot be executed on a Spark cluster: it cannot be deserialized. " + + "This is very likely to be caused by the lambda function (the UDF) having a " + + "self-reference. This is not supported by java serialization.") + case NonFatal(e) => + throw new SparkException( + "UDF cannot be executed on a Spark cluster: it cannot be deserialized.", + e) + } + } + + private[sql] def toUdfPacketBytes( + function: AnyRef, + inputEncoders: Seq[AgnosticEncoder[_]], + outputEncoder: AgnosticEncoder[_]): ByteString = { + ClosureCleaner.clean(function, cleanTransitively = true, mutable.Map.empty) + val bytes = SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) + checkDeserializable(bytes) + ByteString.copyFrom(bytes) + } + + /** + * Convert a [[UserDefinedFunction]] to a [[proto.CommonInlineUserDefinedFunction]]. + */ + private[sql] def toProto( + udf: UserDefinedFunction, + arguments: Seq[proto.Expression] = Nil): proto.CommonInlineUserDefinedFunction = { + val invokeUdf = proto.CommonInlineUserDefinedFunction + .newBuilder() + .setDeterministic(udf.deterministic) + .addAllArguments(arguments.asJava) + val protoUdf = invokeUdf.getScalarScalaUdfBuilder + .setNullable(udf.nullable) + udf match { + case f: SparkUserDefinedFunction => + val outputEncoder = f.outputEncoder + .map(e => agnosticEncoderFor(e)) + .getOrElse(RowEncoder.encoderForDataType(f.dataType, lenient = false)) + val inputEncoders = if (f.inputEncoders.forall(_.isEmpty)) { + Nil // Java UDFs have no bindings for their inputs. + } else { + f.inputEncoders.map(e => agnosticEncoderFor(e.get)) // TODO support Any and UnboundRow. + } + inputEncoders.foreach(e => protoUdf.addInputTypes(toConnectProtoType(e.dataType))) + protoUdf + .setPayload(toUdfPacketBytes(f.f, inputEncoders, outputEncoder)) + .setOutputType(toConnectProtoType(outputEncoder.dataType)) + .setAggregate(false) + f.givenName.foreach(invokeUdf.setFunctionName) + case f: UserDefinedAggregator[_, _, _] => + val outputEncoder = agnosticEncoderFor(f.aggregator.outputEncoder) + val inputEncoder = agnosticEncoderFor(f.inputEncoder) + protoUdf + .setPayload(toUdfPacketBytes(f.aggregator, inputEncoder :: Nil, outputEncoder)) + .addInputTypes(toConnectProtoType(inputEncoder.dataType)) + .setOutputType(toConnectProtoType(outputEncoder.dataType)) + .setAggregate(true) + f.givenName.foreach(invokeUdf.setFunctionName) + } + invokeUdf.build() + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala new file mode 100644 index 0000000000000..34a8a91a0ddf8 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkException +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Expression +import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering.{SORT_NULLS_FIRST, SORT_NULLS_LAST} +import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_DIRECTION_ASCENDING, SORT_DIRECTION_DESCENDING} +import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary, FrameType} +import org.apache.spark.sql.{Column, Encoder} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder +import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregator, UserDefinedFunction} + +/** + * Converter for [[ColumnNode]] to [[proto.Expression]] conversions. + */ +object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) { + def toExpr(column: Column): proto.Expression = apply(column.node, None) + + def toTypedExpr[I](column: Column, encoder: Encoder[I]): proto.Expression = { + apply(column.node, Option(encoder)) + } + + override def apply(node: ColumnNode): Expression = apply(node, None) + + private def apply(node: ColumnNode, e: Option[Encoder[_]]): proto.Expression = { + val builder = proto.Expression.newBuilder() + // TODO(SPARK-49273) support Origin in Connect Scala Client. + node match { + case Literal(value, None, _) => + builder.setLiteral(toLiteralProtoBuilder(value)) + + case Literal(value, Some(dataType), _) => + builder.setLiteral(toLiteralProtoBuilder(value, dataType)) + + case u @ UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + val escapedName = u.sql + val b = builder.getUnresolvedAttributeBuilder + .setUnparsedIdentifier(escapedName) + if (isMetadataColumn) { + // We only set this field when it is needed. If we would always set it, + // too many of the verbatims we use for testing would have to be regenerated. + b.setIsMetadataColumn(true) + } + planId.foreach(b.setPlanId) + + case UnresolvedStar(unparsedTarget, planId, _) => + val b = builder.getUnresolvedStarBuilder + unparsedTarget.foreach(b.setUnparsedTarget) + planId.foreach(b.setPlanId) + + case UnresolvedRegex(regex, planId, _) => + val b = builder.getUnresolvedRegexBuilder + .setColName(regex) + planId.foreach(b.setPlanId) + + case UnresolvedFunction(functionName, arguments, isDistinct, isUserDefinedFunction, _, _) => + // TODO(SPARK-49087) use internal namespace. + builder.getUnresolvedFunctionBuilder + .setFunctionName(functionName) + .setIsUserDefinedFunction(isUserDefinedFunction) + .setIsDistinct(isDistinct) + .addAllArguments(arguments.map(apply(_, e)).asJava) + + case Alias(child, name, metadata, _) => + val b = builder.getAliasBuilder.setExpr(apply(child, e)) + name.foreach(b.addName) + metadata.foreach(m => b.setMetadata(m.json)) + + case Cast(child, dataType, evalMode, _) => + val b = builder.getCastBuilder + .setExpr(apply(child, e)) + .setType(DataTypeProtoConverter.toConnectProtoType(dataType)) + evalMode.foreach { mode => + val convertedMode = mode match { + case Cast.Try => proto.Expression.Cast.EvalMode.EVAL_MODE_TRY + case Cast.Ansi => proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI + case Cast.Legacy => proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY + } + b.setEvalMode(convertedMode) + } + + case SqlExpression(expression, _) => + builder.getExpressionStringBuilder.setExpression(expression) + + case s: SortOrder => + builder.setSortOrder(convertSortOrder(s, e)) + + case Window(windowFunction, windowSpec, _) => + val b = builder.getWindowBuilder + .setWindowFunction(apply(windowFunction, e)) + .addAllPartitionSpec(windowSpec.partitionColumns.map(apply(_, e)).asJava) + .addAllOrderSpec(windowSpec.sortColumns.map(convertSortOrder(_, e)).asJava) + windowSpec.frame.foreach { frame => + b.getFrameSpecBuilder + .setFrameType(frame.frameType match { + case WindowFrame.Row => FrameType.FRAME_TYPE_ROW + case WindowFrame.Range => FrameType.FRAME_TYPE_RANGE + }) + .setLower(convertFrameBoundary(frame.lower, e)) + .setUpper(convertFrameBoundary(frame.upper, e)) + } + + case UnresolvedExtractValue(child, extraction, _) => + builder.getUnresolvedExtractValueBuilder + .setChild(apply(child, e)) + .setExtraction(apply(extraction, e)) + + case UpdateFields(structExpression, fieldName, valueExpression, _) => + val b = builder.getUpdateFieldsBuilder + .setStructExpression(apply(structExpression, e)) + .setFieldName(fieldName) + valueExpression.foreach(v => b.setValueExpression(apply(v, e))) + + case v: UnresolvedNamedLambdaVariable => + builder.setUnresolvedNamedLambdaVariable(convertNamedLambdaVariable(v)) + + case LambdaFunction(function, arguments, _) => + builder.getLambdaFunctionBuilder + .setFunction(apply(function, e)) + .addAllArguments(arguments.map(convertNamedLambdaVariable).asJava) + + case InvokeInlineUserDefinedFunction( + a: Aggregator[Any @unchecked, Any @unchecked, Any @unchecked], + Nil, + false, + _) => + // TODO we should probably 'just' detect this particular scenario + // in the planner instead of wrapping it in a separate method. + val protoUdf = UdfToProtoUtils.toProto(UserDefinedAggregator(a, e.get)) + builder.getTypedAggregateExpressionBuilder.setScalarScalaUdf(protoUdf.getScalarScalaUdf) + + case InvokeInlineUserDefinedFunction(udf: UserDefinedFunction, args, false, _) => + builder.setCommonInlineUserDefinedFunction( + UdfToProtoUtils.toProto(udf, args.map(apply(_, e)))) + + case CaseWhenOtherwise(branches, otherwise, _) => + val b = builder.getUnresolvedFunctionBuilder + .setFunctionName("when") + branches.foreach { case (condition, value) => + b.addArguments(apply(condition, e)) + b.addArguments(apply(value, e)) + } + otherwise.foreach { value => + b.addArguments(apply(value, e)) + } + + case ProtoColumnNode(e, _) => + return e + + case node => + throw SparkException.internalError("Unsupported ColumnNode: " + node) + } + builder.build() + } + + private def convertSortOrder( + s: SortOrder, + e: Option[Encoder[_]]): proto.Expression.SortOrder = { + proto.Expression.SortOrder + .newBuilder() + .setChild(apply(s.child, e)) + .setDirection(s.sortDirection match { + case SortOrder.Ascending => SORT_DIRECTION_ASCENDING + case SortOrder.Descending => SORT_DIRECTION_DESCENDING + }) + .setNullOrdering(s.nullOrdering match { + case SortOrder.NullsFirst => SORT_NULLS_FIRST + case SortOrder.NullsLast => SORT_NULLS_LAST + }) + .build() + } + + private def convertFrameBoundary( + boundary: WindowFrame.FrameBoundary, + e: Option[Encoder[_]]): FrameBoundary = { + val builder = FrameBoundary.newBuilder() + boundary match { + case WindowFrame.UnboundedPreceding => builder.setUnbounded(true) + case WindowFrame.UnboundedFollowing => builder.setUnbounded(true) + case WindowFrame.CurrentRow => builder.setCurrentRow(true) + case WindowFrame.Value(value) => builder.setValue(apply(value, e)) + } + builder.build() + } + + private def convertNamedLambdaVariable( + v: UnresolvedNamedLambdaVariable): proto.Expression.UnresolvedNamedLambdaVariable = { + proto.Expression.UnresolvedNamedLambdaVariable.newBuilder().addNameParts(v.name).build() + } +} + +case class ProtoColumnNode( + expr: proto.Expression, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { + override def sql: String = expr.toString +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala index 556b472283a37..ada94b76fcbcd 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala @@ -17,12 +17,6 @@ package org.apache.spark -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder - package object sql { type DataFrame = Dataset[Row] - - private[sql] def encoderFor[E: Encoder]: AgnosticEncoder[E] = { - implicitly[Encoder[E]].asInstanceOf[AgnosticEncoder[E]] - } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 789425c9daea1..2ff34a6343644 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,11 +21,9 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving import org.apache.spark.connect.proto.Read.DataSource -import org.apache.spark.internal.Logging -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dataset -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder +import org.apache.spark.sql.{api, DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.connect.ConnectConversions._ +import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.StructType /** @@ -35,101 +33,49 @@ import org.apache.spark.sql.types.StructType * @since 3.5.0 */ @Evolving -final class DataStreamReader private[sql] (sparkSession: SparkSession) extends Logging { +final class DataStreamReader private[sql] (sparkSession: SparkSession) + extends api.DataStreamReader { - /** - * Specifies the input data source format. - * - * @since 3.5.0 - */ - def format(source: String): DataStreamReader = { + private val sourceBuilder = DataSource.newBuilder() + + /** @inheritdoc */ + def format(source: String): this.type = { sourceBuilder.setFormat(source) this } - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can skip - * the schema inference step, and thus speed up data loading. - * - * @since 3.5.0 - */ - def schema(schema: StructType): DataStreamReader = { + /** @inheritdoc */ + def schema(schema: StructType): this.type = { if (schema != null) { sourceBuilder.setSchema(schema.json) // Use json. DDL does not retail all the attributes. } this } - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) - * can infer the input schema automatically from data. By specifying the schema here, the - * underlying data source can skip the schema inference step, and thus speed up data loading. - * - * @since 3.5.0 - */ - def schema(schemaString: String): DataStreamReader = { + /** @inheritdoc */ + override def schema(schemaString: String): this.type = { sourceBuilder.setSchema(schemaString) this } - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: String): DataStreamReader = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { sourceBuilder.putOptions(key, value) this } - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Long): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Double): DataStreamReader = option(key, value.toString) - - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamReader = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.options(options.asJava) - this } - /** - * (Java-specific) Adds input options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: java.util.Map[String, String]): DataStreamReader = { + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = { sourceBuilder.putAllOptions(options) this } - /** - * Loads input data stream in as a `DataFrame`, for data streams that don't require a path (e.g. - * external key-value stores). - * - * @since 3.5.0 - */ + /** @inheritdoc */ def load(): DataFrame = { sparkSession.newDataFrame { relationBuilder => relationBuilder.getReadBuilder @@ -138,120 +84,14 @@ final class DataStreamReader private[sql] (sparkSession: SparkSession) extends L } } - /** - * Loads input in as a `DataFrame`, for data streams that read from some path. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def load(path: String): DataFrame = { sourceBuilder.clearPaths() sourceBuilder.addPaths(path) load() } - /** - * Loads a JSON file stream and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * You can find the JSON-specific options for reading JSON file stream in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def json(path: String): DataFrame = { - format("json").load(path) - } - - /** - * Loads a CSV file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * You can find the CSV-specific options for reading CSV file stream in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def csv(path: String): DataFrame = format("csv").load(path) - - /** - * Loads a XML file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * You can find the XML-specific options for reading XML file stream in - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = format("xml").load(path) - - /** - * Loads a ORC file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * ORC-specific option(s) for reading ORC file stream can be found in Data - * Source Option in the version you use. - * - * @since 3.5.0 - */ - def orc(path: String): DataFrame = format("orc").load(path) - - /** - * Loads a Parquet file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * Parquet-specific option(s) for reading Parquet file stream can be found in Data - * Source Option in the version you use. - * - * @since 3.5.0 - */ - def parquet(path: String): DataFrame = format("parquet").load(path) - - /** - * Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should - * support streaming mode. - * @param tableName - * The name of the table - * @since 3.5.0 - */ + /** @inheritdoc */ def table(tableName: String): DataFrame = { require(tableName != null, "The table name can't be null") sparkSession.newDataFrame { builder => @@ -263,59 +103,44 @@ final class DataStreamReader private[sql] (sparkSession: SparkSession) extends L } } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. The text files must be encoded - * as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.readStream.text("/path/to/directory/") - * - * // Java: - * spark.readStream().text("/path/to/directory/") - * }}} - * - * You can set the following option(s):
  • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
  • - *
  • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
- * - * You can find the text-specific options for reading text files in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def text(path: String): DataFrame = format("text").load(path) - - /** - * Loads text file(s) and returns a `Dataset` of String. The underlying schema of the Dataset - * contains a single string column named "value". The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text file is a new element in the resulting Dataset. For - * example: - * {{{ - * // Scala: - * spark.readStream.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.readStream().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataStreamReader.text`. - * - * @param path - * input path - * @since 3.5.0 - */ - def textFile(path: String): Dataset[String] = { - text(path).select("value").as[String](StringEncoder) + override protected def assertNoSpecifiedSchema(operation: String): Unit = { + if (sourceBuilder.hasSchema) { + throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation) + } } - private val sourceBuilder = DataSource.newBuilder() + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant overrides. + /////////////////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) + + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) + + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) + + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) + + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) + + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) + + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) + } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index fe68f3cb0b572..9fcc31e562682 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,9 +29,8 @@ import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.WriteStreamOperationStart -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, UdfUtils} +import org.apache.spark.sql.{api, Dataset, ForeachWriter} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket} import org.apache.spark.sql.execution.streaming.AvailableNowTrigger import org.apache.spark.sql.execution.streaming.ContinuousTrigger import org.apache.spark.sql.execution.streaming.OneTimeTrigger @@ -47,63 +46,23 @@ import org.apache.spark.util.SparkSerDeUtils * @since 3.5.0 */ @Evolving -final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { +final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataStreamWriter[T] { + override type DS[U] = Dataset[U] - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
  • - * `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be written - * to the sink.
  • `OutputMode.Complete()`: all the rows in the streaming - * DataFrame/Dataset will be written to the sink every time there are some updates.
  • - * `OutputMode.Update()`: only the rows that were updated in the streaming DataFrame/Dataset - * will be written to the sink every time there are some updates. If the query doesn't contain - * aggregations, it will be equivalent to `OutputMode.Append()` mode.
- * - * @since 3.5.0 - */ - def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: OutputMode): this.type = { sinkBuilder.setOutputMode(outputMode.toString.toLowerCase(Locale.ROOT)) this } - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
  • - * `append`: only the new rows in the streaming DataFrame/Dataset will be written to the - * sink.
  • `complete`: all the rows in the streaming DataFrame/Dataset will be written - * to the sink every time there are some updates.
  • `update`: only the rows that were - * updated in the streaming DataFrame/Dataset will be written to the sink every time there are - * some updates. If the query doesn't contain aggregations, it will be equivalent to `append` - * mode.
- * - * @since 3.5.0 - */ - def outputMode(outputMode: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: String): this.type = { sinkBuilder.setOutputMode(outputMode) this } - /** - * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will - * run the query as fast as possible. - * - * Scala Example: - * {{{ - * df.writeStream.trigger(ProcessingTime("10 seconds")) - * - * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) - * }}} - * - * Java Example: - * {{{ - * df.writeStream().trigger(ProcessingTime.create("10 seconds")) - * - * import java.util.concurrent.TimeUnit - * df.writeStream().trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - def trigger(trigger: Trigger): DataStreamWriter[T] = { + /** @inheritdoc */ + def trigger(trigger: Trigger): this.type = { trigger match { case ProcessingTimeTrigger(intervalMs) => sinkBuilder.setProcessingTimeInterval(s"$intervalMs milliseconds") @@ -117,106 +76,54 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. This name - * must be unique among all the currently active queries in the associated SQLContext. - * - * @since 3.5.0 - */ - def queryName(queryName: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def queryName(queryName: String): this.type = { sinkBuilder.setQueryName(queryName) this } - /** - * Specifies the underlying output data source. - * - * @since 3.5.0 - */ - def format(source: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def format(source: String): this.type = { sinkBuilder.setFormat(source) this } - /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. As an example, when we - * partition a dataset by year and then month, the directory layout would look like: - * - *
  • year=2016/month=01/
  • year=2016/month=02/
- * - * Partitioning is one of the most widely used techniques to optimize physical data layout. It - * provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number of - * distinct values in each column should typically be less than tens of thousands. - * - * @since 3.5.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def partitionBy(colNames: String*): DataStreamWriter[T] = { + def partitionBy(colNames: String*): this.type = { sinkBuilder.clearPartitioningColumnNames() sinkBuilder.addAllPartitioningColumnNames(colNames.asJava) this } - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: String): DataStreamWriter[T] = { - sinkBuilder.putOptions(key, value) + /** @inheritdoc */ + @scala.annotation.varargs + def clusterBy(colNames: String*): this.type = { + sinkBuilder.clearClusteringColumnNames() + sinkBuilder.addAllClusteringColumnNames(colNames.asJava) this } - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) + /** @inheritdoc */ + def option(key: String, value: String): this.type = { + sinkBuilder.putOptions(key, value) + this + } - /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.options(options.asJava) this } - /** - * Adds output options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: java.util.Map[String, String]): this.type = { sinkBuilder.putAllOptions(options) this } - /** - * Sets the output of the streaming query to be processed using the provided writer object. - * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and - * semantics. - * @since 3.5.0 - */ - def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { + /** @inheritdoc */ + def foreach(writer: ForeachWriter[T]): this.type = { val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.agnosticEncoder)) val scalaWriterBuilder = proto.ScalarScalaUDF .newBuilder() @@ -225,21 +132,9 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * :: Experimental :: - * - * (Scala-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The - * batchId can be used to deduplicate and transactionally write the output (that is, the - * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the - * same for the same batchId (assuming all operations are deterministic in the query). - * - * @since 3.5.0 - */ + /** @inheritdoc */ @Evolving - def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { val serializedFn = SparkSerDeUtils.serialize(function) sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder .setPayload(ByteString.copyFrom(serializedFn)) @@ -248,48 +143,13 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * :: Experimental :: - * - * (Java-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The - * batchId can be used to deduplicate and transactionally write the output (that is, the - * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the - * same for the same batchId (assuming all operations are deterministic in the query). - * - * @since 3.5.0 - */ - @Evolving - def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { - foreachBatch(UdfUtils.foreachBatchFuncToScalaFunc(function)) - } - - /** - * Starts the execution of the streaming query, which will continually output results to the - * given path as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def start(path: String): StreamingQuery = { sinkBuilder.setPath(path) start() } - /** - * Starts the execution of the streaming query, which will continually output results to the - * given path as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. Throws a `TimeoutException` if the following conditions are met: - * - Another run of the same streaming query, that is a streaming query sharing the same - * checkpoint location, is already active on the same Spark Driver - * - The SQL configuration `spark.sql.streaming.stopActiveRunOnRestart` is enabled - * - The active run cannot be stopped within the timeout controlled by the SQL configuration - * `spark.sql.streaming.stopTimeout` - * - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[TimeoutException] def start(): StreamingQuery = { val startCmd = Command @@ -306,22 +166,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { RemoteStreamingQuery.fromStartCommandResponse(ds.sparkSession, resp) } - /** - * Starts the execution of the streaming query, which will continually output results to the - * given table as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. - * - * For v1 table, partitioning columns provided by `partitionBy` will be respected no matter the - * table exists or not. A new table will be created if the table not exists. - * - * For v2 table, `partitionBy` will be ignored if the table already exists. `partitionBy` will - * be respected only if the v2 table does not exist. Besides, the v2 table created by this API - * lacks some functionalities (e.g., customized properties, options, and serde info). If you - * need them, please create the v2 table manually before the execution to avoid creating a table - * with incomplete information. - * - * @since 3.5.0 - */ + /** @inheritdoc */ @Evolving @throws[TimeoutException] def toTable(tableName: String): StreamingQuery = { @@ -329,6 +174,24 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { start() } + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant Overrides + /////////////////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + @Evolving + override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = + super.foreachBatch(function) + private val sinkBuilder = WriteStreamOperationStart .newBuilder() .setInput(ds.plan.getRoot) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 13a26fa79085e..29fbcc443deb9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -18,166 +18,21 @@ package org.apache.spark.sql.streaming import java.util.UUID -import java.util.concurrent.TimeoutException import scala.jdk.CollectionConverters._ -import org.apache.spark.annotation.Evolving import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.connect.proto.StreamingQueryCommand import org.apache.spark.connect.proto.StreamingQueryCommandResult import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{api, SparkSession} -/** - * A handle to a query that is executing continuously in the background as new data arrives. All - * these methods are thread-safe. - * @since 3.5.0 - */ -@Evolving -trait StreamingQuery { - // This is a copy of StreamingQuery in sql/core/.../streaming/StreamingQuery.scala - - /** - * Returns the user-specified name of the query, or null if not specified. This name can be - * specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as - * `dataframe.writeStream.queryName("query").start()`. This name, if set, must be unique across - * all active queries. - * - * @since 3.5.0 - */ - def name: String - - /** - * Returns the unique id of this query that persists across restarts from checkpoint data. That - * is, this id is generated when a query is started for the first time, and will be the same - * every time it is restarted from checkpoint data. Also see [[runId]]. - * - * @since 3.5.0 - */ - def id: UUID - - /** - * Returns the unique id of this run of the query. That is, every start/restart of a query will - * generate a unique runId. Therefore, every time a query is restarted from checkpoint, it will - * have the same [[id]] but different [[runId]]s. - */ - def runId: UUID - - /** - * Returns the `SparkSession` associated with `this`. - * - * @since 3.5.0 - */ - def sparkSession: SparkSession - - /** - * Returns `true` if this query is actively running. - * - * @since 3.5.0 - */ - def isActive: Boolean - - /** - * Returns the [[StreamingQueryException]] if the query was terminated by an exception. - * @since 3.5.0 - */ - def exception: Option[StreamingQueryException] - - /** - * Returns the current status of the query. - * - * @since 3.5.0 - */ - def status: StreamingQueryStatus - - /** - * Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The - * number of progress updates retained for each stream is configured by Spark session - * configuration `spark.sql.streaming.numRecentProgressUpdates`. - * - * @since 3.5.0 - */ - def recentProgress: Array[StreamingQueryProgress] - - /** - * Returns the most recent [[StreamingQueryProgress]] update of this streaming query. - * - * @since 3.5.0 - */ - def lastProgress: StreamingQueryProgress - - /** - * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If - * the query has terminated with an exception, then the exception will be thrown. - * - * If the query has terminated, then all subsequent calls to this method will either return - * immediately (if the query was terminated by `stop()`), or throw the exception immediately (if - * the query has terminated with exception). - * - * @throws StreamingQueryException - * if the query has terminated with an exception. - * @since 3.5.0 - */ - @throws[StreamingQueryException] - def awaitTermination(): Unit - - /** - * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If - * the query has terminated with an exception, then the exception will be thrown. Otherwise, it - * returns whether the query has terminated or not within the `timeoutMs` milliseconds. - * - * If the query has terminated, then all subsequent calls to this method will either return - * `true` immediately (if the query was terminated by `stop()`), or throw the exception - * immediately (if the query has terminated with exception). - * - * @throws StreamingQueryException - * if the query has terminated with an exception - * @since 3.5.0 - */ - @throws[StreamingQueryException] - def awaitTermination(timeoutMs: Long): Boolean - - /** - * Blocks until all available data in the source has been processed and committed to the sink. - * This method is intended for testing. Note that in the case of continually arriving data, this - * method may block forever. Additionally, this method is only guaranteed to block until data - * that has been synchronously appended data to a - * `org.apache.spark.sql.execution.streaming.Source` prior to invocation. (i.e. `getOffset` must - * immediately reflect the addition). - * @since 3.5.0 - */ - def processAllAvailable(): Unit - - /** - * Stops the execution of this query if it is running. This waits until the termination of the - * query execution threads or until a timeout is hit. - * - * By default stop will block indefinitely. You can configure a timeout by the configuration - * `spark.sql.streaming.stopTimeout`. A timeout of 0 (or negative) milliseconds will block - * indefinitely. If a `TimeoutException` is thrown, users can retry stopping the stream. If the - * issue persists, it is advisable to kill the Spark application. - * - * @since 3.5.0 - */ - @throws[TimeoutException] - def stop(): Unit - - /** - * Prints the physical plan to the console for debugging purposes. - * @since 3.5.0 - */ - def explain(): Unit - - /** - * Prints the physical plan to the console for debugging purposes. - * - * @param extended - * whether to do extended explain or not - * @since 3.5.0 - */ - def explain(extended: Boolean): Unit +/** @inheritdoc */ +trait StreamingQuery extends api.StreamingQuery { + + /** @inheritdoc */ + override def sparkSession: SparkSession } class RemoteStreamingQuery( diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7efced227d6d1..647d29c714dbb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.StreamingQueryManagerCommand import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{api, SparkSession} import org.apache.spark.sql.connect.common.InvalidPlanInput /** @@ -36,7 +36,9 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput * @since 3.5.0 */ @Evolving -class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { +class StreamingQueryManager private[sql] (sparkSession: SparkSession) + extends api.StreamingQueryManager + with Logging { // Mapping from id to StreamingQueryListener. There's another mapping from id to // StreamingQueryListener on server side. This is used by removeListener() to find the id @@ -53,29 +55,17 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo streamingQueryListenerBus.close() } - /** - * Returns a list of active queries associated with this SQLContext - * - * @since 3.5.0 - */ + /** @inheritdoc */ def active: Array[StreamingQuery] = { executeManagerCmd(_.setActive(true)).getActive.getActiveQueriesList.asScala.map { q => RemoteStreamingQuery.fromStreamingQueryInstanceResponse(sparkSession, q) }.toArray } - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def get(id: UUID): StreamingQuery = get(id.toString) - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def get(id: String): StreamingQuery = { val response = executeManagerCmd(_.setGetQuery(id)) if (response.hasQuery) { @@ -85,52 +75,13 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } } - /** - * Wait until any of the queries on the associated SQLContext has terminated since the creation - * of the context, or since `resetTerminated()` was called. If any query was terminated with an - * exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return immediately (if the query was terminated by `query.stop()`), or throw the exception - * immediately (if the query was terminated with exception). Use `resetTerminated()` to clear - * past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, if - * any query has terminated with exception, then `awaitAnyTermination()` will throw any of the - * exception. For correctly documenting exceptions across multiple queries, users need to stop - * all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException - * if any query has terminated with an exception - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(): Unit = { executeManagerCmd(_.getAwaitAnyTerminationBuilder.build()) } - /** - * Wait until any of the queries on the associated SQLContext has terminated since the creation - * of the context, or since `resetTerminated()` was called. Returns whether any query has - * terminated or not (multiple may have terminated). If any query has terminated with an - * exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return `true` immediately (if the query was terminated by `query.stop()`), or throw the - * exception immediately (if the query was terminated with exception). Use `resetTerminated()` - * to clear past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, if - * any query has terminated with exception, then `awaitAnyTermination()` will throw any of the - * exception. For correctly documenting exceptions across multiple queries, users need to stop - * all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException - * if any query has terminated with an exception - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(timeoutMs: Long): Boolean = { require(timeoutMs > 0, "Timeout has to be positive") @@ -139,40 +90,22 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo timeoutMs)).getAwaitAnyTermination.getTerminated } - /** - * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to - * wait for new terminations. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def resetTerminated(): Unit = { executeManagerCmd(_.setResetTerminated(true)) } - /** - * Register a [[StreamingQueryListener]] to receive up-calls for life cycle events of - * [[StreamingQuery]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def addListener(listener: StreamingQueryListener): Unit = { streamingQueryListenerBus.append(listener) } - /** - * Deregister a [[StreamingQueryListener]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def removeListener(listener: StreamingQueryListener): Unit = { streamingQueryListenerBus.remove(listener) } - /** - * List all [[StreamingQueryListener]]s attached to this [[StreamingQueryManager]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def listListeners(): Array[StreamingQueryListener] = { streamingQueryListenerBus.list() } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 0e3a683d2701d..ce552bdd4f0f0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -69,7 +69,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe val exception = intercept[SparkException] { spark.catalog.setCurrentCatalog("notExists") } - assert(exception.getErrorClass == "CATALOG_NOT_FOUND") + assert(exception.getCondition == "CATALOG_NOT_FOUND") spark.catalog.setCurrentCatalog("testcat") assert(spark.catalog.currentCatalog().equals("testcat")) val catalogsAfterChange = spark.catalog.listCatalogs().collect() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala index e57b051890f56..0d9685d9c710f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} +import org.apache.spark.storage.StorageLevel class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper { @@ -50,12 +51,20 @@ class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHe checkFragments(captureStdOut(block), fragmentsToCheck) } - test("checkpoint") { + test("localCheckpoint") { val df = spark.range(100).localCheckpoint() testCapturedStdOut(df.explain(), "ExistingRDD") } - test("checkpoint gc") { + test("localCheckpoint with StorageLevel") { + // We don't have a way to reach into the server and assert the storage level server side, but + // this test should cover for unexpected errors in the API. + val df = + spark.range(100).localCheckpoint(eager = true, storageLevel = StorageLevel.DISK_ONLY) + df.collect() + } + + test("localCheckpoint gc") { val df = spark.range(100).localCheckpoint(eager = true) val encoder = df.agnosticEncoder val dfId = df.plan.getRoot.getCachedRemoteRelation.getRelationId @@ -77,7 +86,7 @@ class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHe // This test is flaky because cannot guarantee GC // You can locally run this to verify the behavior. - ignore("checkpoint gc derived DataFrame") { + ignore("localCheckpoint gc derived DataFrame") { var df1 = spark.range(100).localCheckpoint(eager = true) var derived = df1.repartition(10) val encoder = df1.agnosticEncoder diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala index 88281352f2479..84ed624a95214 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala @@ -251,16 +251,16 @@ class ClientDataFrameStatSuite extends ConnectFunSuite with RemoteSparkSession { val error1 = intercept[AnalysisException] { df.stat.bloomFilter("id", -1000, 100) } - assert(error1.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error1.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") val error2 = intercept[AnalysisException] { df.stat.bloomFilter("id", 1000, -100) } - assert(error2.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error2.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") val error3 = intercept[AnalysisException] { df.stat.bloomFilter("id", 1000, -1.0) } - assert(error3.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error3.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala index 9d6f07cf603aa..04367d3b95f14 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala @@ -71,33 +71,46 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { test("write") { val df = ss.newDataFrame(_ => ()).limit(10) - val builder = proto.WriteOperation.newBuilder() - builder + def toPlan(builder: proto.WriteOperation.Builder): proto.Plan = { + proto.Plan + .newBuilder() + .setCommand(proto.Command.newBuilder().setWriteOperation(builder)) + .build() + } + + val builder = proto.WriteOperation + .newBuilder() .setInput(df.plan.getRoot) .setPath("my/test/path") .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS) .setSource("parquet") - .addSortColumnNames("col1") - .addPartitioningColumns("col99") - .setBucketBy( - proto.WriteOperation.BucketBy - .newBuilder() - .setNumBuckets(2) - .addBucketColumnNames("col1") - .addBucketColumnNames("col2")) - val expectedPlan = proto.Plan - .newBuilder() - .setCommand(proto.Command.newBuilder().setWriteOperation(builder)) - .build() + val partitionedPlan = toPlan( + builder + .clone() + .addSortColumnNames("col1") + .addPartitioningColumns("col99") + .setBucketBy( + proto.WriteOperation.BucketBy + .newBuilder() + .setNumBuckets(2) + .addBucketColumnNames("col1") + .addBucketColumnNames("col2"))) df.write .sortBy("col1") .partitionBy("col99") .bucketBy(2, "col1", "col2") .parquet("my/test/path") - val actualPlan = service.getAndClearLatestInputPlan() - assert(actualPlan.equals(expectedPlan)) + val actualPartionedPlan = service.getAndClearLatestInputPlan() + assert(actualPartionedPlan.equals(partitionedPlan)) + + val clusteredPlan = toPlan(builder.clone().addClusteringColumns("col3")) + df.write + .clusterBy("col3") + .parquet("my/test/path") + val actualClusteredPlan = service.getAndClearLatestInputPlan() + assert(actualClusteredPlan.equals(clusteredPlan)) } test("write jdbc") { @@ -134,8 +147,9 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { builder .setInput(df.plan.getRoot) .setTableName("t1") - .addPartitioningColumns(col("col99").expr) + .addPartitioningColumns(toExpr(col("col99"))) .setProvider("json") + .addClusteringColumns("col3") .putTableProperties("key", "value") .putOptions("key2", "value2") .setMode(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE) @@ -147,6 +161,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { df.writeTo("t1") .partitionedBy(col("col99")) + .clusterBy("col3") .using("json") .tableProperty("key", "value") .options(Map("key2" -> "value2")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index a66f790e08740..0371981b728d1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.DurationInt +import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.jdk.CollectionConverters._ import org.apache.commons.io.FileUtils @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} @@ -95,7 +95,7 @@ class ClientE2ETestSuite .collect() } assert( - ex.getErrorClass === + ex.getCondition === "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER") assert( ex.getMessageParameters.asScala == Map( @@ -122,12 +122,12 @@ class ClientE2ETestSuite Seq("1").toDS().withColumn("udf_val", throwException($"value")).collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.getMessageParameters.isEmpty) assert(ex.getCause.isInstanceOf[SparkException]) val cause = ex.getCause.asInstanceOf[SparkException] - assert(cause.getErrorClass == null) + assert(cause.getCondition == null) assert(cause.getMessageParameters.isEmpty) assert(cause.getMessage.contains("test" * 10000)) } @@ -141,7 +141,7 @@ class ClientE2ETestSuite val ex = intercept[AnalysisException] { spark.sql("select x").collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.messageParameters.isEmpty) assert(ex.getSqlState != null) assert(!ex.isInternalError) @@ -169,14 +169,14 @@ class ClientE2ETestSuite val ex = intercept[NoSuchNamespaceException] { spark.sql("use database123") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } test("table not found for spark.catalog.getTable") { val ex = intercept[AnalysisException] { spark.catalog.getTable("test_table") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } test("throw NamespaceAlreadyExistsException") { @@ -185,7 +185,7 @@ class ClientE2ETestSuite val ex = intercept[NamespaceAlreadyExistsException] { spark.sql("create database test_db") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } finally { spark.sql("drop database test_db") } @@ -197,7 +197,7 @@ class ClientE2ETestSuite val ex = intercept[TempTableAlreadyExistsException] { spark.sql("create temporary view test_view as select 1") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } finally { spark.sql("drop view test_view") } @@ -209,7 +209,7 @@ class ClientE2ETestSuite val ex = intercept[TableAlreadyExistsException] { spark.sql(s"create table testcat.test_table (id int)") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } } @@ -217,7 +217,7 @@ class ClientE2ETestSuite val ex = intercept[ParseException] { spark.sql("selet 1").collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.messageParameters.isEmpty) assert(ex.getSqlState != null) assert(!ex.isInternalError) @@ -308,6 +308,7 @@ class ClientE2ETestSuite val testDataPath = java.nio.file.Paths .get( IntegrationTestUtils.sparkHome, + "sql", "connect", "common", "src", @@ -347,6 +348,7 @@ class ClientE2ETestSuite val testDataPath = java.nio.file.Paths .get( IntegrationTestUtils.sparkHome, + "sql", "connect", "common", "src", @@ -377,6 +379,7 @@ class ClientE2ETestSuite val testDataPath = java.nio.file.Paths .get( IntegrationTestUtils.sparkHome, + "sql", "connect", "common", "src", @@ -1563,6 +1566,25 @@ class ClientE2ETestSuite val result = df.select(trim(col("col"), " ").as("trimmed_col")).collect() assert(result sameElements Array(Row("a"), Row("b"), Row("c"))) } + + test("SPARK-49673: new batch size, multiple batches") { + val maxBatchSize = spark.conf.get("spark.connect.grpc.arrow.maxBatchSize").dropRight(1).toInt + // Adjust client grpcMaxMessageSize to maxBatchSize (10MiB; set in RemoteSparkSession config) + val sparkWithLowerMaxMessageSize = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .userId("test") + .port(port) + .grpcMaxMessageSize(maxBatchSize) + .retryPolicy(RetryPolicy + .defaultPolicy() + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) + .build()) + .create() + assert(sparkWithLowerMaxMessageSize.range(maxBatchSize).collect().length == maxBatchSize) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala index 0fb6894e457ae..86c7a20136851 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala @@ -170,11 +170,11 @@ class ColumnTestSuite extends ConnectFunSuite { val x = fn.col("a") + fn.col("b") val explain1 = captureStdOut(x.explain(false)) val explain2 = captureStdOut(x.explain(true)) - assert(explain1 == explain2) - val expectedFragments = Seq("unresolved_function", "function_name: \"+\"", "arguments") - expectedFragments.foreach { fragment => - assert(explain1.contains(fragment)) - } + assert(explain1 != explain2) + assert(explain1.strip() == "+(a, b)") + assert(explain2.contains("UnresolvedFunction(+")) + assert(explain2.contains("UnresolvedAttribute(List(a")) + assert(explain2.contains("UnresolvedAttribute(List(b")) } private def testColName(dataType: DataType, f: ColumnName => StructField): Unit = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala new file mode 100644 index 0000000000000..4c0357a3ed984 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession} + +class DataFrameTableValuedFunctionsSuite extends QueryTest with RemoteSparkSession { + + test("explode") { + val actual1 = spark.tvf.explode(array(lit(1), lit(2))) + val expected1 = spark.sql("SELECT * FROM explode(array(1, 2))") + checkAnswer(actual1, expected1) + + val actual2 = spark.tvf.explode(map(lit("a"), lit(1), lit("b"), lit(2))) + val expected2 = spark.sql("SELECT * FROM explode(map('a', 1, 'b', 2))") + checkAnswer(actual2, expected2) + + // empty + val actual3 = spark.tvf.explode(array()) + val expected3 = spark.sql("SELECT * FROM explode(array())") + checkAnswer(actual3, expected3) + + val actual4 = spark.tvf.explode(map()) + val expected4 = spark.sql("SELECT * FROM explode(map())") + checkAnswer(actual4, expected4) + + // null + val actual5 = spark.tvf.explode(lit(null).cast("array")) + val expected5 = spark.sql("SELECT * FROM explode(null :: array)") + checkAnswer(actual5, expected5) + + val actual6 = spark.tvf.explode(lit(null).cast("map")) + val expected6 = spark.sql("SELECT * FROM explode(null :: map)") + checkAnswer(actual6, expected6) + } + + test("explode_outer") { + val actual1 = spark.tvf.explode_outer(array(lit(1), lit(2))) + val expected1 = spark.sql("SELECT * FROM explode_outer(array(1, 2))") + checkAnswer(actual1, expected1) + + val actual2 = spark.tvf.explode_outer(map(lit("a"), lit(1), lit("b"), lit(2))) + val expected2 = spark.sql("SELECT * FROM explode_outer(map('a', 1, 'b', 2))") + checkAnswer(actual2, expected2) + + // empty + val actual3 = spark.tvf.explode_outer(array()) + val expected3 = spark.sql("SELECT * FROM explode_outer(array())") + checkAnswer(actual3, expected3) + + val actual4 = spark.tvf.explode_outer(map()) + val expected4 = spark.sql("SELECT * FROM explode_outer(map())") + checkAnswer(actual4, expected4) + + // null + val actual5 = spark.tvf.explode_outer(lit(null).cast("array")) + val expected5 = spark.sql("SELECT * FROM explode_outer(null :: array)") + checkAnswer(actual5, expected5) + + val actual6 = spark.tvf.explode_outer(lit(null).cast("map")) + val expected6 = spark.sql("SELECT * FROM explode_outer(null :: map)") + checkAnswer(actual6, expected6) + } + + test("inline") { + val actual1 = spark.tvf.inline(array(struct(lit(1), lit("a")), struct(lit(2), lit("b")))) + val expected1 = spark.sql("SELECT * FROM inline(array(struct(1, 'a'), struct(2, 'b')))") + checkAnswer(actual1, expected1) + + val actual2 = spark.tvf.inline(array().cast("array>")) + val expected2 = spark.sql("SELECT * FROM inline(array() :: array>)") + checkAnswer(actual2, expected2) + + val actual3 = spark.tvf.inline( + array( + named_struct(lit("a"), lit(1), lit("b"), lit(2)), + lit(null), + named_struct(lit("a"), lit(3), lit("b"), lit(4)))) + val expected3 = spark.sql( + "SELECT * FROM " + + "inline(array(named_struct('a', 1, 'b', 2), null, named_struct('a', 3, 'b', 4)))") + checkAnswer(actual3, expected3) + } + + test("inline_outer") { + val actual1 = + spark.tvf.inline_outer(array(struct(lit(1), lit("a")), struct(lit(2), lit("b")))) + val expected1 = spark.sql("SELECT * FROM inline_outer(array(struct(1, 'a'), struct(2, 'b')))") + checkAnswer(actual1, expected1) + + val actual2 = spark.tvf.inline_outer(array().cast("array>")) + val expected2 = spark.sql("SELECT * FROM inline_outer(array() :: array>)") + checkAnswer(actual2, expected2) + + val actual3 = spark.tvf.inline_outer( + array( + named_struct(lit("a"), lit(1), lit("b"), lit(2)), + lit(null), + named_struct(lit("a"), lit(3), lit("b"), lit(4)))) + val expected3 = spark.sql( + "SELECT * FROM " + + "inline_outer(array(named_struct('a', 1, 'b', 2), null, named_struct('a', 3, 'b', 4)))") + checkAnswer(actual3, expected3) + } + + test("json_tuple") { + val actual = spark.tvf.json_tuple(lit("""{"a":1,"b":2}"""), lit("a"), lit("b")) + val expected = spark.sql("""SELECT * FROM json_tuple('{"a":1,"b":2}', 'a', 'b')""") + checkAnswer(actual, expected) + + val ex = intercept[AnalysisException] { + spark.tvf.json_tuple(lit("""{"a":1,"b":2}""")).collect() + } + assert(ex.errorClass.get == "WRONG_NUM_ARGS.WITHOUT_SUGGESTION") + assert(ex.messageParameters("functionName") == "`json_tuple`") + } + + test("posexplode") { + val actual1 = spark.tvf.posexplode(array(lit(1), lit(2))) + val expected1 = spark.sql("SELECT * FROM posexplode(array(1, 2))") + checkAnswer(actual1, expected1) + + val actual2 = spark.tvf.posexplode(map(lit("a"), lit(1), lit("b"), lit(2))) + val expected2 = spark.sql("SELECT * FROM posexplode(map('a', 1, 'b', 2))") + checkAnswer(actual2, expected2) + + // empty + val actual3 = spark.tvf.posexplode(array()) + val expected3 = spark.sql("SELECT * FROM posexplode(array())") + checkAnswer(actual3, expected3) + + val actual4 = spark.tvf.posexplode(map()) + val expected4 = spark.sql("SELECT * FROM posexplode(map())") + checkAnswer(actual4, expected4) + + // null + val actual5 = spark.tvf.posexplode(lit(null).cast("array")) + val expected5 = spark.sql("SELECT * FROM posexplode(null :: array)") + checkAnswer(actual5, expected5) + + val actual6 = spark.tvf.posexplode(lit(null).cast("map")) + val expected6 = spark.sql("SELECT * FROM posexplode(null :: map)") + checkAnswer(actual6, expected6) + } + + test("posexplode_outer") { + val actual1 = spark.tvf.posexplode_outer(array(lit(1), lit(2))) + val expected1 = spark.sql("SELECT * FROM posexplode_outer(array(1, 2))") + checkAnswer(actual1, expected1) + + val actual2 = spark.tvf.posexplode_outer(map(lit("a"), lit(1), lit("b"), lit(2))) + val expected2 = spark.sql("SELECT * FROM posexplode_outer(map('a', 1, 'b', 2))") + checkAnswer(actual2, expected2) + + // empty + val actual3 = spark.tvf.posexplode_outer(array()) + val expected3 = spark.sql("SELECT * FROM posexplode_outer(array())") + checkAnswer(actual3, expected3) + + val actual4 = spark.tvf.posexplode_outer(map()) + val expected4 = spark.sql("SELECT * FROM posexplode_outer(map())") + checkAnswer(actual4, expected4) + + // null + val actual5 = spark.tvf.posexplode_outer(lit(null).cast("array")) + val expected5 = spark.sql("SELECT * FROM posexplode_outer(null :: array)") + checkAnswer(actual5, expected5) + + val actual6 = spark.tvf.posexplode_outer(lit(null).cast("map")) + val expected6 = spark.sql("SELECT * FROM posexplode_outer(null :: map)") + checkAnswer(actual6, expected6) + } + + test("stack") { + val actual = spark.tvf.stack(lit(2), lit(1), lit(2), lit(3)) + val expected = spark.sql("SELECT * FROM stack(2, 1, 2, 3)") + checkAnswer(actual, expected) + } + + test("collations") { + val actual = spark.tvf.collations() + val expected = spark.sql("SELECT * FROM collations()") + checkAnswer(actual, expected) + } + + test("sql_keywords") { + val actual = spark.tvf.sql_keywords() + val expected = spark.sql("SELECT * FROM sql_keywords()") + checkAnswer(actual, expected) + } + + // TODO(SPARK-50063): Support VARIANT in Spark Connect Scala client + ignore("variant_explode") { + val actual1 = spark.tvf.variant_explode(parse_json(lit("""["hello", "world"]"""))) + val expected1 = + spark.sql("""SELECT * FROM variant_explode(parse_json('["hello", "world"]'))""") + checkAnswer(actual1, expected1) + + val actual2 = spark.tvf.variant_explode(parse_json(lit("""{"a": true, "b": 3.14}"""))) + val expected2 = + spark.sql("""SELECT * FROM variant_explode(parse_json('{"a": true, "b": 3.14}'))""") + checkAnswer(actual2, expected2) + + // empty + val actual3 = spark.tvf.variant_explode(parse_json(lit("[]"))) + val expected3 = spark.sql("SELECT * FROM variant_explode(parse_json('[]'))") + checkAnswer(actual3, expected3) + + val actual4 = spark.tvf.variant_explode(parse_json(lit("{}"))) + val expected4 = spark.sql("SELECT * FROM variant_explode(parse_json('{}'))") + checkAnswer(actual4, expected4) + + // null + val actual5 = spark.tvf.variant_explode(lit(null).cast("variant")) + val expected5 = spark.sql("SELECT * FROM variant_explode(null :: variant)") + checkAnswer(actual5, expected5) + + // not a variant object/array + val actual6 = spark.tvf.variant_explode(parse_json(lit("1"))) + val expected6 = spark.sql("SELECT * FROM variant_explode(parse_json('1'))") + checkAnswer(actual6, expected6) + } + + // TODO(SPARK-50063): Support VARIANT in Spark Connect Scala client + ignore("variant_explode_outer") { + val actual1 = spark.tvf.variant_explode_outer(parse_json(lit("""["hello", "world"]"""))) + val expected1 = + spark.sql("""SELECT * FROM variant_explode_outer(parse_json('["hello", "world"]'))""") + checkAnswer(actual1, expected1) + + val actual2 = spark.tvf.variant_explode_outer(parse_json(lit("""{"a": true, "b": 3.14}"""))) + val expected2 = + spark.sql("""SELECT * FROM variant_explode_outer(parse_json('{"a": true, "b": 3.14}'))""") + checkAnswer(actual2, expected2) + + // empty + val actual3 = spark.tvf.variant_explode_outer(parse_json(lit("[]"))) + val expected3 = spark.sql("SELECT * FROM variant_explode_outer(parse_json('[]'))") + checkAnswer(actual3, expected3) + + val actual4 = spark.tvf.variant_explode_outer(parse_json(lit("{}"))) + val expected4 = spark.sql("SELECT * FROM variant_explode_outer(parse_json('{}'))") + checkAnswer(actual4, expected4) + + // null + val actual5 = spark.tvf.variant_explode_outer(lit(null).cast("variant")) + val expected5 = spark.sql("SELECT * FROM variant_explode_outer(null :: variant)") + checkAnswer(actual5, expected5) + + // not a variant object/array + val actual6 = spark.tvf.variant_explode_outer(parse_json(lit("1"))) + val expected6 = spark.sql("SELECT * FROM variant_explode_outer(parse_json('1'))") + checkAnswer(actual6, expected6) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala index 748843ec9910c..40b66bcb8358d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala @@ -22,7 +22,6 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.protobuf.{functions => pbFn} import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.{DataType, StructType} @@ -196,20 +195,18 @@ class FunctionTestSuite extends ConnectFunSuite { lead("a", 2, null), lead(a, 2, null, false)) testEquals( - "aggregate", - aggregate(a, lit(0), (l, r) => l + r), - aggregate(a, lit(0), (l, r) => l + r, id => id)) - testEquals( - "from_json", + "from_json with sql schema", from_json(a, schema.asInstanceOf[DataType]), from_json(a, schema), - from_json(a, lit(schema.json)), - from_json(a, schema.json, Map.empty[String, String]), - from_json(a, schema.json, Collections.emptyMap[String, String]), from_json(a, schema.asInstanceOf[DataType], Map.empty[String, String]), from_json(a, schema.asInstanceOf[DataType], Collections.emptyMap[String, String]), from_json(a, schema, Map.empty[String, String]), - from_json(a, schema, Collections.emptyMap[String, String]), + from_json(a, schema, Collections.emptyMap[String, String])) + testEquals( + "from_json with json schema", + from_json(a, lit(schema.json)), + from_json(a, schema.json, Map.empty[String, String]), + from_json(a, schema.json, Collections.emptyMap[String, String]), from_json(a, lit(schema.json), Collections.emptyMap[String, String])) testEquals("schema_of_json", schema_of_json(lit("x,y")), schema_of_json("x,y")) testEquals( @@ -229,13 +226,15 @@ class FunctionTestSuite extends ConnectFunSuite { schema_of_csv(lit("x,y"), Collections.emptyMap())) testEquals("to_csv", to_csv(a), to_csv(a, Collections.emptyMap[String, String])) testEquals( - "from_xml", + "from_xml with sql schema", from_xml(a, schema), + from_xml(a, schema, Map.empty[String, String].asJava), + from_xml(a, schema, Collections.emptyMap[String, String])) + testEquals( + "from_xml with json schema", from_xml(a, lit(schema.json)), from_xml(a, schema.json, Collections.emptyMap[String, String]), from_xml(a, schema.json, Map.empty[String, String].asJava), - from_xml(a, schema, Map.empty[String, String].asJava), - from_xml(a, schema, Collections.emptyMap[String, String]), from_xml(a, lit(schema.json), Collections.emptyMap[String, String])) testEquals( "schema_of_xml", @@ -251,28 +250,16 @@ class FunctionTestSuite extends ConnectFunSuite { a, """{"type": "int", "name": "id"}""", Collections.emptyMap[String, String])) - testEquals( - "from_protobuf", - pbFn.from_protobuf( - a, - "FakeMessage", - "fakeBytes".getBytes(), - Map.empty[String, String].asJava), - pbFn.from_protobuf(a, "FakeMessage", "fakeBytes".getBytes())) - testEquals( - "to_protobuf", - pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes(), Map.empty[String, String].asJava), - pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes())) testEquals("call_udf", callUDF("bob", lit(1)), call_udf("bob", lit(1))) test("assert_true no message") { - val e = assert_true(a).expr + val e = toExpr(assert_true(a)) assert(e.hasUnresolvedFunction) val fn = e.getUnresolvedFunction assert(fn.getFunctionName == "assert_true") assert(fn.getArgumentsCount == 1) - assert(fn.getArguments(0) == a.expr) + assert(fn.getArguments(0) == toExpr(a)) } test("json_tuple zero args") { @@ -280,7 +267,7 @@ class FunctionTestSuite extends ConnectFunSuite { } test("rand no seed") { - val e = rand().expr + val e = toExpr(rand()) assert(e.hasUnresolvedFunction) val fn = e.getUnresolvedFunction assert(fn.getFunctionName == "rand") @@ -288,7 +275,7 @@ class FunctionTestSuite extends ConnectFunSuite { } test("randn no seed") { - val e = randn().expr + val e = toExpr(randn()) assert(e.hasUnresolvedFunction) val fn = e.getUnresolvedFunction assert(fn.getFunctionName == "randn") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 07b03c4564bcc..1b973ac8ea344 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -71,7 +72,7 @@ import org.apache.spark.util.SparkFileUtils * compatibility. * * Note that the plan protos are used as the input for the `ProtoToParsedPlanTestSuite` in the - * `connect/server` module + * `sql/connect/server` module */ // scalastyle:on class PlanGenerationTestSuite @@ -88,7 +89,7 @@ class PlanGenerationTestSuite protected val queryFilePath: Path = commonResourcePath.resolve("query-tests/queries") - // A relative path to /connect/server, used by `ProtoToParsedPlanTestSuite` to run + // A relative path to /sql/connect/server, used by `ProtoToParsedPlanTestSuite` to run // with the datasource. protected val testDataPath: Path = java.nio.file.Paths.get( "../", @@ -118,6 +119,7 @@ class PlanGenerationTestSuite override protected def beforeEach(): Unit = { session.resetPlanIdGenerator() + internal.UnresolvedNamedLambdaVariable.resetIdGenerator() } override protected def afterAll(): Unit = { @@ -551,6 +553,14 @@ class PlanGenerationTestSuite valueColumnName = "value") } + test("transpose index_column") { + simple.transpose(indexColumn = fn.col("id")) + } + + test("transpose no_index_column") { + simple.transpose() + } + test("offset") { simple.offset(1000) } @@ -1800,7 +1810,11 @@ class PlanGenerationTestSuite fn.sentences(fn.col("g")) } - functionTest("sentences with locale") { + functionTest("sentences with language") { + fn.sentences(fn.col("g"), lit("en")) + } + + functionTest("sentences with language and country") { fn.sentences(fn.col("g"), lit("en"), lit("US")) } @@ -1911,6 +1925,47 @@ class PlanGenerationTestSuite fn.make_interval() } + functionTest("try_make_interval years months weeks days hours mins secs") { + fn.try_make_interval( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b")) + } + + functionTest("try_make_interval years months weeks days hours mins") { + fn.try_make_interval( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a")) + } + + functionTest("try_make_interval years months weeks days hours") { + fn.try_make_interval(fn.col("a"), fn.col("a"), fn.col("a"), fn.col("a"), fn.col("a")) + } + + functionTest("try_make_interval years months weeks days") { + fn.try_make_interval(fn.col("a"), fn.col("a"), fn.col("a"), fn.col("a")) + } + + functionTest("try_make_interval years months weeks") { + fn.try_make_interval(fn.col("a"), fn.col("a"), fn.col("a")) + } + + functionTest("try_make_interval years months") { + fn.try_make_interval(fn.col("a"), fn.col("a")) + } + + functionTest("try_make_interval years") { + fn.try_make_interval(fn.col("a")) + } + functionTest("make_timestamp with timezone") { fn.make_timestamp( fn.col("a"), @@ -1963,6 +2018,58 @@ class PlanGenerationTestSuite fn.col("b")) } + functionTest("try_make_timestamp with timezone") { + fn.try_make_timestamp( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b"), + fn.col("g")) + } + + functionTest("try_make_timestamp without timezone") { + fn.try_make_timestamp( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b")) + } + + functionTest("try_make_timestamp_ltz with timezone") { + fn.try_make_timestamp_ltz( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b"), + fn.col("g")) + } + + functionTest("try_make_timestamp_ltz without timezone") { + fn.try_make_timestamp_ltz( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b")) + } + + functionTest("try_make_timestamp_ntz") { + fn.try_make_timestamp_ntz( + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("a"), + fn.col("b")) + } + functionTest("make_ym_interval years months") { fn.make_ym_interval(fn.col("a"), fn.col("a")) } @@ -2433,6 +2540,10 @@ class PlanGenerationTestSuite fn.aggregate(fn.col("e"), lit(0), (x, y) => x + y) } + functionTest("aggregate with finish lambda") { + fn.aggregate(fn.col("e"), lit(0), (x, y) => x + y, x => x + lit(2)) + } + functionTest("reduce") { fn.reduce(fn.col("e"), lit(0), (x, y) => x + y) } @@ -2493,6 +2604,10 @@ class PlanGenerationTestSuite fn.from_json(fn.col("g"), simpleSchema) } + functionTest("from_json with json schema") { + fn.from_json(fn.col("g"), fn.lit(simpleSchema.json)) + } + functionTest("schema_of_json") { fn.schema_of_json(lit("""[{"col":01}]""")) } @@ -2669,6 +2784,14 @@ class PlanGenerationTestSuite fn.parse_url(fn.col("g"), fn.col("g"), fn.col("g")) } + functionTest("try_parse_url") { + fn.try_parse_url(fn.col("g"), fn.col("g")) + } + + functionTest("try_parse_url with key") { + fn.try_parse_url(fn.col("g"), fn.col("g"), fn.col("g")) + } + functionTest("printf") { fn.printf(fn.col("g"), fn.col("a"), fn.col("g")) } @@ -2968,6 +3091,14 @@ class PlanGenerationTestSuite fn.call_function("lower", fn.col("g")) } + functionTest("from_xml") { + fn.from_xml(fn.col("g"), simpleSchema) + } + + functionTest("from_xml with json schema") { + fn.from_xml(fn.col("g"), fn.lit(simpleSchema.json)) + } + test("hll_sketch_agg with column lgConfigK") { binary.select(fn.hll_sketch_agg(fn.col("bytes"), lit(0))) } @@ -3264,7 +3395,7 @@ class PlanGenerationTestSuite .setUnparsedIdentifier("id"))) .setCustomField("abc") .build() - simple.select(Column(_.setExtension(com.google.protobuf.Any.pack(extension)))) + simple.select(column(_.setExtension(com.google.protobuf.Any.pack(extension)))) } test("crosstab") { @@ -3325,10 +3456,10 @@ class PlanGenerationTestSuite /* Protobuf functions */ // scalastyle:off line.size.limit // If `common.desc` needs to be updated, execute the following command to regenerate it: - // 1. cd connect/common/src/main/protobuf/spark/connect + // 1. cd sql/connect/common/src/main/protobuf/spark/connect // 2. protoc --include_imports --descriptor_set_out=../../../../test/resources/protobuf-tests/common.desc common.proto // scalastyle:on line.size.limit - private val testDescFilePath: String = s"${IntegrationTestUtils.sparkHome}/connect/" + + private val testDescFilePath: String = s"${IntegrationTestUtils.sparkHome}/sql/connect/" + "common/src/test/resources/protobuf-tests/common.desc" // TODO(SPARK-45030): Re-enable this test when all Maven test scenarios succeed and there diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index 57342e12fcb51..b3b8020b1e4c7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.SystemUtils import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer} import org.apache.spark.sql.test.ConnectFunSuite @@ -55,7 +56,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { import org.apache.spark.util.ArrayImplicits._ import spark.implicits._ def testImplicit[T: Encoder](expected: T): Unit = { - val encoder = encoderFor[T] + val encoder = agnosticEncoderFor[T] val allocator = new RootAllocator() try { val batch = ArrowSerializer.serialize( diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala similarity index 52% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala index 66f591bf1fb99..ed930882ac2fd 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala @@ -16,26 +16,18 @@ */ package org.apache.spark.sql +import org.apache.spark.sql.api.SparkSessionBuilder +import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} + /** - * A container for a [[Dataset]], used for implicit conversions in Scala. - * - * To use this, import implicit conversions in SQL: - * {{{ - * val spark: SparkSession = ... - * import spark.implicits._ - * }}} - * - * @since 3.4.0 + * Make sure the api.SparkSessionBuilder binds to Connect implementation. */ -case class DatasetHolder[T] private[sql] (private val ds: Dataset[T]) { - - // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. - def toDS(): Dataset[T] = ds - - // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = ds.toDF() - - def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*) +class SparkSessionBuilderImplementationBindingSuite + extends ConnectFunSuite + with api.SparkSessionBuilderImplementationBindingSuite + with RemoteSparkSession { + override protected def configure(builder: SparkSessionBuilder): builder.type = { + // We need to set this configuration because the port used by the server is random. + builder.remote(s"sc://localhost:$serverPort") + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 8abc41639fdd2..dec56554d143e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -22,6 +22,7 @@ import scala.util.control.NonFatal import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} +import org.apache.spark.SparkException import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils @@ -113,7 +114,7 @@ class SparkSessionSuite extends ConnectFunSuite { SparkSession.clearActiveSession() assert(SparkSession.getDefaultSession.isEmpty) assert(SparkSession.getActiveSession.isEmpty) - intercept[IllegalStateException](SparkSession.active) + intercept[SparkException](SparkSession.active) // Create a session val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala index 1d8d164c9541c..c1e44b6fb11b2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala @@ -21,8 +21,10 @@ import java.nio.file.{Files, Paths} import scala.util.Properties +import com.google.protobuf.ByteString + +import org.apache.spark.connect.proto import org.apache.spark.sql.connect.common.ProtoDataTypes -import org.apache.spark.sql.expressions.ScalaUserDefinedFunction import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession { @@ -39,15 +41,17 @@ class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession { new File(s"src/test/resources/udf$scalaVersion.jar").toURI.toURL private def registerUdf(session: SparkSession): Unit = { - val udf = ScalaUserDefinedFunction( - serializedUdfPacket = udfByteArray, - inputTypes = Seq(ProtoDataTypes.IntegerType), - outputType = ProtoDataTypes.IntegerType, - name = Some("dummyUdf"), - nullable = true, - deterministic = true, - aggregate = false) - session.registerUdf(udf.toProto) + val builder = proto.CommonInlineUserDefinedFunction + .newBuilder() + .setDeterministic(true) + .setFunctionName("dummyUdf") + builder.getScalarScalaUdfBuilder + .setPayload(ByteString.copyFrom(udfByteArray)) + .addInputTypes(ProtoDataTypes.IntegerType) + .setOutputType(ProtoDataTypes.IntegerType) + .setNullable(true) + .setAggregate(false) + session.registerUdf(builder.build()) } test("update class loader after stubbing: new session") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala new file mode 100644 index 0000000000000..6a26cf581751d --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UnsupportedFeaturesSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.types.StructType + +/** + * Test suite that test the errors thrown when using unsupported features. + */ +class UnsupportedFeaturesSuite extends ConnectFunSuite { + private def session = SparkSession.builder().getOrCreate() + + private def testUnsupportedFeature(name: String, errorCode: String)( + f: SparkSession => Any): Unit = { + test(name) { + val e = intercept[SparkUnsupportedOperationException](f(session)) + assert(e.getCondition == "UNSUPPORTED_CONNECT_FEATURE." + errorCode) + } + } + + testUnsupportedFeature("SparkSession.createDataFrame(RDD)", "RDD") { session => + session.createDataFrame(new RDD[(Int, Int)]) + } + + testUnsupportedFeature("SparkSession.createDataFrame(RDD, StructType)", "RDD") { session => + val schema = new StructType().add("_1", "int").add("_2", "int") + session.createDataFrame(new RDD[Row], schema) + } + + testUnsupportedFeature("SparkSession.createDataFrame(JavaRDD, StructType)", "RDD") { session => + val schema = new StructType().add("_1", "int").add("_2", "int") + session.createDataFrame(new JavaRDD[Row], schema) + } + + testUnsupportedFeature("SparkSession.createDataFrame(RDD, Class)", "RDD") { session => + session.createDataFrame(new RDD[Int], classOf[Int]) + } + + testUnsupportedFeature("SparkSession.createDataFrame(JavaRDD, Class)", "RDD") { session => + session.createDataFrame(new JavaRDD[Int], classOf[Int]) + } + + testUnsupportedFeature("SparkSession.createDataset(RDD)", "RDD") { session => + session.createDataset(new RDD[Int])(Encoders.scalaInt) + } + + testUnsupportedFeature("SparkSession.experimental", "SESSION_EXPERIMENTAL_METHODS") { + _.experimental + } + + testUnsupportedFeature("SparkSession.sessionState", "SESSION_SESSION_STATE") { + _.sessionState + } + + testUnsupportedFeature("SparkSession.sharedState", "SESSION_SHARED_STATE") { + _.sharedState + } + + testUnsupportedFeature("SparkSession.listenerManager", "SESSION_LISTENER_MANAGER") { + _.listenerManager + } + + testUnsupportedFeature("SparkSession.sqlContext", "SESSION_SQL_CONTEXT") { + _.sqlContext + } + + testUnsupportedFeature( + "SparkSession.baseRelationToDataFrame", + "SESSION_BASE_RELATION_TO_DATAFRAME") { + _.baseRelationToDataFrame(new BaseRelation) + } + + testUnsupportedFeature("SparkSession.executeCommand", "SESSION_EXECUTE_COMMAND") { + _.executeCommand("ds", "exec", Map.empty) + } + + testUnsupportedFeature("Dataset.queryExecution", "DATASET_QUERY_EXECUTION") { + _.range(1).queryExecution + } + + testUnsupportedFeature("Dataset.rdd", "RDD") { + _.range(1).rdd + } + + testUnsupportedFeature("Dataset.javaRDD", "RDD") { + _.range(1).javaRDD + } + + testUnsupportedFeature("Dataset.toJavaRDD", "RDD") { + _.range(1).toJavaRDD + } + + testUnsupportedFeature("DataFrameReader.json(RDD)", "RDD") { + _.read.json(new RDD[String]) + } + + testUnsupportedFeature("DataFrameReader.json(JavaRDD)", "RDD") { + _.read.json(new JavaRDD[String]) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index 4aec0e6348c09..ca754c7b542f7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -315,13 +315,20 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession "c") } - test("(deprecated) scala UDF with dataType") { + // TODO re-enable this after we hooked SqlApiConf into the session confs. + ignore("(deprecated) scala UDF with dataType") { val session: SparkSession = spark import session.implicits._ val fn = udf(((i: Long) => (i + 1).toInt), IntegerType) checkDataset(session.range(2).select(fn($"id")).as[Int], 1, 2) } + test("(deprecated) scala UDF with dataType should fail") { + intercept[AnalysisException] { + udf(((i: Long) => (i + 1).toInt), IntegerType) + } + } + test("java UDF") { val session: SparkSession = spark import session.implicits._ @@ -380,23 +387,23 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession } test("UDAF custom Aggregator - toColumn") { - val session: SparkSession = spark - import session.implicits._ + val encoder = Encoders.product[UdafTestInput] val aggCol = new CompleteUdafTestInputAggregator().toColumn - val ds = spark.range(10).withColumn("extra", col("id") * 2).as[UdafTestInput] - + val ds = spark.range(10).withColumn("extra", col("id") * 2).as(encoder) assert(ds.select(aggCol).head() == 135) // 45 + 90 - assert(ds.agg(aggCol).head().getLong(0) == 135) // 45 + 90 } test("UDAF custom Aggregator - multiple extends - toColumn") { - val session: SparkSession = spark - import session.implicits._ + val encoder = Encoders.product[UdafTestInput] val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn - val ds = spark.range(10).withColumn("extra", col("id") * 2).as[UdafTestInput] - + val ds = spark.range(10).withColumn("extra", col("id") * 2).as(encoder) assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4 - assert(ds.agg(aggCol).head().getLong(0) == 540) // (45 + 90) * 4 + } + + test("UDAF custom aggregator - with rows - toColumn") { + val ds = spark.range(10).withColumn("extra", col("id") * 2) + assert(ds.select(RowAggregator.toColumn).head() == 405) + assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405) } } @@ -441,3 +448,17 @@ final class CompleteGrandChildUdafTestInputAggregator (b._1 + a.id, b._2 + a.extra) override def finish(reduction: (Long, Long)): Long = (reduction._1 + reduction._2) * 4 } + +object RowAggregator extends Aggregator[Row, (Long, Long), Long] { + override def zero: (Long, Long) = (0, 0) + override def reduce(b: (Long, Long), a: Row): (Long, Long) = { + (b._1 + a.getLong(0), b._2 + a.getLong(1)) + } + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + override def finish(r: (Long, Long)): Long = (r._1 + r._2) * 3 + override def bufferEncoder: Encoder[(Long, Long)] = + Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong) + override def outputEncoder: Encoder[Long] = Encoders.scalaLong +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala index 923aa5af75ba8..403d019f01e2a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala @@ -21,7 +21,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.common.UdfPacket -import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils @@ -33,13 +33,13 @@ class UserDefinedFunctionSuite extends ConnectFunSuite { val myUdf = udf(func _) val colWithUdf = myUdf(Column("dummy")) - val udfExpr = colWithUdf.expr.getCommonInlineUserDefinedFunction + val udfExpr = toExpr(colWithUdf).getCommonInlineUserDefinedFunction assert(udfExpr.getDeterministic) assert(udfExpr.getArgumentsCount == 1) - assert(udfExpr.getArguments(0) == Column("dummy").expr) + assert(udfExpr.getArguments(0) == toExpr(Column("dummy"))) val udfObj = udfExpr.getScalarScalaUdf - assert(udfObj.getNullable) + assert(!udfObj.getNullable) val deSer = SparkSerDeUtils.deserialize[UdfPacket](udfObj.getPayload.toByteArray) @@ -49,7 +49,7 @@ class UserDefinedFunctionSuite extends ConnectFunSuite { } private def testNonDeserializable(f: Int => Int): Unit = { - val e = intercept[SparkException](udf(f)) + val e = intercept[SparkException](toExpr(udf(f).apply(lit(1)))) assert( e.getMessage.contains( "UDF cannot be executed on a Spark cluster: it cannot be deserialized.")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index d7977fbeb108f..bb7d1b25738c1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -64,6 +64,7 @@ class ReplE2ESuite extends ConnectFunSuite with RemoteSparkSession with BeforeAn val args = Array("--port", serverPort.toString) val task = new Runnable { override def run(): Unit = { + System.setProperty("spark.sql.abc", "abc") ConnectRepl.doMain( args = args, semaphore = Some(semaphore), @@ -555,4 +556,13 @@ class ReplE2ESuite extends ConnectFunSuite with RemoteSparkSession with BeforeAn val output = runCommandsInShell(input) assertContains(": Long = 1045", output) } + + test("Simple configuration set in startup") { + val input = + """ + |spark.conf.get("spark.sql.abc") + """.stripMargin + val output = runCommandsInShell(input) + assertContains("abc", output) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index bbc396a937c3e..66a2c943af5f6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -30,6 +30,7 @@ import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto.AddArtifactsRequest +import org.apache.spark.sql.Artifact import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.IvyTestUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index a6be271b5c7d1..d9ff8d9122ead 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -21,6 +21,8 @@ import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import java.util.regex.Pattern +import scala.collection.mutable.{Set => MutableSet} + import com.typesafe.tools.mima.core._ import com.typesafe.tools.mima.lib.MiMaLib @@ -143,21 +145,6 @@ object CheckConnectJvmClientCompatibility { checkMiMaCompatibility(clientJar, protobufJar, includedRules, excludeRules) } - private lazy val mergeIntoWriterExcludeRules: Seq[ProblemFilter] = { - // Exclude some auto-generated methods in [[MergeIntoWriter]] classes. - // The incompatible changes are due to the uses of [[proto.Expression]] instead - // of [[catalyst.Expression]] in the method signature. - val classNames = Seq("WhenMatched", "WhenNotMatched", "WhenNotMatchedBySource") - val methodNames = Seq("apply", "condition", "copy", "copy$*", "unapply") - - classNames.flatMap { className => - methodNames.map { methodName => - ProblemFilters.exclude[IncompatibleSignatureProblem]( - s"org.apache.spark.sql.$className.$methodName") - } - } - } - private def checkMiMaCompatibilityWithSqlModule( clientJar: File, sqlJar: File): List[Problem] = { @@ -166,20 +153,23 @@ object CheckConnectJvmClientCompatibility { // Filter unsupported rules: // Note when muting errors for a method, checks on all overloading methods are also muted. + // Skip any avro files + ProblemFilters.exclude[Problem]("org.apache.spark.sql.avro.*"), + // Skip unsupported packages ProblemFilters.exclude[Problem]("org.apache.spark.sql.api.*"), // Java, Python, R ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.columnar.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.expressions.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.jdbc.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.scripting.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.sources.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.streaming.ui.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.test.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.util.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.scripting.*"), // Skip private[sql] constructors ProblemFilters.exclude[Problem]("org.apache.spark.sql.*.this"), @@ -194,12 +184,8 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.ExtendedExplainGenerator"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDTFRegistration"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataSourceRegistration"), - // DataFrame Reader & Writer - ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.json"), // rdd - // DataFrameNaFunctions ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"), @@ -207,65 +193,23 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.Dataset$" // private[sql] ), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.ofRows"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_TAG"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.COL_POS_KEY"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_KEY"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.curId"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"), + // TODO (SPARK-49096): + // Mima check might complain the following Dataset rules does not filter any problem. + // This is due to a potential bug in Mima that all methods in `class Dataset` are not being + // checked for problems due to the presence of a private[sql] companion object. + // Further investigation is needed. ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), // protected - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"), - - // functions - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), - - // KeyValueGroupedDataset - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"), // RelationalGroupedDataset ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.RelationalGroupedDataset$*" // private[sql] ), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), // SparkSession - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sqlContext"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.dataSource"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"), - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"), - - // SparkSession#implicits - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#implicits._sqlContext"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#implicits.session"), - - // SparkSession#Builder - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#Builder.config"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#Builder.withExtensions"), - - // RuntimeConfig - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig$"), + "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.RuntimeConfig.sqlConf"), + "org.apache.spark.sql.SparkSession.canUseSession"), // DataStreamWriter ProblemFilters.exclude[MissingClassProblem]( @@ -279,44 +223,52 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.streaming.TestGroupState"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.TestGroupState$"), + + // Artifact Manager, client has a totally different implementation. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.artifact.ArtifactManager"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.PythonStreamingQueryListener"), + "org.apache.spark.sql.artifact.ArtifactManager$"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.PythonStreamingQueryListenerWrapper"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener$Event"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryIdleEvent"), + "org.apache.spark.sql.artifact.ArtifactManager$SparkContextResourceType$"), + + // ColumnNode conversions + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.Converter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$Converter$"), + + // UDFRegistration + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.register"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.UDFRegistration"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.log*"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener#QueryIdleEvent.logEvent"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent"), + "org.apache.spark.sql.UDFRegistration.LogStringContext"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgressEvent.logEvent"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"), + "org.apache.spark.sql.UDFRegistration.withLogContext"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener#QueryStartedEvent.logEvent"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"), + "org.apache.spark.sql.UDFRegistration.isTraceEnabled"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminatedEvent.logEvent"), + "org.apache.spark.sql.UDFRegistration.initializeForcefully"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.initializeLogIfNecessary"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.initializeLogIfNecessary$default$2"), - // SQLImplicits - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits._sqlContext"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.session"), + // Protected DataFrameReader methods... + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateSingleVariantColumn"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateJsonSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateXmlSchema"), - // Artifact Manager - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.artifact.ArtifactManager"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.artifact.ArtifactManager$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.artifact.util.ArtifactUtils"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.artifact.util.ArtifactUtils$")) ++ - mergeIntoWriterExcludeRules + // Protected DataStreamReader methods... + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.DataStreamReader.validateJsonSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.DataStreamReader.validateXmlSchema")) checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules) } @@ -330,9 +282,9 @@ object CheckConnectJvmClientCompatibility { val includedRules = Seq(IncludeByName("org.apache.spark.sql.*")) val excludeRules = Seq( // Skipped packages - ProblemFilters.exclude[Problem]("org.apache.spark.sql.avro.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.application.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connect.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.protobuf.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), // private[sql] ProblemFilters.exclude[Problem]("org.apache.spark.sql.*.this"), @@ -341,45 +293,14 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.KeyValueGroupedDatasetImpl"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.KeyValueGroupedDatasetImpl$"), - ProblemFilters.exclude[ReversedMissingMethodProblem]( - "org.apache.spark.sql.SQLImplicits._sqlContext" // protected - ), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.internal.SessionCleaner"), - - // private - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.internal.CleanupTask"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.internal.CleanupTaskWeakReference"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.internal.CleanupCachedRemoteRelation"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.internal.CleanupCachedRemoteRelation$"), - // Catalyst Refactoring - ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils$"), + // ColumnNode conversions + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.sql.SparkSession#RichColumn.expr"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession#RichColumn.typedExpr"), // New public APIs added in the client - // ScalaUserDefinedFunction - ProblemFilters - .exclude[MissingClassProblem]( - "org.apache.spark.sql.expressions.ScalaUserDefinedFunction"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.expressions.ScalaUserDefinedFunction$"), - - // New private API added in the client - ProblemFilters - .exclude[MissingClassProblem]( - "org.apache.spark.sql.expressions.SparkConnectClosureCleaner"), - ProblemFilters - .exclude[MissingClassProblem]( - "org.apache.spark.sql.expressions.SparkConnectClosureCleaner$"), - - // Column - // developer API - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Column.apply"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Column.expr"), - // Dataset ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.Dataset.plan" @@ -387,27 +308,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.Dataset.collectResult"), - // RuntimeConfig - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.RuntimeConfig" // Client version extends Logging - ), - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.RuntimeConfig.*" // Mute missing Logging methods - ), - // ConnectRepl - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.application.ConnectRepl" // developer API - ), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.application.ConnectRepl$" // developer API - ), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.application.ExtendedCodeClassWrapper" // developer API - ), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.application.ExtendedCodeClassWrapper$" // developer API - ), - // SparkSession // developer API ProblemFilters.exclude[DirectMissingMethodProblem]( @@ -417,37 +317,16 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.execute"), // Experimental - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addArtifact"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addArtifacts"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.registerClassFinder"), - // public - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptAll"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptOperation"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.removeTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.getTags"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.clearTags"), + ProblemFilters.exclude[IncompatibleSignatureProblem]( + "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"), // SparkSession#Builder - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#Builder.remote"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.client"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.build" // deprecated ), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#Builder.create"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.interceptor"), @@ -455,24 +334,13 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.session"), // Steaming API - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.streaming.DataStreamWriter" // Client version extends Logging - ), - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.streaming.DataStreamWriter.*" // Mute missing Logging methods - ), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.RemoteStreamingQuery"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.RemoteStreamingQuery$"), // Skip client side listener specific class ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListenerBus"), - - // Encoders are in the wrong JAR - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$")) ++ - mergeIntoWriterExcludeRules + "org.apache.spark.sql.streaming.StreamingQueryListenerBus")) checkMiMaCompatibility(sqlJar, clientJar, includedRules, excludeRules) } @@ -490,13 +358,24 @@ object CheckConnectJvmClientCompatibility { excludeRules: Seq[ProblemFilter]): List[Problem] = { val mima = new MiMaLib(Seq(newJar, oldJar)) val allProblems = mima.collectProblems(oldJar, newJar, List.empty) + + val effectiveExcludeRules = MutableSet.empty[ProblemFilter] val problems = allProblems .filter { p => includedRules.exists(rule => rule(p)) } .filter { p => - excludeRules.forall(rule => rule(p)) + excludeRules.forall { rule => + val passedRule = rule(p) + if (!passedRule) { + effectiveExcludeRules += rule + } + passedRule + } } + excludeRules.filterNot(effectiveExcludeRules.contains).foreach { rule => + println(s"Warning: $rule did not filter out any problems.") + } problems } @@ -511,11 +390,14 @@ object CheckConnectJvmClientCompatibility { resultWriter.write( s"ERROR: Comparing Client jar: $clientModule and $targetName jar: $targetModule \n") resultWriter.write(s"problems with $targetName module: \n") - resultWriter.write(s"${problems.map(p => p.description(description)).mkString("\n")}") - resultWriter.write("\n") - resultWriter.write( - "Exceptions to binary compatibility can be added in " + - s"'CheckConnectJvmClientCompatibility#checkMiMaCompatibilityWith${targetName}Module'\n") + val problemDescriptions = + problems.map(p => s"${p.getClass.getSimpleName}: ${p.description(description)}") + resultWriter.write(problemDescriptions.mkString("\n")) + resultWriter.write("\n\n") + resultWriter.write("Exceptions to binary compatibility can be added in " + + s"'CheckConnectJvmClientCompatibility#checkMiMaCompatibilityWith${targetName}Module':\n") + resultWriter.write(problems.flatMap(_.howToFilter).distinct.mkString(",\n")) + resultWriter.write("\n\n") } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 46aeaeff43d2f..ac56600392aa3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -224,7 +224,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { val error = constructor(testParams).asInstanceOf[Throwable with SparkThrowable] assert(error.getMessage.contains(testParams.message)) assert(error.getCause == null) - assert(error.getErrorClass == testParams.errorClass.get) + assert(error.getCondition == testParams.errorClass.get) assert(error.getMessageParameters.asScala == testParams.messageParameters) assert(error.getQueryContext.isEmpty) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 709e2cf0e84ea..10e4c11c406fe 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,11 +30,11 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.sql.{AnalysisException, Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.test.ConnectFunSuite -import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType} /** * Tests for encoding external data to and from arrow. @@ -769,6 +769,34 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } + test("java serialization") { + val encoder = agnosticEncoderFor(Encoders.javaSerialization[(Int, String)]) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "itr_" + i)) + } + } + + test("kryo serialization") { + val e = intercept[SparkRuntimeException] { + val encoder = agnosticEncoderFor(Encoders.kryo[(Int, String)]) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "itr_" + i)) + } + } + assert(e.getCondition == "CANNOT_USE_KRYO") + } + + test("transforming encoder") { + val schema = new StructType() + .add("key", IntegerType) + .add("value", StringType) + val encoder = + TransformingEncoder(classTag[(Int, String)], toRowEncoder(schema), () => new TestCodec) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "v" + i)) + } + } + /* ******************************************************************** * * Arrow deserialization upcasting * ******************************************************************** */ @@ -1136,3 +1164,8 @@ class UDTNotSupported extends UserDefinedType[UDTNotSupportedClass] { case i: Int => UDTNotSupportedClass(i) } } + +class TestCodec extends Codec[(Int, String), Row] { + override def encode(in: (Int, String)): Row = Row(in._1, in._2) + override def decode(out: Row): (Int, String) = (out.getInt(0), out.getString(1)) +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala new file mode 100644 index 0000000000000..2efd396735191 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToProtoConverterSuite.scala @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.SparkException +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Expression.Window.WindowFrame.FrameBoundary +import org.apache.spark.sql.{Column, Encoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoDataTypes} +import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator} +import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.types.{BinaryType, DataType, DoubleType, LongType, MetadataBuilder, ShortType, StringType, StructType} + +/** + * Test suite for [[ColumnNode]] to [[proto.Expression]] conversions. + */ +class ColumnNodeToProtoConverterSuite extends ConnectFunSuite { + private def testConversion( + node: => ColumnNode, + expected: proto.Expression): proto.Expression = { + val expression = ColumnNodeToProtoConverter(node) + assert(expression == expected) + expression + } + + private def expr(f: proto.Expression.Builder => Unit): proto.Expression = { + val builder = proto.Expression.newBuilder() + f(builder) + builder.build() + } + + private def attribute(name: String): proto.Expression = + expr(_.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)) + + private def structField( + name: String, + dataType: proto.DataType, + nullable: Boolean = true): proto.DataType.StructField = { + proto.DataType.StructField + .newBuilder() + .setName(name) + .setDataType(dataType) + .setNullable(nullable) + .build() + } + + test("literal") { + testConversion(Literal(1), expr(_.getLiteralBuilder.setInteger(1).build())) + testConversion( + Literal("foo", Option(StringType)), + expr(_.getLiteralBuilder.setString("foo").build())) + val dataType = new StructType() + .add("_1", DoubleType) + .add("_2", StringType) + .add("_3", DoubleType) + .add("_4", StringType) + val stringTypeWithCollation = proto.DataType + .newBuilder() + .setString(proto.DataType.String.newBuilder().setCollation("UTF8_BINARY")) + .build() + testConversion( + Literal((12.0, "north", 60.0, "west"), Option(dataType)), + expr { b => + val builder = b.getLiteralBuilder.getStructBuilder + builder.getStructTypeBuilder.getStructBuilder + .addFields(structField("_1", ProtoDataTypes.DoubleType)) + .addFields(structField("_2", stringTypeWithCollation)) + .addFields(structField("_3", ProtoDataTypes.DoubleType)) + .addFields(structField("_4", stringTypeWithCollation)) + builder.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0)) + builder.addElements(proto.Expression.Literal.newBuilder().setString("north")) + builder.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0)) + builder.addElements(proto.Expression.Literal.newBuilder().setString("west")) + }) + } + + test("attribute") { + testConversion(UnresolvedAttribute("x"), attribute("x")) + testConversion( + UnresolvedAttribute("y", Option(44L), isMetadataColumn = true), + expr( + _.getUnresolvedAttributeBuilder + .setUnparsedIdentifier("y") + .setPlanId(44L) + .setIsMetadataColumn(true))) + } + + test("star") { + testConversion(UnresolvedStar(None), expr(_.getUnresolvedStarBuilder)) + testConversion( + UnresolvedStar(Option("x.y.z.*")), + expr(_.getUnresolvedStarBuilder.setUnparsedTarget("x.y.z.*"))) + testConversion( + UnresolvedStar(None, Option(10L)), + expr(_.getUnresolvedStarBuilder.setPlanId(10L))) + } + + test("regex") { + testConversion( + UnresolvedRegex("`(_1)?+.+`"), + expr(_.getUnresolvedRegexBuilder.setColName("`(_1)?+.+`"))) + testConversion( + UnresolvedRegex("a", planId = Option(11L)), + expr(_.getUnresolvedRegexBuilder.setColName("a").setPlanId(11L))) + } + + test("function") { + testConversion( + UnresolvedFunction("+", Seq(UnresolvedAttribute("a"), Literal(1))), + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("+") + .setIsDistinct(false) + .addArguments(attribute("a")) + .addArguments(expr(_.getLiteralBuilder.setInteger(1))))) + testConversion( + UnresolvedFunction( + "db1.myAgg", + Seq(UnresolvedAttribute("a")), + isDistinct = true, + isUserDefinedFunction = true), + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("db1.myAgg") + .setIsDistinct(true) + .setIsUserDefinedFunction(true) + .addArguments(attribute("a")))) + } + + test("alias") { + testConversion( + Alias(Literal("qwe"), "newA" :: Nil), + expr( + _.getAliasBuilder + .setExpr(expr(_.getLiteralBuilder.setString("qwe"))) + .addName("newA"))) + val metadata = new MetadataBuilder().putLong("q", 10).build() + testConversion( + Alias(UnresolvedAttribute("a"), "b" :: Nil, Option(metadata)), + expr( + _.getAliasBuilder + .setExpr(attribute("a")) + .addName("b") + .setMetadata("""{"q":10}"""))) + testConversion( + Alias(UnresolvedAttribute("complex"), "newA" :: "newB" :: Nil), + expr( + _.getAliasBuilder + .setExpr(attribute("complex")) + .addName("newA") + .addName("newB"))) + } + + private def testCast( + dataType: DataType, + colEvalMode: Cast.EvalMode, + catEvalMode: proto.Expression.Cast.EvalMode): Unit = { + testConversion( + Cast(UnresolvedAttribute("attr"), dataType, Option(colEvalMode)), + expr( + _.getCastBuilder + .setExpr(attribute("attr")) + .setType(DataTypeProtoConverter.toConnectProtoType(dataType)) + .setEvalMode(catEvalMode))) + } + + test("cast") { + testConversion( + Cast(UnresolvedAttribute("str"), DoubleType), + expr( + _.getCastBuilder + .setExpr(attribute("str")) + .setType(ProtoDataTypes.DoubleType))) + + testCast(LongType, Cast.Legacy, proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY) + testCast(BinaryType, Cast.Try, proto.Expression.Cast.EvalMode.EVAL_MODE_TRY) + testCast(ShortType, Cast.Ansi, proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI) + } + + private def testSortOrder( + colDirection: SortOrder.SortDirection, + colNullOrdering: SortOrder.NullOrdering, + catDirection: proto.Expression.SortOrder.SortDirection, + catNullOrdering: proto.Expression.SortOrder.NullOrdering): Unit = { + testConversion( + SortOrder(UnresolvedAttribute("unsorted"), colDirection, colNullOrdering), + expr( + _.getSortOrderBuilder + .setChild(attribute("unsorted")) + .setNullOrdering(catNullOrdering) + .setDirection(catDirection))) + } + + test("sortOrder") { + testSortOrder( + SortOrder.Ascending, + SortOrder.NullsFirst, + proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING, + proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST) + testSortOrder( + SortOrder.Ascending, + SortOrder.NullsLast, + proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING, + proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST) + testSortOrder( + SortOrder.Descending, + SortOrder.NullsFirst, + proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING, + proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST) + testSortOrder( + SortOrder.Descending, + SortOrder.NullsLast, + proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING, + proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST) + } + + private def testWindowFrame( + colFrameType: WindowFrame.FrameType, + colLower: WindowFrame.FrameBoundary, + colUpper: WindowFrame.FrameBoundary, + catFrameType: proto.Expression.Window.WindowFrame.FrameType, + catLower: proto.Expression.Window.WindowFrame.FrameBoundary, + catUpper: proto.Expression.Window.WindowFrame.FrameBoundary): Unit = { + testConversion( + Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a"))), + WindowSpec( + Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Seq(SortOrder(UnresolvedAttribute("d"), SortOrder.Descending, SortOrder.NullsLast)), + Option(WindowFrame(colFrameType, colLower, colUpper)))), + expr( + _.getWindowBuilder + .setWindowFunction( + expr(_.getUnresolvedFunctionBuilder + .setFunctionName("sum") + .setIsDistinct(false) + .addArguments(attribute("a")))) + .addPartitionSpec(attribute("b")) + .addPartitionSpec(attribute("c")) + .addOrderSpec(proto.Expression.SortOrder + .newBuilder() + .setChild(attribute("d")) + .setDirection(proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING) + .setNullOrdering(proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST)) + .getFrameSpecBuilder + .setFrameType(catFrameType) + .setLower(catLower) + .setUpper(catUpper))) + } + + test("window") { + testConversion( + Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a"))), + WindowSpec(Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), Nil, None)), + expr( + _.getWindowBuilder + .setWindowFunction( + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("sum") + .setIsDistinct(false) + .addArguments(attribute("a")))) + .addPartitionSpec(attribute("b")) + .addPartitionSpec(attribute("c")))) + testWindowFrame( + WindowFrame.Row, + WindowFrame.Value(Literal(-10)), + WindowFrame.UnboundedFollowing, + proto.Expression.Window.WindowFrame.FrameType.FRAME_TYPE_ROW, + FrameBoundary.newBuilder().setValue(expr(_.getLiteralBuilder.setInteger(-10))).build(), + FrameBoundary.newBuilder().setUnbounded(true).build()) + testWindowFrame( + WindowFrame.Range, + WindowFrame.UnboundedPreceding, + WindowFrame.CurrentRow, + proto.Expression.Window.WindowFrame.FrameType.FRAME_TYPE_RANGE, + FrameBoundary.newBuilder().setUnbounded(true).build(), + FrameBoundary.newBuilder().setCurrentRow(true).build()) + } + + test("lambda") { + val colX = UnresolvedNamedLambdaVariable("x") + val catX = proto.Expression.UnresolvedNamedLambdaVariable + .newBuilder() + .addNameParts(colX.name) + .build() + testConversion( + LambdaFunction(UnresolvedFunction("+", Seq(colX, UnresolvedAttribute("y"))), Seq(colX)), + expr( + _.getLambdaFunctionBuilder + .setFunction( + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("+") + .addArguments(expr(_.setUnresolvedNamedLambdaVariable(catX))) + .addArguments(attribute("y")))) + .addArguments(catX))) + } + + test("sql") { + testConversion( + SqlExpression("1 + 1"), + expr(_.getExpressionStringBuilder.setExpression("1 + 1"))) + } + + test("caseWhen") { + testConversion( + CaseWhenOtherwise( + Seq(UnresolvedAttribute("c1") -> Literal("r1")), + Option(Literal("fallback"))), + expr( + _.getUnresolvedFunctionBuilder + .setFunctionName("when") + .addArguments(attribute("c1")) + .addArguments(expr(_.getLiteralBuilder.setString("r1"))) + .addArguments(expr(_.getLiteralBuilder.setString("fallback"))))) + } + + test("extract field") { + testConversion( + UnresolvedExtractValue(UnresolvedAttribute("struct"), Literal("cl_a")), + expr( + _.getUnresolvedExtractValueBuilder + .setChild(attribute("struct")) + .setExtraction(expr(_.getLiteralBuilder.setString("cl_a"))))) + } + + test("update field") { + testConversion( + UpdateFields(UnresolvedAttribute("struct"), "col_b", Option(Literal("cl_a"))), + expr( + _.getUpdateFieldsBuilder + .setStructExpression(attribute("struct")) + .setFieldName("col_b") + .setValueExpression(expr(_.getLiteralBuilder.setString("cl_a"))))) + + testConversion( + UpdateFields(UnresolvedAttribute("struct"), "col_c", None), + expr( + _.getUpdateFieldsBuilder + .setStructExpression(attribute("struct")) + .setFieldName("col_c"))) + } + + test("udf") { + val fn = (i: Int) => i + 1 + val udf = SparkUserDefinedFunction(fn, PrimitiveIntEncoder :: Nil, PrimitiveIntEncoder) + val named = udf.withName("boo").asNondeterministic() + testConversion( + InvokeInlineUserDefinedFunction(named, Seq(UnresolvedAttribute("a"))), + expr( + _.getCommonInlineUserDefinedFunctionBuilder + .setFunctionName("boo") + .setDeterministic(false) + .addArguments(attribute("a")) + .getScalarScalaUdfBuilder + .setPayload( + UdfToProtoUtils.toUdfPacketBytes(fn, PrimitiveIntEncoder :: Nil, PrimitiveIntEncoder)) + .addInputTypes(ProtoDataTypes.IntegerType) + .setOutputType(ProtoDataTypes.IntegerType) + .setNullable(false) + .setAggregate(false))) + + val aggregator = new Aggregator[Long, Long, Long] { + override def zero: Long = 0 + override def reduce(b: Long, a: Long): Long = a + b + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction + override def bufferEncoder: Encoder[Long] = PrimitiveLongEncoder + override def outputEncoder: Encoder[Long] = PrimitiveLongEncoder + } + val uda = UserDefinedAggregator(aggregator, PrimitiveLongEncoder) + .withName("lsum") + .asNonNullable() + testConversion( + InvokeInlineUserDefinedFunction(uda, Seq(UnresolvedAttribute(("a")))), + expr( + _.getCommonInlineUserDefinedFunctionBuilder + .setFunctionName("lsum") + .setDeterministic(true) + .addArguments(attribute("a")) + .getScalarScalaUdfBuilder + .setPayload(UdfToProtoUtils + .toUdfPacketBytes(aggregator, PrimitiveLongEncoder :: Nil, PrimitiveLongEncoder)) + .addInputTypes(ProtoDataTypes.LongType) + .setOutputType(ProtoDataTypes.LongType) + .setNullable(false) + .setAggregate(true))) + + val result = ColumnNodeToProtoConverter.toTypedExpr( + Column(InvokeInlineUserDefinedFunction(aggregator, Nil)), + PrimitiveLongEncoder) + val expected = expr( + _.getTypedAggregateExpressionBuilder.getScalarScalaUdfBuilder + .setPayload(UdfToProtoUtils + .toUdfPacketBytes(aggregator, PrimitiveLongEncoder :: Nil, PrimitiveLongEncoder)) + .addInputTypes(ProtoDataTypes.LongType) + .setOutputType(ProtoDataTypes.LongType) + .setNullable(true) + .setAggregate(true)) + assert(result == expected) + } + + test("extension") { + val e = attribute("name") + testConversion(ProtoColumnNode(e), e) + } + + test("unsupported") { + intercept[SparkException](ColumnNodeToProtoConverter(Nope())) + } +} + +private[internal] case class Nope(override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { + override def sql: String = "nope" +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index fbb2b4af1b128..b1a7d81916e92 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -42,6 +42,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L private val testDataPath = Paths .get( IntegrationTestUtils.sparkHome, + "sql", "connect", "common", "src", @@ -268,6 +269,42 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L } } + test("clusterBy") { + withSQLConf( + "spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers. + ) { + spark.sql("DROP TABLE IF EXISTS my_table").collect() + + withTempPath { ckpt => + val q1 = spark.readStream + .format("rate") + .load() + .writeStream + .clusterBy("value") + .option("checkpointLocation", ckpt.getCanonicalPath) + .toTable("my_table") + + try { + q1.processAllAvailable() + eventually(timeout(30.seconds)) { + checkAnswer( + spark.sql("DESCRIBE my_table"), + Seq( + Row("timestamp", "timestamp", null), + Row("value", "bigint", null), + Row("# Clustering Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("value", "bigint", null))) + assert(spark.table("my_sink").count() > 0) + } + } finally { + q1.stop() + spark.sql("DROP TABLE my_table") + } + } + } + } + test("throw exception in streaming") { try { val session = spark @@ -294,11 +331,9 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L query.awaitTermination() } - assert(exception.getErrorClass != null) + assert(exception.getCondition != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) - assert(!exception.getMessageParameters().get("startOffset").isEmpty) - assert(!exception.getMessageParameters().get("endOffset").isEmpty) assert(exception.getCause.isInstanceOf[SparkException]) assert(exception.getCause.getCause.isInstanceOf[SparkException]) assert( @@ -334,11 +369,9 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L spark.streams.awaitAnyTermination() } - assert(exception.getErrorClass != null) + assert(exception.getCondition != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) - assert(!exception.getMessageParameters().get("startOffset").isEmpty) - assert(!exception.getMessageParameters().get("endOffset").isEmpty) assert(exception.getCause.isInstanceOf[SparkException]) assert(exception.getCause.getCause.isInstanceOf[SparkException]) assert( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala index 608d9b477148c..f46b98646ae4f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala @@ -20,6 +20,10 @@ import java.nio.file.Path import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite +import org.apache.spark.connect.proto +import org.apache.spark.sql.Column +import org.apache.spark.sql.internal.ColumnNodeToProtoConverter + /** * The basic testsuite the client tests should extend from. */ @@ -35,10 +39,19 @@ trait ConnectFunSuite extends AnyFunSuite { // scalastyle:ignore funsuite } protected def baseResourcePath: Path = { - getWorkspaceFilePath("connect", "client", "jvm", "src", "test", "resources").toAbsolutePath + getWorkspaceFilePath( + "sql", + "connect", + "client", + "jvm", + "src", + "test", + "resources").toAbsolutePath } protected def commonResourcePath: Path = { - getWorkspaceFilePath("connect", "common", "src", "test", "resources").toAbsolutePath + getWorkspaceFilePath("sql", "connect", "common", "src", "test", "resources").toAbsolutePath } + + protected def toExpr(c: Column): proto.Expression = ColumnNodeToProtoConverter.toExpr(c) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala index 95326ce7e8cec..36aaa2cc7fbf6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala @@ -24,6 +24,9 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration.FiniteDuration import org.scalatest.{BeforeAndAfterAll, Suite} +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.timeout +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkBuildInfo import org.apache.spark.sql.SparkSession @@ -60,7 +63,7 @@ object SparkConnectServerUtils { private lazy val sparkConnect: java.lang.Process = { debug("Starting the Spark Connect Server...") val connectJar = - findJar("connect/server", "spark-connect-assembly", "spark-connect").getCanonicalPath + findJar("sql/connect/server", "spark-connect-assembly", "spark-connect").getCanonicalPath val command = Seq.newBuilder[String] command += "bin/spark-submit" @@ -121,6 +124,8 @@ object SparkConnectServerUtils { // to make the tests exercise reattach. "spark.connect.execute.reattachable.senderMaxStreamDuration=1s", "spark.connect.execute.reattachable.senderMaxStreamSize=123", + // Testing SPARK-49673, setting maxBatchSize to 10MiB + s"spark.connect.grpc.arrow.maxBatchSize=${10 * 1024 * 1024}", // Disable UI "spark.ui.enabled=false") Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil) @@ -184,12 +189,14 @@ object SparkConnectServerUtils { .port(port) .retryPolicy(RetryPolicy .defaultPolicy() - .copy(maxRetries = Some(7), maxBackoff = Some(FiniteDuration(10, "s")))) + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) .build()) .create() // Execute an RPC which will get retried until the server is up. - assert(spark.version == SparkBuildInfo.spark_version) + eventually(timeout(1.minute)) { + assert(spark.version == SparkBuildInfo.spark_version) + } // Auto-sync dependencies. SparkConnectServerUtils.syncTestDependencies(spark) diff --git a/connector/docker-integration-tests/README.md b/connector/docker-integration-tests/README.md index 03d3fe706a606..50000ae196107 100644 --- a/connector/docker-integration-tests/README.md +++ b/connector/docker-integration-tests/README.md @@ -45,7 +45,7 @@ the container bootstrapping. To run an individual Docker integration test, use t Besides the default Docker images, the integration tests can be run with custom Docker images. For example, - ORACLE_DOCKER_IMAGE_NAME=gvenzl/oracle-free:23.4-slim-faststart ./build/sbt -Pdocker-integration-tests "docker-integration-tests/testOnly *OracleIntegrationSuite" + ORACLE_DOCKER_IMAGE_NAME=gvenzl/oracle-free:23.5-slim-faststart ./build/sbt -Pdocker-integration-tests "docker-integration-tests/testOnly *OracleIntegrationSuite" The following environment variables can be used to specify the custom Docker images for different databases: diff --git a/connector/docker-integration-tests/pom.xml b/connector/docker-integration-tests/pom.xml index 9003c2190be22..91d0c68c73158 100644 --- a/connector/docker-integration-tests/pom.xml +++ b/connector/docker-integration-tests/pom.xml @@ -98,7 +98,7 @@ com.oracle.database.jdbc - ojdbc11 + ojdbc17 test diff --git a/connector/docker-integration-tests/src/test/resources/db2_krb_setup.sh b/connector/docker-integration-tests/src/test/resources/db2-krb-setup.sh similarity index 100% rename from connector/docker-integration-tests/src/test/resources/db2_krb_setup.sh rename to connector/docker-integration-tests/src/test/resources/db2-krb-setup.sh diff --git a/connector/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh b/connector/docker-integration-tests/src/test/resources/mariadb-docker-entrypoint.sh similarity index 100% rename from connector/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh rename to connector/docker-integration-tests/src/test/resources/mariadb-docker-entrypoint.sh diff --git a/connector/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh b/connector/docker-integration-tests/src/test/resources/mariadb-krb-setup.sh similarity index 100% rename from connector/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh rename to connector/docker-integration-tests/src/test/resources/mariadb-krb-setup.sh diff --git a/connector/docker-integration-tests/src/test/resources/postgres_krb_setup.sh b/connector/docker-integration-tests/src/test/resources/postgres-krb-setup.sh similarity index 100% rename from connector/docker-integration-tests/src/test/resources/postgres_krb_setup.sh rename to connector/docker-integration-tests/src/test/resources/postgres-krb-setup.sh diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2DatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2DatabaseOnDocker.scala index 5cfd998b20a75..34db51ab0fd72 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2DatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2DatabaseOnDocker.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.jdbc class DB2DatabaseOnDocker extends DatabaseOnDocker { - override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", "ibmcom/db2:11.5.8.0") + override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", + "icr.io/db2_community/db2:11.5.9.0") override val env = Map( "DB2INST1_PASSWORD" -> "rootpass", "LICENSE" -> "accept", diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index 72b2ac8074f4a..1d33acfdee013 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.types.{ByteType, ShortType, StructType} import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.8.0): + * To run this test suite for a specific version (e.g., icr.io/db2_community/db2:11.5.9.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.8.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=icr.io/db2_community/db2:11.5.9.0 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.DB2IntegrationSuite" * }}} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala index 4899de2b2a14c..c7d8fc43393ed 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.{DB2Connection import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.8.0): + * To run this test suite for a specific version (e.g., icr.io/db2_community/db2:11.5.9.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.8.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=icr.io/db2_community/db2:11.5.9.0 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly *DB2KrbIntegrationSuite" * }}} @@ -57,7 +57,7 @@ class DB2KrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override def beforeContainerStart( hostConfigBuilder: HostConfig, containerConfigBuilder: ContainerConfig): Unit = { - copyExecutableResource("db2_krb_setup.sh", initDbDir, replaceIp) + copyExecutableResource("db2-krb-setup.sh", initDbDir, replaceIp) val newBind = new Bind( initDbDir.getAbsolutePath, diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index 8d17e0b4e36e6..1df01bd3bfb62 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -115,7 +115,7 @@ abstract class DockerJDBCIntegrationSuite protected val startContainerTimeout: Long = timeStringAsSeconds(sys.props.getOrElse("spark.test.docker.startContainerTimeout", "5min")) protected val connectionTimeout: PatienceConfiguration.Timeout = { - val timeoutStr = sys.props.getOrElse("spark.test.docker.connectionTimeout", "5min") + val timeoutStr = sys.props.getOrElse("spark.test.docker.connectionTimeout", "10min") timeout(timeStringAsSeconds(timeoutStr).seconds) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala index efb2fa09f6a3f..32c552eb8c7eb 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnecti import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., mariadb:10.5.25): + * To run this test suite for a specific version (e.g., mariadb:10.6.19): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MARIADB_DOCKER_IMAGE_NAME=mariadb:10.5.25 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MARIADB_DOCKER_IMAGE_NAME=mariadb:10.6.19 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.MariaDBKrbIntegrationSuite" * }}} @@ -38,7 +38,7 @@ class MariaDBKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val keytabFileName = "mariadb.keytab" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MARIADB_DOCKER_IMAGE_NAME", "mariadb:10.5.25") + override val imageName = sys.env.getOrElse("MARIADB_DOCKER_IMAGE_NAME", "mariadb:10.6.19") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) @@ -49,13 +49,13 @@ class MariaDBKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { s"jdbc:mysql://$ip:$port/mysql?user=$principal" override def getEntryPoint: Option[String] = - Some("/docker-entrypoint/mariadb_docker_entrypoint.sh") + Some("/docker-entrypoint/mariadb-docker-entrypoint.sh") override def beforeContainerStart( hostConfigBuilder: HostConfig, containerConfigBuilder: ContainerConfig): Unit = { - copyExecutableResource("mariadb_docker_entrypoint.sh", entryPointDir, replaceIp) - copyExecutableResource("mariadb_krb_setup.sh", initDbDir, replaceIp) + copyExecutableResource("mariadb-docker-entrypoint.sh", entryPointDir, replaceIp) + copyExecutableResource("mariadb-krb-setup.sh", initDbDir, replaceIp) val binds = Seq(entryPointDir -> "/docker-entrypoint", initDbDir -> "/docker-entrypoint-initdb.d") diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala index 61530f713eb86..6bd33356cab3d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc class MsSQLServerDatabaseOnDocker extends DatabaseOnDocker { override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2022-CU12-GDR1-ubuntu-22.04") + "mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04") override val env = Map( "SA_PASSWORD" -> "Sapass123", "ACCEPT_EULA" -> "Y" diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 623f404339e9e..62f088ebc2b6d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -31,10 +31,10 @@ import org.apache.spark.sql.types.{BinaryType, DecimalType} import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2019-CU13-ubuntu-20.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.MsSqlServerIntegrationSuite" * }}} @@ -490,7 +490,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { .option("query", "SELECT @myvariant1 as variant1, @myvariant2 as variant2") .load() }, - errorClass = "UNRECOGNIZED_SQL_TYPE", + condition = "UNRECOGNIZED_SQL_TYPE", parameters = Map("typeName" -> "sql_variant", "jdbcType" -> "-156")) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala index 97f076eccf941..d5b61d479af1a 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLDatabaseOnDocker.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.jdbc class MySQLDatabaseOnDocker extends DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:9.0.0") + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:9.1.0") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index b0ba86a9f6c4f..deeb52267a145 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.types.ShortType import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., mysql:9.0.0): + * To run this test suite for a specific version (e.g., mysql:9.1.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.0.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.1.0 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.MySQLIntegrationSuite" * }}} @@ -351,13 +351,22 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { val df = spark.read.jdbc(jdbcUrl, "smallint_round_trip", new Properties) assert(df.schema.fields.head.dataType === ShortType) } + + test("SPARK-44638: Char/Varchar in Custom Schema") { + val df = spark.read.option("url", jdbcUrl) + .option("query", "SELECT c, d from strings") + .option("customSchema", "c CHAR(10), d VARCHAR(10)") + .format("jdbc") + .load() + checkAnswer(df, Row("brown ", "fox")) + } } /** - * To run this test suite for a specific version (e.g., mysql:9.0.0): + * To run this test suite for a specific version (e.g., mysql:9.1.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.0.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.1.0 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly *MySQLOverMariaConnectorIntegrationSuite" * }}} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleDatabaseOnDocker.scala index dd6bbf0af8a33..87e83ff59a641 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleDatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleDatabaseOnDocker.scala @@ -21,7 +21,7 @@ import org.apache.spark.internal.Logging class OracleDatabaseOnDocker extends DatabaseOnDocker with Logging { lazy override val imageName = - sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-free:23.4-slim") + sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-free:23.5-slim") val oracle_password = "Th1s1sThe0racle#Pass" override val env = Map( "ORACLE_PWD" -> oracle_password, // oracle images uses this diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 2b2596289548c..cf547b93aa0ba 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.execution.{RowDataSourceScanExec, WholeStageCodegenExec} -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.LogicalRelationWithTable import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -456,7 +456,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark .load() df1.logicalPlan match { - case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + case LogicalRelationWithTable(JDBCRelation(_, parts, _), _) => val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet assert(whereClauses === Set( """"D" < '2018-07-11' or "D" is null""", @@ -479,7 +479,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark .load() df2.logicalPlan match { - case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + case LogicalRelationWithTable(JDBCRelation(_, parts, _), _) => val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet assert(whereClauses === Set( """"T" < '2018-07-15 20:50:32.5' or "T" is null""", diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresDatabaseOnDocker.scala new file mode 100644 index 0000000000000..db2495ad3c698 --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresDatabaseOnDocker.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.internal.Logging + +class PostgresDatabaseOnDocker extends DatabaseOnDocker with Logging { + lazy override val imageName: String = + sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.2-alpine") + private val postgres_user = "postgres" + private val postgres_password = "rootpass" + override val env: Map[String, String] = Map( + "POSTGRES_PASSWORD" -> postgres_password + ) + override val usesIpc = false + override val jdbcPort: Int = 5432 + + override def getJdbcUrl(ip: String, port: Int): String = { + s"jdbc:postgresql://$ip:$port/postgres?user=$postgres_user&password=$postgres_password" + } +} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 12a71dbd7c7f8..5c985da226b06 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -24,33 +24,24 @@ import java.time.LocalDateTime import java.util.Properties import org.apache.spark.SparkException -import org.apache.spark.sql.{Column, DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.3-alpine): + * To run this test suite for a specific version (e.g., postgres:17.2-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.2-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} */ @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") - override val env = Map( - "POSTGRES_PASSWORD" -> "rootpass" - ) - override val usesIpc = false - override val jdbcPort = 5432 - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" - } + override val db = new PostgresDatabaseOnDocker override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE DATABASE foo").executeUpdate() @@ -303,7 +294,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { ArrayType(DecimalType(2, 2), true)) // Test write null values. df.select(df.queryExecution.analyzed.output.map { a => - Column(Literal.create(null, a.dataType)).as(a.name) + lit(null).cast(a.dataType).as(a.name) }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index af1cd464ad5fe..b3cfe8bd77e2b 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnecti import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.3-alpine): + * To run this test suite for a specific version (e.g., postgres:17.2-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.2-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly *PostgresKrbIntegrationSuite" * }}} @@ -37,21 +37,14 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val userName = s"postgres/$dockerIp" override protected val keytabFileName = "postgres.keytab" - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") - override val env = Map( - "POSTGRES_PASSWORD" -> "rootpass" - ) - override val usesIpc = false - override val jdbcPort = 5432 - + override val db = new PostgresDatabaseOnDocker { override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:postgresql://$ip:$port/postgres?user=$principal&gsslib=gssapi" override def beforeContainerStart( hostConfigBuilder: HostConfig, containerConfigBuilder: ContainerConfig): Unit = { - copyExecutableResource("postgres_krb_setup.sh", initDbDir, replaceIp) + copyExecutableResource("postgres-krb-setup.sh", initDbDir, replaceIp) val newBind = new Bind( initDbDir.getAbsolutePath, new Volume("/docker-entrypoint-initdb.d"), diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala index b526599482daf..3a1d5e18b7e5a 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.tags.DockerTest /** * This suite is used to generate subqueries, and test Spark against Postgres. - * To run this test suite for a specific version (e.g., postgres:16.3-alpine): + * To run this test suite for a specific version (e.g., postgres:17.2-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.2-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.GeneratedSubquerySuite" * }}} @@ -38,16 +38,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class GeneratedSubquerySuite extends DockerJDBCIntegrationSuite with QueryGeneratorHelper { - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") - override val env = Map( - "POSTGRES_PASSWORD" -> "rootpass" - ) - override val usesIpc = false - override val jdbcPort = 5432 - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" - } + override val db = new PostgresDatabaseOnDocker private val FIRST_COLUMN = "a" private val SECOND_COLUMN = "b" @@ -145,10 +136,6 @@ class GeneratedSubquerySuite extends DockerJDBCIntegrationSuite with QueryGenera None } - // SPARK-46446: offset operator in correlated subquery is not supported - // as it creates incorrect results for now. - val requireNoOffsetInCorrelatedSubquery = correlationConditions.nonEmpty - // For the Limit clause, consider whether the subquery needs to return 1 row, or whether the // operator to be included is a Limit. val limitAndOffsetClause = if (requiresExactlyOneRowOutput) { @@ -156,11 +143,10 @@ class GeneratedSubquerySuite extends DockerJDBCIntegrationSuite with QueryGenera } else { operatorInSubquery match { case lo: LimitAndOffset => - val offsetValue = if (requireNoOffsetInCorrelatedSubquery) 0 else lo.offsetValue - if (offsetValue == 0 && lo.limitValue == 0) { + if (lo.offsetValue == 0 && lo.limitValue == 0) { None } else { - Some(LimitAndOffset(lo.limitValue, offsetValue)) + Some(lo) } case _ => None } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgresSQLQueryTestSuite.scala similarity index 82% rename from connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala rename to connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgresSQLQueryTestSuite.scala index de28e16b325ce..28320a9e0a949 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgresSQLQueryTestSuite.scala @@ -30,31 +30,21 @@ import org.apache.spark.tags.DockerTest * confidence, and you won't have to manually verify the golden files generated with your test. * 2. Add this line to your .sql file: --ONLY_IF spark * - * Note: To run this test suite for a specific version (e.g., postgres:16.3-alpine): + * Note: To run this test suite for a specific version (e.g., postgres:17.2-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.2-alpine * ./build/sbt -Pdocker-integration-tests * "testOnly org.apache.spark.sql.jdbc.PostgreSQLQueryTestSuite" * }}} */ @DockerTest -class PostgreSQLQueryTestSuite extends CrossDbmsQueryTestSuite { +class PostgresSQLQueryTestSuite extends CrossDbmsQueryTestSuite { val DATABASE_NAME = CrossDbmsQueryTestSuite.POSTGRES // Scope to only subquery directory for now. protected val customInputFilePath: String = new File(inputFilePath, "subquery").getAbsolutePath - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") - override val env = Map( - "POSTGRES_PASSWORD" -> "rootpass" - ) - override val usesIpc = false - override val jdbcPort = 5432 - - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" - } + override val db = new PostgresDatabaseOnDocker override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement( diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 57129e9d846f6..91a82075a3607 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.8.0): + * To run this test suite for a specific version (e.g., icr.io/db2_community/db2:11.5.9.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.8.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=icr.io/db2_community/db2:11.5.9.0 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.DB2IntegrationSuite" * }}} */ @@ -87,7 +87,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"DOUBLE\"", "newType" -> "\"VARCHAR(10)\"", diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala index 528b26e61e761..385039fb6bd51 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.8.0): + * To run this test suite for a specific version (e.g., icr.io/db2_community/db2:11.5.9.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.8.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=icr.io/db2_community/db2:11.5.9.0 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.DB2NamespaceSuite" * }}} */ diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 9ddd79fb257d8..fd7efb1efb764 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -20,23 +20,38 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2019-CU13-ubuntu-20.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MsSqlServerIntegrationSuite" * }}} */ @DockerTest class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { + def getExternalEngineQuery(executedPlan: SparkPlan): String = { + getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery + } + + def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = { + val queryNode = executedPlan.collect { case r: RowDataSourceScanExec => + r + }.head + queryNode.rdd + } + override def excluded: Seq[String] = Seq( "simple scan with OFFSET", "simple scan with LIMIT and OFFSET", @@ -97,7 +112,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -115,7 +130,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - errorClass = "_LEGACY_ERROR_TEMP_2271") + condition = "UNSUPPORTED_FEATURE.UPDATE_COLUMN_NULLABILITY") } test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") { @@ -146,4 +161,68 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |""".stripMargin) assert(df.collect().length == 2) } + + test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 'Wizard') END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN + | CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name = 'Gandalf') END + | ELSE (name = 'Sauron') END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN + | CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END + | ELSE 'Sauron' END = name + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """ + ) + // scalastyle:on + df.collect() + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala index e010a0caf13fa..724c394a4f052 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2019-CU13-ubuntu-20.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.MsSqlServerNamespaceSuite" * }}} */ diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index c48505ff153fd..9c07ce939fe3e 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., mysql:9.0.0): + * To run this test suite for a specific version (e.g., mysql:9.1.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.0.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.1.0 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLIntegrationSuite" * }}} */ @@ -77,8 +77,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest s"""CREATE TABLE pattern_testing_table ( |pattern_testing_col LONGTEXT |) - """.stripMargin + |""".stripMargin ).executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation(connection) + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -98,7 +109,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -131,7 +142,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - errorClass = "_LEGACY_ERROR_TEMP_2271") + condition = "UNSUPPORTED_FEATURE.UPDATE_COLUMN_NULLABILITY") } override def testCreateTableWithProperty(tbl: String): Unit = { @@ -157,12 +168,85 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest assert(sql(s"SELECT char_length(c1) from $tableName").head().get(0) === 65536) } } + + override def testDatetime(tbl: String): Unit = { + val df1 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ") + checkFilterPushed(df1) + val rows1 = df1.collect() + assert(rows1.length === 2) + assert(rows1(0).getString(0) === "amy") + assert(rows1(1).getString(0) === "alex") + + val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2") + checkFilterPushed(df2) + val rows2 = df2.collect() + assert(rows2.length === 2) + assert(rows2(0).getString(0) === "amy") + assert(rows2(1).getString(0) === "alex") + + val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5") + checkFilterPushed(df3) + val rows3 = df3.collect() + assert(rows3.length === 2) + assert(rows3(0).getString(0) === "amy") + assert(rows3(1).getString(0) === "alex") + + val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0") + checkFilterPushed(df4) + val rows4 = df4.collect() + assert(rows4.length === 2) + assert(rows4(0).getString(0) === "amy") + assert(rows4(1).getString(0) === "alex") + + val df5 = sql(s"SELECT name FROM $tbl WHERE " + + "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022") + checkFilterPushed(df5) + val rows5 = df5.collect() + assert(rows5.length === 2) + assert(rows5(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + + val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " + + "AND datediff(date1, '2022-05-10') > 0") + checkFilterPushed(df6) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "amy") + + val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2") + checkFilterPushed(df7) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "alex") + + val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4") + checkFilterPushed(df8) + val rows8 = df8.collect() + assert(rows8.length === 1) + assert(rows8(0).getString(0) === "alex") + + val df9 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFilterPushed(df9) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "alex") + + // MySQL does not support + val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'") + checkFilterPushed(df10, false) + val rows10 = df10.collect() + assert(rows10.length === 2) + assert(rows10(0).getString(0) === "amy") + assert(rows10(1).getString(0) === "alex") + } } /** - * To run this test suite for a specific version (e.g., mysql:9.0.0): + * To run this test suite for a specific version (e.g., mysql:9.1.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.0.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.1.0 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly *MySQLOverMariaConnectorIntegrationSuite" * }}} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala index b49f6901034b1..5db1b9f863adb 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., mysql:9.0.0): + * To run this test suite for a specific version (e.g., mysql:9.1.0): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.0.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:9.1.0 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLNamespaceSuite" * }}} */ @@ -62,7 +62,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) }, - errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) assert(catalog.namespaceExists(Array("foo")) === false) @@ -74,7 +74,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac Array("foo"), NamespaceChange.setProperty("comment", "comment for foo")) }, - errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) @@ -82,7 +82,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) }, - errorClass = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT", + condition = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT", parameters = Map("namespace" -> "`foo`") ) @@ -90,7 +90,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.dropNamespace(Array("foo"), cascade = false) }, - errorClass = "UNSUPPORTED_FEATURE.DROP_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.DROP_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) catalog.dropNamespace(Array("foo"), cascade = true) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 342fb4bb38e60..2c97a588670a8 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -118,7 +118,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"DECIMAL(19,0)\"", "newType" -> "\"INT\"", @@ -139,7 +139,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes exception = intercept[SparkRuntimeException] { sql(s"INSERT INTO $tableName SELECT rpad('hi', 256, 'spark')") }, - errorClass = "EXCEED_LIMIT_LENGTH", + condition = "EXCEED_LIMIT_LENGTH", parameters = Map("limit" -> "255") ) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 7c439d449d86f..ef52aebd723a6 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -19,34 +19,25 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkSQLException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.DatabaseOnDocker +import org.apache.spark.sql.jdbc.PostgresDatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.3-alpine) + * To run this test suite for a specific version (e.g., postgres:17.2-alpine) * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.2-alpine * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresIntegrationSuite" * }}} */ @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") - override val env = Map( - "POSTGRES_PASSWORD" -> "rootpass" - ) - override val usesIpc = false - override val jdbcPort = 5432 - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" - } + override val db = new PostgresDatabaseOnDocker override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) @@ -65,6 +56,134 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT |) """.stripMargin ).executeUpdate() + + connection.prepareStatement("CREATE TABLE array_test_table (int_array int[]," + + "float_array FLOAT8[], timestamp_array TIMESTAMP[], string_array TEXT[]," + + "datetime_array TIMESTAMPTZ[], array_of_int_arrays INT[][])").executeUpdate() + + val query = + """ + INSERT INTO array_test_table + (int_array, float_array, timestamp_array, string_array, + datetime_array, array_of_int_arrays) + VALUES + ( + ARRAY[1, 2, 3], -- Array of integers + ARRAY[1.1, 2.2, 3.3], -- Array of floats + ARRAY['2023-01-01 12:00'::timestamp, '2023-06-01 08:30'::timestamp], + ARRAY['hello', 'world'], -- Array of strings + ARRAY['2023-10-04 12:00:00+00'::timestamptz, + '2023-12-01 14:15:00+00'::timestamptz], + ARRAY[ARRAY[1, 2]] -- Array of arrays of integers + ), + ( + ARRAY[10, 20, 30], -- Another set of data + ARRAY[10.5, 20.5, 30.5], + ARRAY['2022-01-01 09:15'::timestamp, '2022-03-15 07:45'::timestamp], + ARRAY['postgres', 'arrays'], + ARRAY['2022-11-22 09:00:00+00'::timestamptz, + '2022-12-31 23:59:59+00'::timestamptz], + ARRAY[ARRAY[10, 20]] + ); + """ + connection.prepareStatement(query).executeUpdate() + + connection.prepareStatement("CREATE TABLE array_int (col int[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_bigint(col bigint[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_smallint (col smallint[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_boolean (col boolean[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_float (col real[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_double (col float8[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_timestamp (col timestamp[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_timestamptz (col timestamptz[])") + .executeUpdate() + + connection.prepareStatement("INSERT INTO array_int VALUES (array[10]), (array[array[10]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_bigint VALUES (array[10]), " + + "(array[array[10]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_smallint VALUES (array[10]), " + + "(array[array[10]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_boolean VALUES (array[true]), " + + "(array[array[true]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_float VALUES (array[10.5]), " + + "(array[array[10.5]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_double VALUES (array[10.1]), " + + "(array[array[10.1]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_timestamp VALUES " + + "(array['2022-01-01 09:15'::timestamp]), " + + "(array[array['2022-01-01 09:15'::timestamp]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_timestamptz VALUES " + + "(array['2022-01-01 09:15'::timestamptz]), " + + "(array[array['2022-01-01 09:15'::timestamptz]])").executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + + connection.prepareStatement("CREATE TABLE array_of_int (col int[])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_of_int " + + "VALUES (array[1])").executeUpdate() + connection.prepareStatement("CREATE TABLE ctas_array_of_int " + + "AS SELECT * FROM array_of_int").executeUpdate() + + connection.prepareStatement("CREATE TABLE array_of_array_of_int (col int[][])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_of_array_of_int " + + "VALUES (array[array[1],array[2]])").executeUpdate() + connection.prepareStatement("CREATE TABLE ctas_array_of_array_of_int " + + "AS SELECT * FROM array_of_array_of_int").executeUpdate() + + connection.prepareStatement("CREATE TABLE unsupported_array_of_array_of_int (col int[][])") + .executeUpdate() + connection.prepareStatement("INSERT INTO unsupported_array_of_array_of_int " + + "VALUES (array[array[1],array[2]]), (array[3])").executeUpdate() + } + + test("Test multi-dimensional column types") { + // This test is used to verify that the multi-dimensional + // column types are supported by the JDBC V2 data source. + // We do not verify any result output + // + val df = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "array_test_table") + .load() + df.collect() + + val array_tables = Array( + ("array_int", "\"ARRAY\""), + ("array_bigint", "\"ARRAY\""), + ("array_smallint", "\"ARRAY\""), + ("array_boolean", "\"ARRAY\""), + ("array_float", "\"ARRAY\""), + ("array_double", "\"ARRAY\""), + ("array_timestamp", "\"ARRAY\""), + ("array_timestamptz", "\"ARRAY\"") + ) + + array_tables.foreach { case (dbtable, arrayType) => + checkError( + exception = intercept[SparkSQLException] { + val df = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", dbtable) + .load() + df.collect() + }, + condition = "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH", + parameters = Map("pos" -> "0", "type" -> arrayType), + sqlState = Some("0A000") + ) + } + } + + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation(connection) + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -84,7 +203,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -118,9 +237,112 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT sql(s"CREATE TABLE $t2(c int)") checkError( exception = intercept[TableAlreadyExistsException](sql(s"ALTER TABLE $t1 RENAME TO t2")), - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`t2`") ) } } + + override def testDatetime(tbl: String): Unit = { + val df1 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ") + checkFilterPushed(df1) + val rows1 = df1.collect() + assert(rows1.length === 2) + assert(rows1(0).getString(0) === "amy") + assert(rows1(1).getString(0) === "alex") + + val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2") + checkFilterPushed(df2) + val rows2 = df2.collect() + assert(rows2.length === 2) + assert(rows2(0).getString(0) === "amy") + assert(rows2(1).getString(0) === "alex") + + val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5") + checkFilterPushed(df3) + val rows3 = df3.collect() + assert(rows3.length === 2) + assert(rows3(0).getString(0) === "amy") + assert(rows3(1).getString(0) === "alex") + + val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0") + checkFilterPushed(df4) + val rows4 = df4.collect() + assert(rows4.length === 2) + assert(rows4(0).getString(0) === "amy") + assert(rows4(1).getString(0) === "alex") + + val df5 = sql(s"SELECT name FROM $tbl WHERE " + + "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022") + checkFilterPushed(df5) + val rows5 = df5.collect() + assert(rows5.length === 2) + assert(rows5(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + + val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " + + "AND datediff(date1, '2022-05-10') > 0") + checkFilterPushed(df6, false) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "amy") + + val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2") + checkFilterPushed(df7) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "alex") + + val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4") + checkFilterPushed(df8) + val rows8 = df8.collect() + assert(rows8.length === 1) + assert(rows8(0).getString(0) === "alex") + + val df9 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFilterPushed(df9) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "alex") + + // Postgres does not support + val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'") + checkFilterPushed(df10, false) + val rows10 = df10.collect() + assert(rows10.length === 2) + assert(rows10(0).getString(0) === "amy") + assert(rows10(1).getString(0) === "alex") + } + + test("Test reading 2d array from table created via CTAS command - positive test") { + val dfNoCTASTable = sql(s"SELECT * FROM $catalogName.array_of_int") + val dfWithCTASTable = sql(s"SELECT * FROM $catalogName.ctas_array_of_int") + + checkAnswer(dfWithCTASTable, dfNoCTASTable.collect()) + } + + test("Test reading 2d array from table created via CTAS command - negative test") { + val dfNoCTASTable = sql(s"SELECT * FROM $catalogName.array_of_int") + + checkError( + exception = intercept[org.apache.spark.SparkSQLException] { + // This should fail as only 1D CTAS tables are supported + sql(s"SELECT * FROM $catalogName.ctas_array_of_array_of_int").collect() + }, + condition = "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH", + parameters = Map("pos" -> "0", "type" -> "\"ARRAY\"") + ) + } + + test("Test reading multiple dimension array from table created via CTAS command") { + checkError( + exception = intercept[org.apache.spark.SparkSQLException] { + sql(s"SELECT * FROM $catalogName.unsupported_array_of_array_of_int").collect() + }, + condition = "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH", + parameters = Map("pos" -> "0", "type" -> "\"ARRAY>\"") + ) + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index 8a2d0ded84381..f84bdb46850f2 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -21,29 +21,20 @@ import java.sql.Connection import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, PostgresDatabaseOnDocker} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.3-alpine): + * To run this test suite for a specific version (e.g., postgres:17.2-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.3-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.2-alpine * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresNamespaceSuite" * }}} */ @DockerTest class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.3-alpine") - override val env = Map( - "POSTGRES_PASSWORD" -> "rootpass" - ) - override val usesIpc = false - override val jdbcPort = 5432 - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" - } + override val db = new PostgresDatabaseOnDocker val map = new CaseInsensitiveStringMap( Map("url" -> db.getJdbcUrl(dockerIp, externalPort), diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index e4cc88cec0f5e..3b1a457214be7 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -92,7 +92,7 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte catalog.listNamespaces(Array("foo")) } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`foo`")) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 88ba00a8a1aea..54635f69f8b65 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -66,10 +66,17 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update nullability of not existing column - val msg = intercept[AnalysisException] { - sql(s"ALTER TABLE $catalogName.alt_table ALTER COLUMN bad_column DROP NOT NULL") - }.getMessage - assert(msg.contains("Missing field bad_column")) + val sqlText = s"ALTER TABLE $catalogName.alt_table ALTER COLUMN bad_column DROP NOT NULL" + checkError( + exception = intercept[AnalysisException] { + sql(sqlText) + }, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map( + "objectName" -> "`bad_column`", + "proposal" -> "`ID`"), + context = ExpectedContext(fragment = sqlText, start = 0, stop = sqlText.length -1)) } def testRenameColumn(tbl: String): Unit = { @@ -85,11 +92,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu private def checkErrorFailedJDBC( e: AnalysisException, - errorClass: String, + condition: String, tbl: String): Unit = { checkErrorMatchPVals( exception = e, - errorClass = errorClass, + condition = condition, parameters = Map( "url" -> "jdbc:.*", "tableName" -> s"`$tbl`") @@ -119,7 +126,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C3 DOUBLE)") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "add", "fieldNames" -> "`C3`", @@ -147,10 +154,17 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu .add("C2", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Drop not existing column - val msg = intercept[AnalysisException] { - sql(s"ALTER TABLE $catalogName.alt_table DROP COLUMN bad_column") - }.getMessage - assert(msg.contains(s"Missing field bad_column in table $catalogName.alt_table")) + val sqlText = s"ALTER TABLE $catalogName.alt_table DROP COLUMN bad_column" + checkError( + exception = intercept[AnalysisException] { + sql(sqlText) + }, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map( + "objectName" -> "`bad_column`", + "proposal" -> "`C2`"), + context = ExpectedContext(fragment = sqlText, start = 0, stop = sqlText.length -1)) } // Drop a column from a not existing table val e = intercept[AnalysisException] { @@ -163,10 +177,17 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu withTable(s"$catalogName.alt_table") { testUpdateColumnType(s"$catalogName.alt_table") // Update not existing column - val msg2 = intercept[AnalysisException] { - sql(s"ALTER TABLE $catalogName.alt_table ALTER COLUMN bad_column TYPE DOUBLE") - }.getMessage - assert(msg2.contains("Missing field bad_column")) + val sqlText = s"ALTER TABLE $catalogName.alt_table ALTER COLUMN bad_column TYPE DOUBLE" + checkError( + exception = intercept[AnalysisException] { + sql(sqlText) + }, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map( + "objectName" -> "`bad_column`", + "proposal" -> "`ID`"), + context = ExpectedContext(fragment = sqlText, start = 0, stop = sqlText.length -1)) } // Update column type in not existing table val e = intercept[AnalysisException] { @@ -185,7 +206,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.alt_table RENAME COLUMN ID1 TO ID2") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "rename", "fieldNames" -> "`ID2`", @@ -287,7 +308,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[IndexAlreadyExistsException] { sql(s"CREATE index i1 ON $catalogName.new_table (col1)") }, - errorClass = "INDEX_ALREADY_EXISTS", + condition = "INDEX_ALREADY_EXISTS", parameters = Map("indexName" -> "`i1`", "tableName" -> "`new_table`") ) @@ -312,7 +333,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[NoSuchIndexException] { sql(s"DROP index i1 ON $catalogName.new_table") }, - errorClass = "INDEX_NOT_FOUND", + condition = "INDEX_NOT_FOUND", parameters = Map("indexName" -> "`i1`", "tableName" -> "`new_table`") ) } @@ -332,7 +353,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { + protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { val filter = df.queryExecution.optimizedPlan.collect { case f: Filter => f } @@ -944,4 +965,25 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(2).getDouble(0) === 0.0) } } + + test("SPARK-48618: Renaming the table to the name of an existing table") { + withTable(s"$catalogName.tbl1", s"$catalogName.tbl2") { + sql(s"CREATE TABLE $catalogName.tbl1 (col1 INT, col2 INT)") + sql(s"CREATE TABLE $catalogName.tbl2 (col3 INT, col4 INT)") + + checkError( + exception = intercept[AnalysisException] { + sql(s"ALTER TABLE $catalogName.tbl2 RENAME TO tbl1") + }, + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", + parameters = Map("relationName" -> "`tbl1`") + ) + } + } + + def testDatetime(tbl: String): Unit = {} + + test("scan with filter push-down with date time functions") { + testDatetime(s"$catalogAndNamespace.${caseConvert("datetime")}") + } } diff --git a/connector/docker/README.md b/connector/docker/README.md deleted file mode 100644 index 40ba9c3065946..0000000000000 --- a/connector/docker/README.md +++ /dev/null @@ -1,7 +0,0 @@ -Spark docker files -=========== - -Drawn from Matt Massie's docker files (https://github.com/massie/dockerfiles), -as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker). - -Tested with Docker version 0.8.1. diff --git a/connector/docker/spark-test/README.md b/connector/docker/spark-test/README.md deleted file mode 100644 index ec0baf6e6d419..0000000000000 --- a/connector/docker/spark-test/README.md +++ /dev/null @@ -1,11 +0,0 @@ -Spark Docker files usable for testing and development purposes. - -These images are intended to be run like so: - - docker run -v $SPARK_HOME:/opt/spark spark-test-master - docker run -v $SPARK_HOME:/opt/spark spark-test-worker spark://:7077 - -Using this configuration, the containers will have their Spark directories -mounted to your actual `SPARK_HOME`, allowing you to modify and recompile -your Spark source and have them immediately usable in the docker images -(without rebuilding them). diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml index 35f58134f1a85..66e1c24e821c8 100644 --- a/connector/kafka-0-10-sql/pom.xml +++ b/connector/kafka-0-10-sql/pom.xml @@ -148,6 +148,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala index 13a68e72269f0..c4adb6b3f26e1 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala @@ -184,5 +184,5 @@ private[kafka010] class KafkaIllegalStateException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala index 4c0620a35cc21..ae3c50f82e2d5 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -21,22 +21,26 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.util.CaseInsensitiveStringMap - /** - * Class to calculate offset ranges to process based on the from and until offsets, and - * the configured `minPartitions`. + * Class to calculate offset ranges to process based on the from and until offsets, and the + * configured `minPartitions` and `maxRecordsPerPartition`. */ -private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { +private[kafka010] class KafkaOffsetRangeCalculator( + val minPartitions: Option[Int], + val maxRecordsPerPartition: Option[Long]) { require(minPartitions.isEmpty || minPartitions.get > 0) + require(maxRecordsPerPartition.isEmpty || maxRecordsPerPartition.get > 0) /** - * Calculate the offset ranges that we are going to process this batch. If `minPartitions` - * is not set or is set less than or equal the number of `topicPartitions` that we're going to - * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If - * `minPartitions` is set higher than the number of our `topicPartitions`, then we will split up - * the read tasks of the skewed partitions to multiple Spark tasks. - * The number of Spark tasks will be *approximately* `minPartitions`. It can be less or more - * depending on rounding errors or Kafka partitions that didn't receive any new data. + * Calculate the offset ranges that we are going to process this batch. If `minPartitions` is + * not set or is set less than or equal the number of `topicPartitions` that we're going to + * consume and, `maxRecordsPerPartition` is not set then we fall back to a 1-1 mapping of Spark + * tasks to Kafka partitions. If `maxRecordsPerPartition` is set, then we will split up read + * task to multiple tasks as per `maxRecordsPerPartition` value. If `minPartitions` is set + * higher than the number of our `topicPartitions`, then we will split up the read tasks of the + * skewed partitions to multiple Spark tasks. The number of Spark tasks will be *approximately* + * max of `(recordsPerPartition/maxRecordsPerPartition)` and `minPartitions`. It can be less or + * more depending on rounding errors or Kafka partitions that didn't receive any new data. * * Empty (`KafkaOffsetRange.size == 0`) or invalid (`KafkaOffsetRange.size < 0`) ranges will be * dropped. @@ -47,51 +51,81 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int val offsetRanges = ranges.filter(_.size > 0) // If minPartitions not set or there are enough partitions to satisfy minPartitions - if (minPartitions.isEmpty || offsetRanges.size >= minPartitions.get) { + // and maxRecordsPerPartition is empty + if ((minPartitions.isEmpty || offsetRanges.size >= minPartitions.get) + && maxRecordsPerPartition.isEmpty) { // Assign preferred executor locations to each range such that the same topic-partition is // preferentially read from the same executor and the KafkaConsumer can be reused. offsetRanges.map { range => range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations)) } } else { + val dividedOffsetRanges = if (maxRecordsPerPartition.isDefined) { + val maxRecords = maxRecordsPerPartition.get + offsetRanges + .flatMap { range => + val size = range.size + // number of partitions to divvy up this topic partition to + val parts = math.ceil(size.toDouble / maxRecords).toInt + getDividedPartition(parts, range) + } + .filter(_.size > 0) + } else { + offsetRanges + } - // Splits offset ranges with relatively large amount of data to smaller ones. - val totalSize = offsetRanges.map(_.size).sum + if (minPartitions.isDefined && minPartitions.get > dividedOffsetRanges.size) { + // Splits offset ranges with relatively large amount of data to smaller ones. + val totalSize = dividedOffsetRanges.map(_.size).sum + + // First distinguish between any small (i.e. unsplit) ranges and large (i.e. split) ranges, + // in order to exclude the contents of unsplit ranges from the proportional math applied to + // split ranges + val unsplitRanges = dividedOffsetRanges.filter { range => + getPartCount(range.size, totalSize, minPartitions.get) == 1 + } - // First distinguish between any small (i.e. unsplit) ranges and large (i.e. split) ranges, - // in order to exclude the contents of unsplit ranges from the proportional math applied to - // split ranges - val unsplitRanges = offsetRanges.filter { range => - getPartCount(range.size, totalSize, minPartitions.get) == 1 + val unsplitRangeTotalSize = unsplitRanges.map(_.size).sum + val splitRangeTotalSize = totalSize - unsplitRangeTotalSize + val unsplitRangeTopicPartitions = unsplitRanges.map(_.topicPartition).toSet + val splitRangeMinPartitions = math.max(minPartitions.get - unsplitRanges.size, 1) + + // Now we can apply the main calculation logic + dividedOffsetRanges + .flatMap { range => + val tp = range.topicPartition + val size = range.size + // number of partitions to divvy up this topic partition to + val parts = if (unsplitRangeTopicPartitions.contains(tp)) { + 1 + } else { + getPartCount(size, splitRangeTotalSize, splitRangeMinPartitions) + } + getDividedPartition(parts, range) + } + .filter(_.size > 0) + } else { + dividedOffsetRanges } + } + } - val unsplitRangeTotalSize = unsplitRanges.map(_.size).sum - val splitRangeTotalSize = totalSize - unsplitRangeTotalSize - val unsplitRangeTopicPartitions = unsplitRanges.map(_.topicPartition).toSet - val splitRangeMinPartitions = math.max(minPartitions.get - unsplitRanges.size, 1) - - // Now we can apply the main calculation logic - offsetRanges.flatMap { range => - val tp = range.topicPartition - val size = range.size - // number of partitions to divvy up this topic partition to - val parts = if (unsplitRangeTopicPartitions.contains(tp)) { - 1 - } else { - getPartCount(size, splitRangeTotalSize, splitRangeMinPartitions) - } - var remaining = size - var startOffset = range.fromOffset - (0 until parts).map { part => - // Fine to do integer division. Last partition will consume all the round off errors - val thisPartition = remaining / (parts - part) - remaining -= thisPartition - val endOffset = math.min(startOffset + thisPartition, range.untilOffset) - val offsetRange = KafkaOffsetRange(tp, startOffset, endOffset, None) - startOffset = endOffset - offsetRange - } - }.filter(_.size > 0) + private def getDividedPartition( + parts: Int, + offsetRange: KafkaOffsetRange): IndexedSeq[KafkaOffsetRange] = { + var remaining = offsetRange.size + var startOffset = offsetRange.fromOffset + val tp = offsetRange.topicPartition + val untilOffset = offsetRange.untilOffset + + (0 until parts).map { part => + // Fine to do integer division. Last partition will consume all the round off errors + val thisPartition = remaining / (parts - part) + remaining -= thisPartition + val endOffset = math.min(startOffset + thisPartition, untilOffset) + val offsetRange = KafkaOffsetRange(tp, startOffset, endOffset, None) + startOffset = endOffset + offsetRange } } @@ -114,9 +148,12 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int private[kafka010] object KafkaOffsetRangeCalculator { def apply(options: CaseInsensitiveStringMap): KafkaOffsetRangeCalculator = { - val optionalValue = Option(options.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) + val minPartition = Option(options.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) .map(_.toInt) - new KafkaOffsetRangeCalculator(optionalValue) + val maxRecordsPerPartition = + Option(options.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY)) + .map(_.toLong) + new KafkaOffsetRangeCalculator(minPartition, maxRecordsPerPartition) } } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala index bb4f14686f976..0bdd931028aef 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala @@ -99,14 +99,18 @@ private[kafka010] class KafkaOffsetReaderAdmin( */ private val minPartitions = readerOptions.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY).map(_.toInt) + private val maxRecordsPerPartition = + readerOptions.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY).map(_.toLong) - private val rangeCalculator = new KafkaOffsetRangeCalculator(minPartitions) + private val rangeCalculator = + new KafkaOffsetRangeCalculator(minPartitions, maxRecordsPerPartition) /** * Whether we should divide Kafka TopicPartitions with a lot of data into smaller Spark tasks. */ - private def shouldDivvyUpLargePartitions(numTopicPartitions: Int): Boolean = { - minPartitions.map(_ > numTopicPartitions).getOrElse(false) + private def shouldDivvyUpLargePartitions(offsetRanges: Seq[KafkaOffsetRange]): Boolean = { + minPartitions.map(_ > offsetRanges.size).getOrElse(false) || + offsetRanges.exists(_.size > maxRecordsPerPartition.getOrElse(Long.MaxValue)) } override def toString(): String = consumerStrategy.toString @@ -397,7 +401,7 @@ private[kafka010] class KafkaOffsetReaderAdmin( KafkaOffsetRange(tp, fromOffset, untilOffset, None) }.toSeq - if (shouldDivvyUpLargePartitions(offsetRangesBase.size)) { + if (shouldDivvyUpLargePartitions(offsetRangesBase)) { val fromOffsetsMap = offsetRangesBase.map(range => (range.topicPartition, range.fromOffset)).toMap val untilOffsetsMap = diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala index fa53d6373176e..f7530dcba6b85 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala @@ -98,8 +98,11 @@ private[kafka010] class KafkaOffsetReaderConsumer( */ private val minPartitions = readerOptions.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY).map(_.toInt) + private val maxRecordsPerPartition = + readerOptions.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY).map(_.toLong) - private val rangeCalculator = new KafkaOffsetRangeCalculator(minPartitions) + private val rangeCalculator = + new KafkaOffsetRangeCalculator(minPartitions, maxRecordsPerPartition) private[kafka010] val offsetFetchAttemptIntervalMs = readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong @@ -107,8 +110,9 @@ private[kafka010] class KafkaOffsetReaderConsumer( /** * Whether we should divide Kafka TopicPartitions with a lot of data into smaller Spark tasks. */ - private def shouldDivvyUpLargePartitions(numTopicPartitions: Int): Boolean = { - minPartitions.map(_ > numTopicPartitions).getOrElse(false) + private def shouldDivvyUpLargePartitions(offsetRanges: Seq[KafkaOffsetRange]): Boolean = { + minPartitions.map(_ > offsetRanges.size).getOrElse(false) || + offsetRanges.exists(_.size > maxRecordsPerPartition.getOrElse(Long.MaxValue)) } private def nextGroupId(): String = { @@ -446,7 +450,7 @@ private[kafka010] class KafkaOffsetReaderConsumer( KafkaOffsetRange(tp, fromOffset, untilOffset, None) }.toSeq - if (shouldDivvyUpLargePartitions(offsetRangesBase.size)) { + if (shouldDivvyUpLargePartitions(offsetRangesBase)) { val fromOffsetsMap = offsetRangesBase.map(range => (range.topicPartition, range.fromOffset)).toMap val untilOffsetsMap = diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala index 56456f9b1f776..8d0bcc5816775 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala @@ -50,7 +50,7 @@ private[kafka010] class KafkaRecordToRowConverter { new GenericArrayData(cr.headers.iterator().asScala .map(header => InternalRow(UTF8String.fromString(header.key()), header.value()) - ).toArray) + ).toArray[Any]) } else { null } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index e1fdbfb183c39..4cb9fa8df8052 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -271,6 +271,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") } + if (params.contains(MAX_RECORDS_PER_PARTITION_OPTION_KEY)) { + val p = params(MAX_RECORDS_PER_PARTITION_OPTION_KEY).toLong + if (p <= 0) { + throw new IllegalArgumentException( + s"$MAX_RECORDS_PER_PARTITION_OPTION_KEY must be positive") + } + } + // Validate user-specified Kafka options if (params.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { @@ -557,6 +565,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val ENDING_TIMESTAMP_OPTION_KEY = "endingtimestamp" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions" + private[kafka010] val MAX_RECORDS_PER_PARTITION_OPTION_KEY = "maxrecordsperpartition" private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxoffsetspertrigger" private[kafka010] val MIN_OFFSET_PER_TRIGGER = "minoffsetspertrigger" private[kafka010] val MAX_TRIGGER_DELAY = "maxtriggerdelay" diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 9ae6a9290f80a..1d119de43970f 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1156,7 +1156,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with test("allow group.id prefix") { // Group ID prefix is only supported by consumer based offset reader - if (spark.conf.get(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { + if (sqlConf.getConf(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { testGroupId("groupIdPrefix", (expected, actual) => { assert(actual.exists(_.startsWith(expected)) && !actual.exists(_ === expected), "Valid consumer groups don't contain the expected group id - " + @@ -1167,7 +1167,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with test("allow group.id override") { // Group ID override is only supported by consumer based offset reader - if (spark.conf.get(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { + if (sqlConf.getConf(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { testGroupId("kafka.group.id", (expected, actual) => { assert(actual.exists(_ === expected), "Valid consumer groups don't " + s"contain the expected group id - Valid consumer groups: $actual / " + diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala index 89ab0902f4d6f..516aee6ad537d 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -34,6 +34,30 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { } } + def testWithMaxRecordsPerPartition(name: String, maxRecordsPerPartition: Long)( + f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new CaseInsensitiveStringMap( + Map("maxRecordsPerPartition" -> maxRecordsPerPartition.toString).asJava) + test(s"with maxRecordsPerPartition = $maxRecordsPerPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + + def testWithMinPartitionsAndMaxRecordsPerPartition( + name: String, + minPartitions: Int, + maxRecordsPerPartition: Long)(f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new CaseInsensitiveStringMap( + Map( + "minPartitions" -> minPartitions.toString, + "maxRecordsPerPartition" -> maxRecordsPerPartition.toString).asJava) + test( + s"with minPartitions = $minPartitions " + + s"and maxRecordsPerPartition = $maxRecordsPerPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + test("with no minPartition: N TopicPartitions to N offset ranges") { val calc = KafkaOffsetRangeCalculator(CaseInsensitiveStringMap.empty()) assert( @@ -253,6 +277,59 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { KafkaOffsetRange(tp3, 7500, 10000, None))) } + testWithMaxRecordsPerPartition("SPARK-49259: 1 TopicPartition to N offset ranges", 4) { calc => + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 5))) == Seq(KafkaOffsetRange(tp1, 1, 5, None))) + + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 2))) == Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 6)), executorLocations = Seq("location")) == + Seq(KafkaOffsetRange(tp1, 1, 3, None), KafkaOffsetRange(tp1, 3, 6, None)) + ) // location pref not set when maxRecordsPerPartition is set + } + + testWithMaxRecordsPerPartition("SPARK-49259: N TopicPartition to N offset ranges", 20) { calc => + assert( + calc.getRanges( + Seq( + KafkaOffsetRange(tp1, 1, 40), + KafkaOffsetRange(tp2, 1, 50), + KafkaOffsetRange(tp3, 1, 60))) == + Seq( + KafkaOffsetRange(tp1, 1, 20, None), + KafkaOffsetRange(tp1, 20, 40, None), + KafkaOffsetRange(tp2, 1, 17, None), + KafkaOffsetRange(tp2, 17, 33, None), + KafkaOffsetRange(tp2, 33, 50, None), + KafkaOffsetRange(tp3, 1, 20, None), + KafkaOffsetRange(tp3, 20, 40, None), + KafkaOffsetRange(tp3, 40, 60, None))) + } + + testWithMinPartitionsAndMaxRecordsPerPartition( + "SPARK-49259: 1 TopicPartition with low minPartitions value", + 1, + 20) { calc => + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 40))) == + Seq(KafkaOffsetRange(tp1, 1, 20, None), KafkaOffsetRange(tp1, 20, 40, None))) + } + + testWithMinPartitionsAndMaxRecordsPerPartition( + "SPARK-49259: 1 TopicPartition with high minPartitions value", + 4, + 20) { calc => + assert( + calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 40))) == + Seq( + KafkaOffsetRange(tp1, 1, 10, None), + KafkaOffsetRange(tp1, 10, 20, None), + KafkaOffsetRange(tp1, 20, 30, None), + KafkaOffsetRange(tp1, 30, 40, None))) + } + private val tp1 = new TopicPartition("t1", 1) private val tp2 = new TopicPartition("t2", 1) private val tp3 = new TopicPartition("t3", 1) diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala index 320485a79e59d..6fc22e7ac5e03 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala @@ -153,7 +153,7 @@ class KafkaOffsetReaderSuite extends QueryTest with SharedSparkSession with Kafk } checkError( exception = ex, - errorClass = "KAFKA_START_OFFSET_DOES_NOT_MATCH_ASSIGNED", + condition = "KAFKA_START_OFFSET_DOES_NOT_MATCH_ASSIGNED", parameters = Map( "specifiedPartitions" -> "Set\\(.*,.*\\)", "assignedPartitions" -> "Set\\(.*,.*,.*\\)"), diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 8d4e3e5c1d364..6087447fa3045 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -638,7 +638,7 @@ class KafkaRelationSuiteV1 extends KafkaRelationSuiteBase { val topic = newTopic() val df = createDF(topic) assert(df.logicalPlan.collect { - case LogicalRelation(_, _, _, _) => true + case _: LogicalRelation => true }.nonEmpty) } } diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 0737658e65256..9e06b6c6ff4a2 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -28,6 +28,7 @@ import scala.io.Source import scala.jdk.CollectionConverters._ import com.google.common.io.Files +import kafka.log.LogManager import kafka.server.{HostedPartition, KafkaConfig, KafkaServer} import kafka.server.checkpoints.OffsetCheckpointFile import kafka.zk.KafkaZkClient @@ -42,11 +43,12 @@ import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.requests.FetchRequest import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT} import org.apache.kafka.common.serialization.StringSerializer -import org.apache.kafka.common.utils.SystemTime +import org.apache.kafka.common.utils.Time import org.apache.zookeeper.client.ZKClientConfig import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.apache.zookeeper.server.auth.SASLAuthenticationProvider import org.scalatest.Assertions._ +import org.scalatest.PrivateMethodTester import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -65,7 +67,7 @@ import org.apache.spark.util.ArrayImplicits._ */ class KafkaTestUtils( withBrokerProps: Map[String, Object] = Map.empty, - secure: Boolean = false) extends Logging { + secure: Boolean = false) extends PrivateMethodTester with Logging { private val JAVA_AUTH_CONFIG = "java.security.auth.login.config" @@ -174,7 +176,7 @@ class KafkaTestUtils( } kdc.getKrb5conf.delete() - Files.write(krb5confStr, kdc.getKrb5conf, StandardCharsets.UTF_8) + Files.asCharSink(kdc.getKrb5conf, StandardCharsets.UTF_8).write(krb5confStr) logDebug(s"krb5.conf file content: $krb5confStr") } @@ -238,7 +240,7 @@ class KafkaTestUtils( | principal="$kafkaServerUser@$realm"; |}; """.stripMargin.trim - Files.write(content, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(content) logDebug(s"Created JAAS file: ${file.getPath}") logDebug(s"JAAS file content: $content") file.getAbsolutePath() @@ -251,7 +253,7 @@ class KafkaTestUtils( // Get the actual zookeeper binding port zkPort = zookeeper.actualPort zkClient = KafkaZkClient(s"$zkHost:$zkPort", isSecure = false, zkSessionTimeout, - zkConnectionTimeout, 1, new SystemTime(), "test", new ZKClientConfig) + zkConnectionTimeout, 1, Time.SYSTEM, "test", new ZKClientConfig) zkReady = true } @@ -447,8 +449,9 @@ class KafkaTestUtils( sendMessages(msgs.toImmutableArraySeq) } + private val cleanupLogsPrivateMethod = PrivateMethod[LogManager](Symbol("cleanupLogs")) def cleanupLogs(): Unit = { - server.logManager.cleanupLogs() + server.logManager.invokePrivate(cleanupLogsPrivateMethod()) } private def getOffsets(topics: Set[String], offsetSpec: OffsetSpec): Map[TopicPartition, Long] = { diff --git a/connector/kafka-0-10-token-provider/pom.xml b/connector/kafka-0-10-token-provider/pom.xml index 2b2707b9da320..3cbfc34e7d806 100644 --- a/connector/kafka-0-10-token-provider/pom.xml +++ b/connector/kafka-0-10-token-provider/pom.xml @@ -64,6 +64,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.hadoop hadoop-client-runtime diff --git a/connector/kafka-0-10/pom.xml b/connector/kafka-0-10/pom.xml index 1b26839a371ce..a42410e6ce885 100644 --- a/connector/kafka-0-10/pom.xml +++ b/connector/kafka-0-10/pom.xml @@ -119,6 +119,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/connector/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/connector/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index a97ee71ef4fe0..212693f6e02cc 100644 --- a/connector/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/connector/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -28,8 +28,9 @@ import kafka.log.{LogCleaner, UnifiedLog} import kafka.server.BrokerTopicStats import kafka.utils.Pool import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.compress.Compression import org.apache.kafka.common.config.TopicConfig -import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} +import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord} import org.apache.kafka.common.serialization.StringDeserializer import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig} import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} @@ -117,7 +118,7 @@ class KafkaRDDSuite extends SparkFunSuite { ) messages.foreach { case (k, v) => val record = new SimpleRecord(k.getBytes, v.getBytes) - log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, record), 0); + log.appendAsLeader(MemoryRecords.withRecords(Compression.NONE, record), 0); } log.roll() logs.put(topicPartition, log) diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml index 9a7f40443bbc9..7eba26ffdff74 100644 --- a/connector/kinesis-asl/pom.xml +++ b/connector/kinesis-asl/pom.xml @@ -81,6 +81,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index 8c52576c3531f..e85481ef9e1c8 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -146,7 +146,7 @@ src/test/resources/protobuf - true + direct java diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala deleted file mode 100644 index 91e87dee50482..0000000000000 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ /dev/null @@ -1,293 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.protobuf - -import scala.jdk.CollectionConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.Column -import org.apache.spark.sql.protobuf.utils.ProtobufUtils - -// scalastyle:off: object.name -object functions { -// scalastyle:on: object.name - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. The - * Protobuf definition is provided through Protobuf descriptor file. - * - * @param data - * the binary column. - * @param messageName - * the protobuf message name to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.4.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageName: String, - descFilePath: String, - options: java.util.Map[String, String]): Column = { - val descriptorFileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - from_protobuf(data, messageName, descriptorFileContent, options) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value.The - * Protobuf definition is provided through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.5.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageName: String, - binaryFileDescriptorSet: Array[Byte], - options: java.util.Map[String, String]): Column = { - new Column( - ProtobufDataToCatalyst( - data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap - ) - ) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. The - * Protobuf definition is provided through Protobuf descriptor file. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @since 3.4.0 - */ - @Experimental - def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent))) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value.The - * Protobuf definition is provided through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @since 3.5.0 - */ - @Experimental - def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) - : Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet))) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the binary column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @since 3.4.0 - */ - @Experimental - def from_protobuf(data: Column, messageClassName: String): Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageClassName)) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the binary column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @param options - * @since 3.4.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageClassName: String, - options: java.util.Map[String, String]): Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap)) - } - - /** - * Converts a column into binary of protobuf format. The Protobuf definition is provided - * through Protobuf descriptor file. - * - * @param data - * the data column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - to_protobuf(data, messageName, descFilePath, Map.empty[String, String].asJava) - } - - /** - * Converts a column into binary of protobuf format.The Protobuf definition is provided - * through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * - * @since 3.5.0 - */ - @Experimental - def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) - : Column = { - new Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet))) - } - /** - * Converts a column into binary of protobuf format. The Protobuf definition is provided - * through Protobuf descriptor file. - * - * @param data - * the data column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * the protobuf descriptor file. - * @param options - * @since 3.4.0 - */ - @Experimental - def to_protobuf( - data: Column, - messageName: String, - descFilePath: String, - options: java.util.Map[String, String]): Column = { - val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - new Column( - CatalystDataToProtobuf(data.expr, messageName, Some(fileContent), options.asScala.toMap) - ) - } - - /** - * Converts a column into binary of protobuf format.The Protobuf definition is provided - * through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.5.0 - */ - @Experimental - def to_protobuf( - data: Column, - messageName: String, - binaryFileDescriptorSet: Array[Byte], - options: java.util.Map[String, String] - ): Column = { - new Column( - CatalystDataToProtobuf( - data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap - ) - ) - } - - /** - * Converts a column into binary of protobuf format. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the data column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageClassName: String): Column = { - new Column(CatalystDataToProtobuf(data.expr, messageClassName)) - } - - /** - * Converts a column into binary of protobuf format. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the data column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @param options - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) - : Column = { - new Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap)) - } -} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala index 82cdc6b9c5816..046b8acdb5e5c 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala @@ -17,5 +17,4 @@ package org.apache.spark.sql package object protobuf { - protected[protobuf] object ScalaReflectionLock } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index 6644bce98293b..e85097a272f24 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -43,8 +43,8 @@ private[sql] class ProtobufOptions( /** * Adds support for recursive fields. If this option is is not specified, recursive fields are - * not permitted. Setting it to 0 drops the recursive fields, 1 allows it to be recursed once, - * and 2 allows it to be recursed twice and so on, up to 10. Values larger than 10 are not + * not permitted. Setting it to 1 drops the recursive fields, 0 allows it to be recursed once, + * and 3 allows it to be recursed twice and so on, up to 10. Values larger than 10 are not * allowed in order avoid inadvertently creating very large schemas. If a Protobuf message * has depth beyond this limit, the Spark struct returned is truncated after the recursion limit. * @@ -52,8 +52,8 @@ private[sql] class ProtobufOptions( * `message Person { string name = 1; Person friend = 2; }` * The following lists the schema with different values for this setting. * 1: `struct` - * 2: `struct>` - * 3: `struct>>` + * 2: `struct>` + * 3: `struct>>` * and so on. */ val recursiveFieldMaxDepth: Int = parameters.getOrElse("recursive.fields.max.depth", "-1").toInt @@ -181,7 +181,7 @@ private[sql] class ProtobufOptions( val upcastUnsignedInts: Boolean = parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean - // Whether to unwrap the struct representation for well known primitve wrapper types when + // Whether to unwrap the struct representation for well known primitive wrapper types when // deserializing. By default, the wrapper types for primitives (i.e. google.protobuf.Int32Value, // google.protobuf.Int64Value, etc.) will get deserialized as structs. We allow the option to // deserialize them as their respective primitives. @@ -221,7 +221,7 @@ private[sql] class ProtobufOptions( // By default, in the spark schema field a will be dropped, which result in schema // b struct // If retain.empty.message.types=true, field a will be retained by inserting a dummy column. - // b struct, name: string> + // b struct, name: string> val retainEmptyMessage: Boolean = parameters.getOrElse("retain.empty.message.types", false.toString).toBoolean } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala index fee1bcdc9670f..3d7bba7a82e88 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.protobuf.utils -import java.io.File -import java.io.FileNotFoundException -import java.nio.file.NoSuchFileException import java.util.Locale import scala.jdk.CollectionConverters._ -import scala.util.control.NonFatal import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException, Message} import com.google.protobuf.DescriptorProtos.{FileDescriptorProto, FileDescriptorSet} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import com.google.protobuf.TypeRegistry -import org.apache.commons.io.FileUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.errors.QueryCompilationErrors @@ -228,18 +223,6 @@ private[sql] object ProtobufUtils extends Logging { } } - def readDescriptorFileContent(filePath: String): Array[Byte] = { - try { - FileUtils.readFileToByteArray(new File(filePath)) - } catch { - case ex: FileNotFoundException => - throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case ex: NoSuchFileException => - throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case NonFatal(ex) => throw QueryCompilationErrors.descriptorParseError(ex) - } - } - private def parseFileDescriptorSet(bytes: Array[Byte]): List[Descriptors.FileDescriptor] = { var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null try { diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index feb5aed03451a..56c1f81850614 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -176,16 +176,16 @@ object SchemaConverters extends Logging { } case MESSAGE => // If the `recursive.fields.max.depth` value is not specified, it will default to -1, - // and recursive fields are not permitted. Setting it to 0 drops all recursive fields, - // 1 allows it to be recursed once, and 2 allows it to be recursed twice and so on. - // A value greater than 10 is not allowed, and if a protobuf record has more depth for - // recursive fields than the allowed value, it will be truncated and some fields may be - // discarded. + // and recursive fields are not permitted. Setting it to 1 drops all recursive fields, + // 2 allows it to be recursed once, and 3 allows it to be recursed twice and so on. + // A value less than or equal to 0 or greater than 10 is not allowed, and if a protobuf + // record has more depth for recursive fields than the allowed value, it will be truncated + // and some fields may be discarded. // SQL Schema for protob2uf `message Person { string name = 1; Person bff = 2;}` // will vary based on the value of "recursive.fields.max.depth". // 1: struct - // 2: struct> - // 3: struct>> + // 2: struct> + // 3: struct>> // and so on. // TODO(rangadi): A better way to terminate would be replace the remaining recursive struct // with the byte array of corresponding protobuf. This way no information is lost. diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index ad6a88640140a..abae1d622d3cf 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters} import org.apache.spark.sql.sources.{EqualTo, Not} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -39,7 +40,7 @@ class ProtobufCatalystDataConversionSuite with ProtobufTestBase { private val testFileDescFile = protobufDescriptorFile("catalyst_types.desc") - private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile) + private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile) private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$" private def checkResultWithEval( @@ -47,7 +48,7 @@ class ProtobufCatalystDataConversionSuite descFilePath: String, messageName: String, expected: Any): Unit = { - val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath) + val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath) withClue("(Eval check with Java class name)") { val className = s"$javaClassNamePrefix$messageName" checkEvaluation( @@ -72,7 +73,7 @@ class ProtobufCatalystDataConversionSuite actualSchema: String, badSchema: String): Unit = { - val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath) + val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath) val binary = CatalystDataToProtobuf(data, actualSchema, Some(descBytes)) intercept[Exception] { diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 5233e06883498..44a8339ac1f02 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.protobuf.utils.ProtobufOptions import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils} class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with ProtobufTestBase with Serializable { @@ -40,11 +41,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot import testImplicits._ val testFileDescFile = protobufDescriptorFile("functions_suite.desc") - private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile) + private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile) private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$" val proto2FileDescFile = protobufDescriptorFile("proto2_messages.desc") - val proto2FileDesc = ProtobufUtils.readDescriptorFileContent(proto2FileDescFile) + val proto2FileDesc = CommonProtobufUtils.readDescriptorFileContent(proto2FileDescFile) private val proto2JavaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.Proto2Messages$" private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary") @@ -467,7 +468,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot test("Handle extra fields : oldProducer -> newConsumer") { val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc") - val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile) + val descBytes = CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile) val oldProducer = ProtobufUtils.buildDescriptor(descBytes, "oldProducer") val newConsumer = ProtobufUtils.buildDescriptor(descBytes, "newConsumer") @@ -509,7 +510,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot test("Handle extra fields : newProducer -> oldConsumer") { val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc") - val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile) + val descBytes = CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile) val newProducer = ProtobufUtils.buildDescriptor(descBytes, "newProducer") val oldConsumer = ProtobufUtils.buildDescriptor(descBytes, "oldConsumer") @@ -708,7 +709,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = e, - errorClass = "PROTOBUF_DEPENDENCY_NOT_FOUND", + condition = "PROTOBUF_DEPENDENCY_NOT_FOUND", parameters = Map("dependencyName" -> "nestedenum.proto")) } @@ -1057,7 +1058,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( ex, - errorClass = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND", + condition = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND", parameters = Map("filePath" -> "/non/existent/path.desc") ) assert(ex.getCause != null) @@ -1699,7 +1700,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = parseError, - errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", + condition = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", parameters = Map( "sqlColumn" -> "`basic_enum`", "protobufColumn" -> "field 'basic_enum'", @@ -1711,7 +1712,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = parseError, - errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", + condition = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", parameters = Map( "sqlColumn" -> "`basic_enum`", "protobufColumn" -> "field 'basic_enum'", @@ -2042,6 +2043,168 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } + test("SPARK-49121: from_protobuf and to_protobuf SQL functions") { + withTable("protobuf_test_table") { + sql( + """ + |CREATE TABLE protobuf_test_table AS + | SELECT named_struct( + | 'id', 1L, + | 'string_value', 'test_string', + | 'int32_value', 32, + | 'int64_value', 64L, + | 'double_value', CAST(123.456 AS DOUBLE), + | 'float_value', CAST(789.01 AS FLOAT), + | 'bool_value', true, + | 'bytes_value', CAST('sample_bytes' AS BINARY) + | ) AS complex_struct + |""".stripMargin) + + val toProtobufSql = + s""" + |SELECT + | to_protobuf( + | complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', map() + | ) AS protobuf_data + |FROM protobuf_test_table + |""".stripMargin + + val protobufResult = spark.sql(toProtobufSql).collect() + assert(protobufResult != null) + + val fromProtobufSql = + s""" + |SELECT + | from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', map()) + |FROM + | ($toProtobufSql) + |""".stripMargin + + checkAnswer( + spark.sql(fromProtobufSql), + Seq(Row(Row(1L, "test_string", 32, 64L, 123.456, 789.01F, true, "sample_bytes".getBytes))) + ) + + // Negative tests for to_protobuf. + var fragment = s"to_protobuf(complex_struct, 42, '$testFileDescFile', map())" + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT + | to_protobuf(complex_struct, 42, '$testFileDescFile', map()) + |FROM protobuf_test_table + |""".stripMargin)), + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> s"""\"to_protobuf(complex_struct, 42, $testFileDescFile, map())\"""", + "msg" -> ("The second argument of the TO_PROTOBUF SQL function must be a constant " + + "string representing the Protobuf message name"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = fragment, + start = 10, + stop = fragment.length + 9)) + ) + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT + | to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map()) + |FROM protobuf_test_table + |""".stripMargin)), + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> "\"to_protobuf(complex_struct, SimpleMessageJavaTypes, 42, map())\"", + "msg" -> ("The third argument of the TO_PROTOBUF SQL function must be a constant " + + "string or binary data representing the Protobuf descriptor file path"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = "to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map())", + start = 10, + stop = 73)) + ) + fragment = s"to_protobuf(complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', 42)" + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT + | to_protobuf(complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', 42) + |FROM protobuf_test_table + |""".stripMargin)), + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> + s"""\"to_protobuf(complex_struct, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""", + "msg" -> ("The fourth argument of the TO_PROTOBUF SQL function must be a constant " + + "map of strings to strings containing the options to use for converting the value " + + "to Protobuf format"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = fragment, + start = 10, + stop = fragment.length + 9)) + ) + + // Negative tests for from_protobuf. + fragment = s"from_protobuf(protobuf_data, 42, '$testFileDescFile', map())" + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT from_protobuf(protobuf_data, 42, '$testFileDescFile', map()) + |FROM ($toProtobufSql) + |""".stripMargin)), + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> s"""\"from_protobuf(protobuf_data, 42, $testFileDescFile, map())\"""", + "msg" -> ("The second argument of the FROM_PROTOBUF SQL function must be a constant " + + "string representing the Protobuf message name"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = fragment, + start = 8, + stop = fragment.length + 7)) + ) + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map()) + |FROM ($toProtobufSql) + |""".stripMargin)), + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> "\"from_protobuf(protobuf_data, SimpleMessageJavaTypes, 42, map())\"", + "msg" -> ("The third argument of the FROM_PROTOBUF SQL function must be a constant " + + "string or binary data representing the Protobuf descriptor file path"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = "from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map())", + start = 8, + stop = 72)) + ) + fragment = s"from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', 42)" + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT + | from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', 42) + |FROM ($toProtobufSql) + |""".stripMargin)), + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> + s"""\"from_protobuf(protobuf_data, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""", + "msg" -> ("The fourth argument of the FROM_PROTOBUF SQL function must be a constant " + + "map of strings to strings containing the options to use for converting the value " + + "from Protobuf format"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = fragment, + start = 10, + stop = fragment.length + 9)) + ) + } + } + def testFromProtobufWithOptions( df: DataFrame, expectedDf: DataFrame, diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index 03285c73f1ff1..f3bd49e1b24af 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils} /** * Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on @@ -37,12 +38,12 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { import ProtoSerdeSuite.MatchType._ private val testFileDescFile = protobufDescriptorFile("serde_suite.desc") - private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile) + private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile) private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$" private val proto2DescFile = protobufDescriptorFile("proto2_messages.desc") - private val proto2Desc = ProtobufUtils.readDescriptorFileContent(proto2DescFile) + private val proto2Desc = CommonProtobufUtils.readDescriptorFileContent(proto2DescFile) test("Test basic conversion") { withFieldMatchType { fieldMatch => @@ -95,7 +96,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Deserializer, fieldMatch, - errorClass = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", + condition = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", params = Map( "protobufType" -> "MissMatchTypeInRoot", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -104,7 +105,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Serializer, fieldMatch, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "MissMatchTypeInRoot", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -122,7 +123,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Serializer, BY_NAME, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FieldMissingInProto", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -132,7 +133,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, nonnullCatalyst, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FieldMissingInProto", "toType" -> toSQLType(nonnullCatalyst))) @@ -150,7 +151,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Deserializer, fieldMatch, catalyst, - errorClass = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", + condition = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", params = Map( "protobufType" -> "MissMatchTypeInDeepNested", "toType" -> toSQLType(catalyst))) @@ -160,7 +161,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, fieldMatch, catalyst, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "MissMatchTypeInDeepNested", "toType" -> toSQLType(catalyst))) @@ -177,7 +178,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, catalystSchema = foobarSQLType, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FoobarWithRequiredFieldBar", "toType" -> toSQLType(foobarSQLType))) @@ -199,7 +200,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, catalystSchema = nestedFoobarSQLType, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "NestedFoobarWithRequiredFieldBar", "toType" -> toSQLType(nestedFoobarSQLType))) @@ -215,17 +216,17 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { val e1 = intercept[AnalysisException] { ProtobufUtils.buildDescriptor( - ProtobufUtils.readDescriptorFileContent(fileDescFile), + CommonProtobufUtils.readDescriptorFileContent(fileDescFile), "SerdeBasicMessage" ) } checkError( exception = e1, - errorClass = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR") + condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR") val basicMessageDescWithoutImports = descriptorSetWithoutImports( - ProtobufUtils.readDescriptorFileContent( + CommonProtobufUtils.readDescriptorFileContent( protobufDescriptorFile("basicmessage.desc") ), "BasicMessage" @@ -240,7 +241,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { checkError( exception = e2, - errorClass = "PROTOBUF_DEPENDENCY_NOT_FOUND", + condition = "PROTOBUF_DEPENDENCY_NOT_FOUND", parameters = Map("dependencyName" -> "nestedenum.proto")) } @@ -254,7 +255,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { serdeFactory: SerdeFactory[_], fieldMatchType: MatchType, catalystSchema: StructType = CATALYST_STRUCT, - errorClass: String, + condition: String, params: Map[String, String]): Unit = { val e = intercept[AnalysisException] { @@ -274,7 +275,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { assert(e.getMessage === expectMsg) checkError( exception = e, - errorClass = errorClass, + condition = condition, parameters = params) } diff --git a/core/benchmarks/ChecksumBenchmark-jdk21-results.txt b/core/benchmarks/ChecksumBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..85370450f355c --- /dev/null +++ b/core/benchmarks/ChecksumBenchmark-jdk21-results.txt @@ -0,0 +1,14 @@ +================================================================================================ +Benchmark Checksum Algorithms +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +AMD EPYC 7763 64-Core Processor +Checksum Algorithms: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +CRC32 2743 2746 3 0.0 2678409.9 1.0X +CRC32C 1974 2055 70 0.0 1928129.2 1.4X +Adler32 12689 12709 17 0.0 12391425.9 0.2X +hadoop PureJavaCrc32C 23027 23041 13 0.0 22487098.9 0.1X + + diff --git a/core/benchmarks/ChecksumBenchmark-results.txt b/core/benchmarks/ChecksumBenchmark-results.txt new file mode 100644 index 0000000000000..cce5a61abf637 --- /dev/null +++ b/core/benchmarks/ChecksumBenchmark-results.txt @@ -0,0 +1,14 @@ +================================================================================================ +Benchmark Checksum Algorithms +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +AMD EPYC 7763 64-Core Processor +Checksum Algorithms: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +CRC32 2757 2758 1 0.0 2692250.2 1.0X +CRC32C 2142 2244 116 0.0 2091901.8 1.3X +Adler32 12699 12712 15 0.0 12401205.6 0.2X +hadoop PureJavaCrc32C 23049 23066 15 0.0 22508320.3 0.1X + + diff --git a/core/benchmarks/CoalescedRDDBenchmark-jdk21-results.txt b/core/benchmarks/CoalescedRDDBenchmark-jdk21-results.txt index 2c1c69ac42db7..1daac7b710bbf 100644 --- a/core/benchmarks/CoalescedRDDBenchmark-jdk21-results.txt +++ b/core/benchmarks/CoalescedRDDBenchmark-jdk21-results.txt @@ -2,39 +2,39 @@ Coalesced RDD , large scale ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Coalesced RDD: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ---------------------------------------------------------------------------------------------------------------------------- -Coalesce Num Partitions: 100 Num Hosts: 1 269 278 10 0.4 2693.1 1.0X -Coalesce Num Partitions: 100 Num Hosts: 5 135 140 5 0.7 1345.0 2.0X -Coalesce Num Partitions: 100 Num Hosts: 10 109 115 6 0.9 1091.6 2.5X -Coalesce Num Partitions: 100 Num Hosts: 20 100 101 1 1.0 999.4 2.7X -Coalesce Num Partitions: 100 Num Hosts: 40 96 98 3 1.0 961.5 2.8X -Coalesce Num Partitions: 100 Num Hosts: 80 93 100 9 1.1 933.4 2.9X -Coalesce Num Partitions: 500 Num Hosts: 1 875 902 28 0.1 8754.7 0.3X -Coalesce Num Partitions: 500 Num Hosts: 5 262 265 3 0.4 2619.9 1.0X -Coalesce Num Partitions: 500 Num Hosts: 10 179 182 4 0.6 1792.2 1.5X -Coalesce Num Partitions: 500 Num Hosts: 20 136 139 6 0.7 1357.2 2.0X -Coalesce Num Partitions: 500 Num Hosts: 40 115 116 1 0.9 1145.7 2.4X -Coalesce Num Partitions: 500 Num Hosts: 80 105 110 7 1.0 1047.2 2.6X -Coalesce Num Partitions: 1000 Num Hosts: 1 1655 1656 2 0.1 16546.0 0.2X -Coalesce Num Partitions: 1000 Num Hosts: 5 425 428 4 0.2 4251.4 0.6X -Coalesce Num Partitions: 1000 Num Hosts: 10 263 267 3 0.4 2634.4 1.0X -Coalesce Num Partitions: 1000 Num Hosts: 20 182 188 9 0.5 1822.5 1.5X -Coalesce Num Partitions: 1000 Num Hosts: 40 142 143 1 0.7 1424.5 1.9X -Coalesce Num Partitions: 1000 Num Hosts: 80 123 131 8 0.8 1226.2 2.2X -Coalesce Num Partitions: 5000 Num Hosts: 1 7484 7491 10 0.0 74836.3 0.0X -Coalesce Num Partitions: 5000 Num Hosts: 5 1873 1880 11 0.1 18725.7 0.1X -Coalesce Num Partitions: 5000 Num Hosts: 10 995 1005 11 0.1 9950.1 0.3X -Coalesce Num Partitions: 5000 Num Hosts: 20 557 561 6 0.2 5570.9 0.5X -Coalesce Num Partitions: 5000 Num Hosts: 40 337 341 4 0.3 3369.7 0.8X -Coalesce Num Partitions: 5000 Num Hosts: 80 222 223 1 0.5 2222.0 1.2X -Coalesce Num Partitions: 10000 Num Hosts: 1 14102 14133 45 0.0 141020.7 0.0X -Coalesce Num Partitions: 10000 Num Hosts: 5 3681 3702 31 0.0 36811.9 0.1X -Coalesce Num Partitions: 10000 Num Hosts: 10 1871 1884 22 0.1 18706.8 0.1X -Coalesce Num Partitions: 10000 Num Hosts: 20 998 1004 5 0.1 9980.9 0.3X -Coalesce Num Partitions: 10000 Num Hosts: 40 570 575 5 0.2 5696.7 0.5X -Coalesce Num Partitions: 10000 Num Hosts: 80 345 346 2 0.3 3447.7 0.8X +Coalesce Num Partitions: 100 Num Hosts: 1 359 371 12 0.3 3586.9 1.0X +Coalesce Num Partitions: 100 Num Hosts: 5 181 190 14 0.6 1812.1 2.0X +Coalesce Num Partitions: 100 Num Hosts: 10 178 185 6 0.6 1779.9 2.0X +Coalesce Num Partitions: 100 Num Hosts: 20 153 156 4 0.7 1531.2 2.3X +Coalesce Num Partitions: 100 Num Hosts: 40 148 149 1 0.7 1479.1 2.4X +Coalesce Num Partitions: 100 Num Hosts: 80 166 170 5 0.6 1657.8 2.2X +Coalesce Num Partitions: 500 Num Hosts: 1 1054 1064 14 0.1 10543.7 0.3X +Coalesce Num Partitions: 500 Num Hosts: 5 331 339 13 0.3 3311.1 1.1X +Coalesce Num Partitions: 500 Num Hosts: 10 230 235 8 0.4 2295.7 1.6X +Coalesce Num Partitions: 500 Num Hosts: 20 218 220 1 0.5 2182.0 1.6X +Coalesce Num Partitions: 500 Num Hosts: 40 161 164 2 0.6 1614.8 2.2X +Coalesce Num Partitions: 500 Num Hosts: 80 137 142 7 0.7 1371.6 2.6X +Coalesce Num Partitions: 1000 Num Hosts: 1 1926 1929 3 0.1 19264.6 0.2X +Coalesce Num Partitions: 1000 Num Hosts: 5 501 507 10 0.2 5011.1 0.7X +Coalesce Num Partitions: 1000 Num Hosts: 10 327 331 4 0.3 3268.5 1.1X +Coalesce Num Partitions: 1000 Num Hosts: 20 256 264 8 0.4 2556.1 1.4X +Coalesce Num Partitions: 1000 Num Hosts: 40 185 191 7 0.5 1853.2 1.9X +Coalesce Num Partitions: 1000 Num Hosts: 80 160 166 5 0.6 1603.5 2.2X +Coalesce Num Partitions: 5000 Num Hosts: 1 8672 9054 615 0.0 86716.9 0.0X +Coalesce Num Partitions: 5000 Num Hosts: 5 2016 2020 6 0.0 20159.9 0.2X +Coalesce Num Partitions: 5000 Num Hosts: 10 1084 1096 10 0.1 10844.7 0.3X +Coalesce Num Partitions: 5000 Num Hosts: 20 625 636 11 0.2 6245.6 0.6X +Coalesce Num Partitions: 5000 Num Hosts: 40 418 425 6 0.2 4182.3 0.9X +Coalesce Num Partitions: 5000 Num Hosts: 80 270 276 8 0.4 2704.6 1.3X +Coalesce Num Partitions: 10000 Num Hosts: 1 16208 16391 226 0.0 162076.8 0.0X +Coalesce Num Partitions: 10000 Num Hosts: 5 3930 3949 23 0.0 39300.4 0.1X +Coalesce Num Partitions: 10000 Num Hosts: 10 2021 2031 11 0.0 20213.1 0.2X +Coalesce Num Partitions: 10000 Num Hosts: 20 1114 1115 1 0.1 11139.0 0.3X +Coalesce Num Partitions: 10000 Num Hosts: 40 628 639 17 0.2 6275.3 0.6X +Coalesce Num Partitions: 10000 Num Hosts: 80 402 408 10 0.2 4016.4 0.9X diff --git a/core/benchmarks/CoalescedRDDBenchmark-results.txt b/core/benchmarks/CoalescedRDDBenchmark-results.txt index aba428c1729e7..d370e6956116d 100644 --- a/core/benchmarks/CoalescedRDDBenchmark-results.txt +++ b/core/benchmarks/CoalescedRDDBenchmark-results.txt @@ -2,39 +2,39 @@ Coalesced RDD , large scale ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Coalesced RDD: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ---------------------------------------------------------------------------------------------------------------------------- -Coalesce Num Partitions: 100 Num Hosts: 1 134 143 8 0.7 1343.4 1.0X -Coalesce Num Partitions: 100 Num Hosts: 5 96 97 2 1.0 962.0 1.4X -Coalesce Num Partitions: 100 Num Hosts: 10 90 92 3 1.1 898.7 1.5X -Coalesce Num Partitions: 100 Num Hosts: 20 90 91 1 1.1 898.9 1.5X -Coalesce Num Partitions: 100 Num Hosts: 40 96 100 4 1.0 957.0 1.4X -Coalesce Num Partitions: 100 Num Hosts: 80 87 93 9 1.2 866.3 1.6X -Coalesce Num Partitions: 500 Num Hosts: 1 312 314 3 0.3 3115.2 0.4X -Coalesce Num Partitions: 500 Num Hosts: 5 135 136 1 0.7 1352.9 1.0X -Coalesce Num Partitions: 500 Num Hosts: 10 110 111 1 0.9 1103.7 1.2X -Coalesce Num Partitions: 500 Num Hosts: 20 103 109 9 1.0 1034.3 1.3X -Coalesce Num Partitions: 500 Num Hosts: 40 95 97 2 1.1 948.3 1.4X -Coalesce Num Partitions: 500 Num Hosts: 80 90 93 3 1.1 899.9 1.5X -Coalesce Num Partitions: 1000 Num Hosts: 1 527 535 13 0.2 5267.7 0.3X -Coalesce Num Partitions: 1000 Num Hosts: 5 179 180 1 0.6 1788.4 0.8X -Coalesce Num Partitions: 1000 Num Hosts: 10 132 138 8 0.8 1321.6 1.0X -Coalesce Num Partitions: 1000 Num Hosts: 20 116 122 5 0.9 1157.1 1.2X -Coalesce Num Partitions: 1000 Num Hosts: 40 99 104 7 1.0 988.1 1.4X -Coalesce Num Partitions: 1000 Num Hosts: 80 95 97 2 1.1 948.9 1.4X -Coalesce Num Partitions: 5000 Num Hosts: 1 2326 2336 10 0.0 23263.2 0.1X -Coalesce Num Partitions: 5000 Num Hosts: 5 735 743 7 0.1 7351.3 0.2X -Coalesce Num Partitions: 5000 Num Hosts: 10 400 405 6 0.2 4002.9 0.3X -Coalesce Num Partitions: 5000 Num Hosts: 20 263 266 2 0.4 2631.4 0.5X -Coalesce Num Partitions: 5000 Num Hosts: 40 175 180 7 0.6 1746.4 0.8X -Coalesce Num Partitions: 5000 Num Hosts: 80 139 141 1 0.7 1389.1 1.0X -Coalesce Num Partitions: 10000 Num Hosts: 1 4250 4263 21 0.0 42497.5 0.0X -Coalesce Num Partitions: 10000 Num Hosts: 5 1508 1512 6 0.1 15082.5 0.1X -Coalesce Num Partitions: 10000 Num Hosts: 10 765 770 7 0.1 7645.5 0.2X -Coalesce Num Partitions: 10000 Num Hosts: 20 420 424 4 0.2 4198.5 0.3X -Coalesce Num Partitions: 10000 Num Hosts: 40 277 279 3 0.4 2768.6 0.5X -Coalesce Num Partitions: 10000 Num Hosts: 80 187 189 2 0.5 1868.9 0.7X +Coalesce Num Partitions: 100 Num Hosts: 1 126 145 24 0.8 1257.7 1.0X +Coalesce Num Partitions: 100 Num Hosts: 5 104 105 1 1.0 1039.8 1.2X +Coalesce Num Partitions: 100 Num Hosts: 10 87 89 2 1.1 873.7 1.4X +Coalesce Num Partitions: 100 Num Hosts: 20 89 90 1 1.1 893.4 1.4X +Coalesce Num Partitions: 100 Num Hosts: 40 88 91 6 1.1 880.4 1.4X +Coalesce Num Partitions: 100 Num Hosts: 80 88 94 10 1.1 875.9 1.4X +Coalesce Num Partitions: 500 Num Hosts: 1 308 314 8 0.3 3078.0 0.4X +Coalesce Num Partitions: 500 Num Hosts: 5 133 136 2 0.7 1334.6 0.9X +Coalesce Num Partitions: 500 Num Hosts: 10 112 114 4 0.9 1118.8 1.1X +Coalesce Num Partitions: 500 Num Hosts: 20 100 111 16 1.0 1004.0 1.3X +Coalesce Num Partitions: 500 Num Hosts: 40 100 106 6 1.0 999.7 1.3X +Coalesce Num Partitions: 500 Num Hosts: 80 95 100 6 1.0 954.0 1.3X +Coalesce Num Partitions: 1000 Num Hosts: 1 531 539 14 0.2 5311.0 0.2X +Coalesce Num Partitions: 1000 Num Hosts: 5 197 201 4 0.5 1970.3 0.6X +Coalesce Num Partitions: 1000 Num Hosts: 10 139 141 2 0.7 1392.4 0.9X +Coalesce Num Partitions: 1000 Num Hosts: 20 114 115 1 0.9 1137.9 1.1X +Coalesce Num Partitions: 1000 Num Hosts: 40 105 108 3 0.9 1054.9 1.2X +Coalesce Num Partitions: 1000 Num Hosts: 80 105 109 4 1.0 1047.4 1.2X +Coalesce Num Partitions: 5000 Num Hosts: 1 2336 2354 16 0.0 23362.8 0.1X +Coalesce Num Partitions: 5000 Num Hosts: 5 680 684 4 0.1 6798.7 0.2X +Coalesce Num Partitions: 5000 Num Hosts: 10 381 390 8 0.3 3810.5 0.3X +Coalesce Num Partitions: 5000 Num Hosts: 20 253 255 2 0.4 2529.6 0.5X +Coalesce Num Partitions: 5000 Num Hosts: 40 171 174 3 0.6 1706.1 0.7X +Coalesce Num Partitions: 5000 Num Hosts: 80 137 139 3 0.7 1365.5 0.9X +Coalesce Num Partitions: 10000 Num Hosts: 1 4220 4253 30 0.0 42203.4 0.0X +Coalesce Num Partitions: 10000 Num Hosts: 5 1377 1394 16 0.1 13769.0 0.1X +Coalesce Num Partitions: 10000 Num Hosts: 10 704 717 12 0.1 7036.4 0.2X +Coalesce Num Partitions: 10000 Num Hosts: 20 420 422 1 0.2 4201.7 0.3X +Coalesce Num Partitions: 10000 Num Hosts: 40 267 271 4 0.4 2669.3 0.5X +Coalesce Num Partitions: 10000 Num Hosts: 80 184 191 6 0.5 1842.2 0.7X diff --git a/core/benchmarks/KryoBenchmark-jdk21-results.txt b/core/benchmarks/KryoBenchmark-jdk21-results.txt index 8488d15602228..aee420e8ca26a 100644 --- a/core/benchmarks/KryoBenchmark-jdk21-results.txt +++ b/core/benchmarks/KryoBenchmark-jdk21-results.txt @@ -2,27 +2,27 @@ Benchmark Kryo Unsafe vs safe Serialization ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark Kryo Unsafe vs safe Serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -basicTypes: Int with unsafe:true 173 174 1 5.8 172.9 1.0X -basicTypes: Long with unsafe:true 178 179 1 5.6 177.7 1.0X -basicTypes: Float with unsafe:true 180 182 1 5.6 179.6 1.0X -basicTypes: Double with unsafe:true 183 191 15 5.5 182.8 0.9X -Array: Int with unsafe:true 1 1 0 753.9 1.3 130.3X -Array: Long with unsafe:true 2 2 0 486.0 2.1 84.0X -Array: Float with unsafe:true 1 1 0 759.7 1.3 131.3X -Array: Double with unsafe:true 2 2 0 473.8 2.1 81.9X -Map of string->Double with unsafe:true 27 27 1 37.5 26.7 6.5X -basicTypes: Int with unsafe:false 204 205 1 4.9 203.7 0.8X -basicTypes: Long with unsafe:false 229 230 1 4.4 229.1 0.8X -basicTypes: Float with unsafe:false 208 209 1 4.8 208.1 0.8X -basicTypes: Double with unsafe:false 204 205 2 4.9 203.7 0.8X -Array: Int with unsafe:false 13 13 0 79.4 12.6 13.7X -Array: Long with unsafe:false 21 21 0 47.8 20.9 8.3X -Array: Float with unsafe:false 6 6 0 168.3 5.9 29.1X -Array: Double with unsafe:false 11 12 0 87.2 11.5 15.1X -Map of string->Double with unsafe:false 30 31 0 32.9 30.4 5.7X +basicTypes: Int with unsafe:true 174 176 1 5.7 174.3 1.0X +basicTypes: Long with unsafe:true 178 184 5 5.6 178.1 1.0X +basicTypes: Float with unsafe:true 185 187 1 5.4 185.3 0.9X +basicTypes: Double with unsafe:true 187 189 1 5.3 187.0 0.9X +Array: Int with unsafe:true 1 1 0 752.1 1.3 131.1X +Array: Long with unsafe:true 2 2 0 490.6 2.0 85.5X +Array: Float with unsafe:true 1 1 0 757.7 1.3 132.1X +Array: Double with unsafe:true 2 2 0 483.9 2.1 84.4X +Map of string->Double with unsafe:true 26 26 2 38.5 26.0 6.7X +basicTypes: Int with unsafe:false 206 207 1 4.9 205.7 0.8X +basicTypes: Long with unsafe:false 222 223 1 4.5 221.7 0.8X +basicTypes: Float with unsafe:false 217 218 1 4.6 216.5 0.8X +basicTypes: Double with unsafe:false 217 218 2 4.6 216.6 0.8X +Array: Int with unsafe:false 13 13 0 79.5 12.6 13.9X +Array: Long with unsafe:false 21 22 0 46.6 21.4 8.1X +Array: Float with unsafe:false 6 6 0 167.8 6.0 29.3X +Array: Double with unsafe:false 16 16 0 64.2 15.6 11.2X +Map of string->Double with unsafe:false 28 28 1 36.3 27.5 6.3X diff --git a/core/benchmarks/KryoBenchmark-results.txt b/core/benchmarks/KryoBenchmark-results.txt index 5c35cc6affc1f..ca80b13a5346d 100644 --- a/core/benchmarks/KryoBenchmark-results.txt +++ b/core/benchmarks/KryoBenchmark-results.txt @@ -2,27 +2,27 @@ Benchmark Kryo Unsafe vs safe Serialization ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark Kryo Unsafe vs safe Serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -basicTypes: Int with unsafe:true 171 172 1 5.9 170.5 1.0X -basicTypes: Long with unsafe:true 185 187 2 5.4 185.5 0.9X -basicTypes: Float with unsafe:true 190 192 1 5.3 190.0 0.9X -basicTypes: Double with unsafe:true 199 200 1 5.0 198.8 0.9X -Array: Int with unsafe:true 1 1 0 751.1 1.3 128.1X -Array: Long with unsafe:true 2 2 0 483.9 2.1 82.5X -Array: Float with unsafe:true 1 1 0 734.7 1.4 125.3X -Array: Double with unsafe:true 2 2 0 478.2 2.1 81.6X -Map of string->Double with unsafe:true 26 26 0 38.8 25.8 6.6X -basicTypes: Int with unsafe:false 207 209 1 4.8 207.5 0.8X -basicTypes: Long with unsafe:false 239 241 2 4.2 239.1 0.7X -basicTypes: Float with unsafe:false 213 213 1 4.7 212.6 0.8X -basicTypes: Double with unsafe:false 224 226 1 4.5 224.2 0.8X -Array: Int with unsafe:false 14 14 0 73.5 13.6 12.5X -Array: Long with unsafe:false 21 21 0 47.8 20.9 8.1X -Array: Float with unsafe:false 6 6 0 169.9 5.9 29.0X -Array: Double with unsafe:false 10 10 0 101.4 9.9 17.3X -Map of string->Double with unsafe:false 28 28 0 35.9 27.9 6.1X +basicTypes: Int with unsafe:true 171 173 2 5.9 170.8 1.0X +basicTypes: Long with unsafe:true 190 193 3 5.3 189.6 0.9X +basicTypes: Float with unsafe:true 186 188 2 5.4 186.1 0.9X +basicTypes: Double with unsafe:true 189 190 1 5.3 188.5 0.9X +Array: Int with unsafe:true 1 2 0 720.0 1.4 123.0X +Array: Long with unsafe:true 2 3 0 462.4 2.2 79.0X +Array: Float with unsafe:true 1 2 0 719.2 1.4 122.9X +Array: Double with unsafe:true 2 3 0 459.8 2.2 78.5X +Map of string->Double with unsafe:true 27 28 1 37.2 26.9 6.3X +basicTypes: Int with unsafe:false 219 220 1 4.6 218.7 0.8X +basicTypes: Long with unsafe:false 242 244 2 4.1 242.1 0.7X +basicTypes: Float with unsafe:false 215 220 10 4.7 214.8 0.8X +basicTypes: Double with unsafe:false 222 224 2 4.5 221.7 0.8X +Array: Int with unsafe:false 15 15 0 66.9 15.0 11.4X +Array: Long with unsafe:false 22 22 0 45.9 21.8 7.8X +Array: Float with unsafe:false 6 6 1 170.1 5.9 29.1X +Array: Double with unsafe:false 10 10 0 103.0 9.7 17.6X +Map of string->Double with unsafe:false 31 32 2 32.4 30.9 5.5X diff --git a/core/benchmarks/KryoIteratorBenchmark-jdk21-results.txt b/core/benchmarks/KryoIteratorBenchmark-jdk21-results.txt index da82b05fd59e6..e3922382068dd 100644 --- a/core/benchmarks/KryoIteratorBenchmark-jdk21-results.txt +++ b/core/benchmarks/KryoIteratorBenchmark-jdk21-results.txt @@ -2,27 +2,27 @@ Benchmark of kryo asIterator on deserialization stream ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark of kryo asIterator on deserialization stream: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------------- -Colletion of int with 1 elements, useIterator: true 6 6 0 1.6 618.2 1.0X -Colletion of int with 10 elements, useIterator: true 14 15 0 0.7 1444.0 0.4X -Colletion of int with 100 elements, useIterator: true 92 92 1 0.1 9168.2 0.1X -Colletion of string with 1 elements, useIterator: true 8 8 0 1.3 777.9 0.8X -Colletion of string with 10 elements, useIterator: true 22 23 0 0.5 2221.0 0.3X -Colletion of string with 100 elements, useIterator: true 166 167 1 0.1 16617.2 0.0X -Colletion of Array[int] with 1 elements, useIterator: true 7 8 0 1.4 730.2 0.8X -Colletion of Array[int] with 10 elements, useIterator: true 20 20 0 0.5 1967.8 0.3X -Colletion of Array[int] with 100 elements, useIterator: true 145 146 1 0.1 14469.9 0.0X -Colletion of int with 1 elements, useIterator: false 7 7 0 1.5 653.6 0.9X -Colletion of int with 10 elements, useIterator: false 15 16 0 0.7 1528.7 0.4X -Colletion of int with 100 elements, useIterator: false 98 98 1 0.1 9755.3 0.1X -Colletion of string with 1 elements, useIterator: false 7 7 0 1.4 718.5 0.9X -Colletion of string with 10 elements, useIterator: false 21 22 2 0.5 2093.0 0.3X -Colletion of string with 100 elements, useIterator: false 157 157 1 0.1 15666.5 0.0X -Colletion of Array[int] with 1 elements, useIterator: false 7 7 0 1.4 698.1 0.9X -Colletion of Array[int] with 10 elements, useIterator: false 18 19 0 0.5 1831.7 0.3X -Colletion of Array[int] with 100 elements, useIterator: false 134 135 0 0.1 13430.8 0.0X +Colletion of int with 1 elements, useIterator: true 6 6 0 1.7 590.7 1.0X +Colletion of int with 10 elements, useIterator: true 14 15 1 0.7 1431.1 0.4X +Colletion of int with 100 elements, useIterator: true 94 96 1 0.1 9429.2 0.1X +Colletion of string with 1 elements, useIterator: true 8 8 0 1.3 760.8 0.8X +Colletion of string with 10 elements, useIterator: true 22 23 0 0.4 2246.1 0.3X +Colletion of string with 100 elements, useIterator: true 167 167 1 0.1 16659.4 0.0X +Colletion of Array[int] with 1 elements, useIterator: true 7 8 0 1.4 735.4 0.8X +Colletion of Array[int] with 10 elements, useIterator: true 20 20 1 0.5 1976.5 0.3X +Colletion of Array[int] with 100 elements, useIterator: true 154 154 1 0.1 15356.7 0.0X +Colletion of int with 1 elements, useIterator: false 6 7 0 1.6 623.6 0.9X +Colletion of int with 10 elements, useIterator: false 13 14 0 0.7 1334.3 0.4X +Colletion of int with 100 elements, useIterator: false 82 82 0 0.1 8164.4 0.1X +Colletion of string with 1 elements, useIterator: false 7 8 0 1.4 727.0 0.8X +Colletion of string with 10 elements, useIterator: false 22 22 0 0.5 2166.5 0.3X +Colletion of string with 100 elements, useIterator: false 159 160 0 0.1 15925.0 0.0X +Colletion of Array[int] with 1 elements, useIterator: false 7 7 0 1.4 712.1 0.8X +Colletion of Array[int] with 10 elements, useIterator: false 19 20 0 0.5 1932.3 0.3X +Colletion of Array[int] with 100 elements, useIterator: false 142 143 1 0.1 14220.2 0.0X diff --git a/core/benchmarks/KryoIteratorBenchmark-results.txt b/core/benchmarks/KryoIteratorBenchmark-results.txt index e2cbfb871e7ba..77452144ac01d 100644 --- a/core/benchmarks/KryoIteratorBenchmark-results.txt +++ b/core/benchmarks/KryoIteratorBenchmark-results.txt @@ -2,27 +2,27 @@ Benchmark of kryo asIterator on deserialization stream ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark of kryo asIterator on deserialization stream: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------------- -Colletion of int with 1 elements, useIterator: true 6 6 1 1.6 611.4 1.0X -Colletion of int with 10 elements, useIterator: true 14 15 0 0.7 1443.2 0.4X -Colletion of int with 100 elements, useIterator: true 93 94 0 0.1 9331.3 0.1X -Colletion of string with 1 elements, useIterator: true 8 8 0 1.3 753.6 0.8X -Colletion of string with 10 elements, useIterator: true 22 22 0 0.5 2150.8 0.3X -Colletion of string with 100 elements, useIterator: true 163 164 1 0.1 16325.7 0.0X -Colletion of Array[int] with 1 elements, useIterator: true 7 8 0 1.3 741.1 0.8X -Colletion of Array[int] with 10 elements, useIterator: true 20 20 0 0.5 1989.2 0.3X -Colletion of Array[int] with 100 elements, useIterator: true 147 147 1 0.1 14659.2 0.0X -Colletion of int with 1 elements, useIterator: false 6 6 0 1.7 597.3 1.0X -Colletion of int with 10 elements, useIterator: false 13 14 0 0.8 1323.4 0.5X -Colletion of int with 100 elements, useIterator: false 83 84 3 0.1 8272.9 0.1X -Colletion of string with 1 elements, useIterator: false 7 7 0 1.4 714.5 0.9X -Colletion of string with 10 elements, useIterator: false 21 22 1 0.5 2146.0 0.3X -Colletion of string with 100 elements, useIterator: false 157 157 0 0.1 15690.5 0.0X -Colletion of Array[int] with 1 elements, useIterator: false 7 7 0 1.5 668.5 0.9X -Colletion of Array[int] with 10 elements, useIterator: false 18 18 0 0.6 1802.1 0.3X -Colletion of Array[int] with 100 elements, useIterator: false 134 135 1 0.1 13393.9 0.0X +Colletion of int with 1 elements, useIterator: true 6 6 0 1.6 621.6 1.0X +Colletion of int with 10 elements, useIterator: true 14 14 0 0.7 1422.5 0.4X +Colletion of int with 100 elements, useIterator: true 90 92 1 0.1 9030.9 0.1X +Colletion of string with 1 elements, useIterator: true 7 8 0 1.4 726.3 0.9X +Colletion of string with 10 elements, useIterator: true 23 23 1 0.4 2251.0 0.3X +Colletion of string with 100 elements, useIterator: true 172 172 0 0.1 17183.3 0.0X +Colletion of Array[int] with 1 elements, useIterator: true 7 7 0 1.4 718.6 0.9X +Colletion of Array[int] with 10 elements, useIterator: true 21 21 1 0.5 2078.4 0.3X +Colletion of Array[int] with 100 elements, useIterator: true 162 162 0 0.1 16189.6 0.0X +Colletion of int with 1 elements, useIterator: false 6 6 0 1.7 599.6 1.0X +Colletion of int with 10 elements, useIterator: false 13 13 0 0.8 1313.8 0.5X +Colletion of int with 100 elements, useIterator: false 81 82 0 0.1 8132.1 0.1X +Colletion of string with 1 elements, useIterator: false 7 7 0 1.4 705.6 0.9X +Colletion of string with 10 elements, useIterator: false 22 23 0 0.4 2240.1 0.3X +Colletion of string with 100 elements, useIterator: false 170 170 0 0.1 16995.3 0.0X +Colletion of Array[int] with 1 elements, useIterator: false 7 7 0 1.5 675.7 0.9X +Colletion of Array[int] with 10 elements, useIterator: false 18 19 0 0.5 1842.7 0.3X +Colletion of Array[int] with 100 elements, useIterator: false 138 139 0 0.1 13801.7 0.0X diff --git a/core/benchmarks/KryoSerializerBenchmark-jdk21-results.txt b/core/benchmarks/KryoSerializerBenchmark-jdk21-results.txt index bb234fd6a13b3..c00cd9152b278 100644 --- a/core/benchmarks/KryoSerializerBenchmark-jdk21-results.txt +++ b/core/benchmarks/KryoSerializerBenchmark-jdk21-results.txt @@ -2,11 +2,11 @@ Benchmark KryoPool vs old"pool of 1" implementation ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark KryoPool vs old"pool of 1" implementation: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------- -KryoPool:true 3937 5805 1793 0.0 7874251.0 1.0X -KryoPool:false 6392 8372 1513 0.0 12783860.9 0.6X +KryoPool:true 4166 5737 1977 0.0 8331992.4 1.0X +KryoPool:false 6201 7778 1281 0.0 12402118.8 0.7X diff --git a/core/benchmarks/KryoSerializerBenchmark-results.txt b/core/benchmarks/KryoSerializerBenchmark-results.txt index 79b87a83b34fa..a86338957cc37 100644 --- a/core/benchmarks/KryoSerializerBenchmark-results.txt +++ b/core/benchmarks/KryoSerializerBenchmark-results.txt @@ -2,11 +2,11 @@ Benchmark KryoPool vs old"pool of 1" implementation ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark KryoPool vs old"pool of 1" implementation: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------- -KryoPool:true 4180 6016 1956 0.0 8359077.2 1.0X -KryoPool:false 6204 8232 1447 0.0 12408361.0 0.7X +KryoPool:true 3973 5797 1879 0.0 7945107.3 1.0X +KryoPool:false 6041 7623 1484 0.0 12082153.5 0.7X diff --git a/core/benchmarks/LZFBenchmark-jdk21-results.txt b/core/benchmarks/LZFBenchmark-jdk21-results.txt index e1566f201a1f6..7104879c5c753 100644 --- a/core/benchmarks/LZFBenchmark-jdk21-results.txt +++ b/core/benchmarks/LZFBenchmark-jdk21-results.txt @@ -2,18 +2,18 @@ Benchmark LZFCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Compress small objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------- -Compression 256000000 int values in parallel 598 600 2 428.2 2.3 1.0X -Compression 256000000 int values single-threaded 568 570 2 451.0 2.2 1.1X +Compression 256000000 int values in parallel 599 601 4 427.4 2.3 1.0X +Compression 256000000 int values single-threaded 608 615 7 420.9 2.4 1.0X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1021-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Compress large objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------- -Compression 1024 array values in 1 threads 39 45 5 0.0 38475.4 1.0X -Compression 1024 array values single-threaded 32 33 1 0.0 31154.5 1.2X +Compression 1024 array values in 1 threads 37 45 5 0.0 35857.3 1.0X +Compression 1024 array values single-threaded 31 31 0 0.0 30334.5 1.2X diff --git a/core/benchmarks/LZFBenchmark-results.txt b/core/benchmarks/LZFBenchmark-results.txt index facc67f9cf4a8..142d3aad2f1ba 100644 --- a/core/benchmarks/LZFBenchmark-results.txt +++ b/core/benchmarks/LZFBenchmark-results.txt @@ -2,18 +2,18 @@ Benchmark LZFCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Compress small objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------- -Compression 256000000 int values in parallel 602 612 6 425.1 2.4 1.0X -Compression 256000000 int values single-threaded 610 617 5 419.8 2.4 1.0X +Compression 256000000 int values in parallel 601 610 9 426.3 2.3 1.0X +Compression 256000000 int values single-threaded 610 619 7 419.3 2.4 1.0X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1021-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Compress large objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------- -Compression 1024 array values in 1 threads 35 43 6 0.0 33806.8 1.0X -Compression 1024 array values single-threaded 32 32 0 0.0 30990.4 1.1X +Compression 1024 array values in 1 threads 35 44 5 0.0 34512.1 1.0X +Compression 1024 array values single-threaded 31 32 1 0.0 30396.7 1.1X diff --git a/core/benchmarks/MapStatusesConvertBenchmark-jdk21-results.txt b/core/benchmarks/MapStatusesConvertBenchmark-jdk21-results.txt index a15442496b244..7c1b1eb4ac803 100644 --- a/core/benchmarks/MapStatusesConvertBenchmark-jdk21-results.txt +++ b/core/benchmarks/MapStatusesConvertBenchmark-jdk21-results.txt @@ -2,12 +2,12 @@ MapStatuses Convert Benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1022-azure +OpenJDK 64-Bit Server VM 21.0.5+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor MapStatuses Convert: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Num Maps: 50000 Fetch partitions:500 674 685 12 0.0 673772738.0 1.0X -Num Maps: 50000 Fetch partitions:1000 1579 1590 12 0.0 1579383970.0 0.4X -Num Maps: 50000 Fetch partitions:1500 2435 2472 37 0.0 2434530380.0 0.3X +Num Maps: 50000 Fetch partitions:500 696 705 13 0.0 696430567.0 1.0X +Num Maps: 50000 Fetch partitions:1000 1620 1628 7 0.0 1620094001.0 0.4X +Num Maps: 50000 Fetch partitions:1500 2507 2522 13 0.0 2507485825.0 0.3X diff --git a/core/benchmarks/MapStatusesConvertBenchmark-results.txt b/core/benchmarks/MapStatusesConvertBenchmark-results.txt index b9f36af4a6531..4ca2e502b9404 100644 --- a/core/benchmarks/MapStatusesConvertBenchmark-results.txt +++ b/core/benchmarks/MapStatusesConvertBenchmark-results.txt @@ -2,12 +2,12 @@ MapStatuses Convert Benchmark ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1022-azure +OpenJDK 64-Bit Server VM 17.0.13+11-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor MapStatuses Convert: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Num Maps: 50000 Fetch partitions:500 703 716 11 0.0 703103575.0 1.0X -Num Maps: 50000 Fetch partitions:1000 1707 1723 14 0.0 1707060398.0 0.4X -Num Maps: 50000 Fetch partitions:1500 2626 2638 14 0.0 2625981097.0 0.3X +Num Maps: 50000 Fetch partitions:500 775 782 8 0.0 774584162.0 1.0X +Num Maps: 50000 Fetch partitions:1000 1605 1634 29 0.0 1604801022.0 0.5X +Num Maps: 50000 Fetch partitions:1500 2568 2585 22 0.0 2568404459.0 0.3X diff --git a/core/benchmarks/MapStatusesSerDeserBenchmark-jdk21-results.txt b/core/benchmarks/MapStatusesSerDeserBenchmark-jdk21-results.txt index f4846ce8b0fb3..1ffe7594c22cd 100644 --- a/core/benchmarks/MapStatusesSerDeserBenchmark-jdk21-results.txt +++ b/core/benchmarks/MapStatusesSerDeserBenchmark-jdk21-results.txt @@ -1,64 +1,64 @@ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 10 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Serialization 84 88 7 2.4 419.8 1.0X -Deserialization 140 148 8 1.4 700.9 0.6X +Serialization 81 85 4 2.5 406.5 1.0X +Deserialization 147 155 9 1.4 734.2 0.6X -Compressed Serialized MapStatus sizes: 427.0 B +Compressed Serialized MapStatus sizes: 426.0 B Compressed Serialized Broadcast MapStatus sizes: 2.5 MiB -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 10 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -Serialization 80 81 2 2.5 400.1 1.0X -Deserialization 139 146 6 1.4 694.8 0.6X +Serialization 73 75 3 2.7 365.9 1.0X +Deserialization 146 153 10 1.4 732.1 0.5X Compressed Serialized MapStatus sizes: 2.5 MiB Compressed Serialized Broadcast MapStatus sizes: 0.0 B -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 100 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -Serialization 153 158 11 1.3 767.3 1.0X -Deserialization 159 165 5 1.3 794.1 1.0X +Serialization 152 157 6 1.3 759.6 1.0X +Deserialization 162 166 4 1.2 811.8 0.9X Compressed Serialized MapStatus sizes: 442.0 B Compressed Serialized Broadcast MapStatus sizes: 13.6 MiB -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 100 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -Serialization 140 143 4 1.4 701.0 1.0X -Deserialization 154 159 7 1.3 771.3 0.9X +Serialization 141 142 1 1.4 703.4 1.0X +Deserialization 161 165 5 1.2 807.0 0.9X Compressed Serialized MapStatus sizes: 13.6 MiB Compressed Serialized Broadcast MapStatus sizes: 0.0 B -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 1000 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -Serialization 707 726 21 0.3 3533.4 1.0X -Deserialization 316 342 17 0.6 1582.2 2.2X +Serialization 686 714 31 0.3 3431.6 1.0X +Deserialization 329 346 18 0.6 1645.6 2.1X -Compressed Serialized MapStatus sizes: 570.0 B +Compressed Serialized MapStatus sizes: 569.0 B Compressed Serialized Broadcast MapStatus sizes: 122.3 MiB -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 1000 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ---------------------------------------------------------------------------------------------------------------------------- -Serialization 566 569 5 0.4 2828.3 1.0X -Deserialization 324 343 19 0.6 1617.8 1.7X +Serialization 560 562 3 0.4 2797.7 1.0X +Deserialization 317 334 20 0.6 1587.0 1.8X Compressed Serialized MapStatus sizes: 122.3 MiB Compressed Serialized Broadcast MapStatus sizes: 0.0 B diff --git a/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt b/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt index 3b0b9b756d32b..edd6207a12f8b 100644 --- a/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt +++ b/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt @@ -1,64 +1,64 @@ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 10 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -Serialization 89 95 4 2.2 445.9 1.0X -Deserialization 138 145 6 1.4 689.8 0.6X +Serialization 93 99 8 2.1 466.1 1.0X +Deserialization 140 151 12 1.4 698.8 0.7X -Compressed Serialized MapStatus sizes: 427.0 B +Compressed Serialized MapStatus sizes: 426.0 B Compressed Serialized Broadcast MapStatus sizes: 2.5 MiB -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 10 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -Serialization 80 83 2 2.5 399.4 1.0X -Deserialization 134 145 10 1.5 671.0 0.6X +Serialization 82 83 1 2.4 409.0 1.0X +Deserialization 139 142 8 1.4 692.8 0.6X Compressed Serialized MapStatus sizes: 2.5 MiB Compressed Serialized Broadcast MapStatus sizes: 0.0 B -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 100 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -Serialization 159 163 7 1.3 793.4 1.0X -Deserialization 155 163 8 1.3 776.8 1.0X +Serialization 161 168 7 1.2 802.6 1.0X +Deserialization 155 169 13 1.3 777.3 1.0X Compressed Serialized MapStatus sizes: 442.0 B Compressed Serialized Broadcast MapStatus sizes: 13.6 MiB -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 100 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -Serialization 144 147 1 1.4 718.8 1.0X -Deserialization 154 159 4 1.3 770.4 0.9X +Serialization 145 147 1 1.4 726.2 1.0X +Deserialization 155 160 10 1.3 772.7 0.9X Compressed Serialized MapStatus sizes: 13.6 MiB Compressed Serialized Broadcast MapStatus sizes: 0.0 B -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 1000 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------- -Serialization 710 713 4 0.3 3549.7 1.0X -Deserialization 346 355 7 0.6 1730.4 2.1X +Serialization 693 714 23 0.3 3465.9 1.0X +Deserialization 326 351 13 0.6 1628.7 2.1X -Compressed Serialized MapStatus sizes: 569.0 B +Compressed Serialized MapStatus sizes: 568.0 B Compressed Serialized Broadcast MapStatus sizes: 122.3 MiB -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 200000 MapOutputs, 1000 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ---------------------------------------------------------------------------------------------------------------------------- -Serialization 557 564 5 0.4 2783.1 1.0X -Deserialization 325 337 10 0.6 1626.5 1.7X +Serialization 564 576 13 0.4 2817.6 1.0X +Deserialization 339 354 9 0.6 1694.4 1.7X Compressed Serialized MapStatus sizes: 122.3 MiB Compressed Serialized Broadcast MapStatus sizes: 0.0 B diff --git a/core/benchmarks/PersistenceEngineBenchmark-jdk21-results.txt b/core/benchmarks/PersistenceEngineBenchmark-jdk21-results.txt index 8da4b4953cad0..7262ea63a6ef9 100644 --- a/core/benchmarks/PersistenceEngineBenchmark-jdk21-results.txt +++ b/core/benchmarks/PersistenceEngineBenchmark-jdk21-results.txt @@ -2,17 +2,17 @@ PersistenceEngineBenchmark ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 1000 Workers: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ---------------------------------------------------------------------------------------------------------------------------------------- -ZooKeeperPersistenceEngine with JavaSerializer 6876 7518 567 0.0 6875910.8 1.0X -FileSystemPersistenceEngine with JavaSerializer 2973 3015 55 0.0 2973365.8 2.3X -FileSystemPersistenceEngine with JavaSerializer (lz4) 813 836 26 0.0 813019.4 8.5X -FileSystemPersistenceEngine with JavaSerializer (lzf) 756 772 21 0.0 755574.7 9.1X -FileSystemPersistenceEngine with JavaSerializer (snappy) 755 788 46 0.0 754897.8 9.1X -FileSystemPersistenceEngine with JavaSerializer (zstd) 930 964 35 0.0 930157.5 7.4X -RocksDBPersistenceEngine with JavaSerializer 302 305 3 0.0 302099.4 22.8X -BlackHolePersistenceEngine 0 0 0 5.9 168.2 40871.1X +ZooKeeperPersistenceEngine with JavaSerializer 5620 5811 236 0.0 5619587.2 1.0X +FileSystemPersistenceEngine with JavaSerializer 2876 2924 42 0.0 2876068.8 2.0X +FileSystemPersistenceEngine with JavaSerializer (lz4) 825 829 4 0.0 824880.6 6.8X +FileSystemPersistenceEngine with JavaSerializer (lzf) 742 774 39 0.0 742492.3 7.6X +FileSystemPersistenceEngine with JavaSerializer (snappy) 785 832 42 0.0 784738.0 7.2X +FileSystemPersistenceEngine with JavaSerializer (zstd) 966 982 14 0.0 965925.8 5.8X +RocksDBPersistenceEngine with JavaSerializer 299 301 2 0.0 299470.1 18.8X +BlackHolePersistenceEngine 0 0 0 6.0 166.6 33740.5X diff --git a/core/benchmarks/PersistenceEngineBenchmark-results.txt b/core/benchmarks/PersistenceEngineBenchmark-results.txt index f927e3c57aa33..c373d88842d2e 100644 --- a/core/benchmarks/PersistenceEngineBenchmark-results.txt +++ b/core/benchmarks/PersistenceEngineBenchmark-results.txt @@ -2,17 +2,17 @@ PersistenceEngineBenchmark ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor 1000 Workers: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ---------------------------------------------------------------------------------------------------------------------------------------- -ZooKeeperPersistenceEngine with JavaSerializer 6080 6179 119 0.0 6079694.4 1.0X -FileSystemPersistenceEngine with JavaSerializer 3011 3060 42 0.0 3011244.1 2.0X -FileSystemPersistenceEngine with JavaSerializer (lz4) 799 827 28 0.0 799357.3 7.6X -FileSystemPersistenceEngine with JavaSerializer (lzf) 800 839 35 0.0 800038.9 7.6X -FileSystemPersistenceEngine with JavaSerializer (snappy) 786 797 11 0.0 785847.0 7.7X -FileSystemPersistenceEngine with JavaSerializer (zstd) 1025 1028 3 0.0 1024806.3 5.9X -RocksDBPersistenceEngine with JavaSerializer 309 311 3 0.0 308522.6 19.7X -BlackHolePersistenceEngine 0 0 0 5.8 173.5 35032.8X +ZooKeeperPersistenceEngine with JavaSerializer 6146 6314 215 0.0 6146007.1 1.0X +FileSystemPersistenceEngine with JavaSerializer 2944 2957 17 0.0 2944099.7 2.1X +FileSystemPersistenceEngine with JavaSerializer (lz4) 827 869 37 0.0 827379.0 7.4X +FileSystemPersistenceEngine with JavaSerializer (lzf) 799 826 25 0.0 799318.1 7.7X +FileSystemPersistenceEngine with JavaSerializer (snappy) 775 805 50 0.0 774802.8 7.9X +FileSystemPersistenceEngine with JavaSerializer (zstd) 972 1002 28 0.0 971773.9 6.3X +RocksDBPersistenceEngine with JavaSerializer 310 312 3 0.0 310401.2 19.8X +BlackHolePersistenceEngine 0 0 0 6.0 165.7 37097.7X diff --git a/core/benchmarks/PropertiesCloneBenchmark-jdk21-results.txt b/core/benchmarks/PropertiesCloneBenchmark-jdk21-results.txt index 06701357609a5..ccae104413f6a 100644 --- a/core/benchmarks/PropertiesCloneBenchmark-jdk21-results.txt +++ b/core/benchmarks/PropertiesCloneBenchmark-jdk21-results.txt @@ -2,39 +2,39 @@ Properties Cloning ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Empty Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 0 0 0 0.3 3356.0 1.0X -Utils.cloneProperties 0 0 0 34.5 29.0 115.7X +SerializationUtils.clone 0 0 0 0.3 3296.0 1.0X +Utils.cloneProperties 0 0 0 34.5 29.0 113.7X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor System Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 0 0 0 0.0 160560.0 1.0X -Utils.cloneProperties 0 0 0 0.3 3085.0 52.0X +SerializationUtils.clone 0 0 0 0.0 156331.0 1.0X +Utils.cloneProperties 0 0 0 0.4 2595.0 60.2X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Small Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 0 0 0 0.0 253143.0 1.0X -Utils.cloneProperties 0 0 0 0.3 3696.0 68.5X +SerializationUtils.clone 0 0 0 0.0 255154.0 1.0X +Utils.cloneProperties 0 0 0 0.4 2624.0 97.2X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Medium Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 1 1 0 0.0 907866.0 1.0X -Utils.cloneProperties 0 0 0 0.0 20328.0 44.7X +SerializationUtils.clone 1 1 0 0.0 887239.0 1.0X +Utils.cloneProperties 0 0 0 0.1 14908.0 59.5X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Large Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 2 2 0 0.0 1699706.0 1.0X -Utils.cloneProperties 0 0 0 0.0 40867.0 41.6X +SerializationUtils.clone 2 2 0 0.0 1655264.0 1.0X +Utils.cloneProperties 0 0 0 0.0 30837.0 53.7X diff --git a/core/benchmarks/PropertiesCloneBenchmark-results.txt b/core/benchmarks/PropertiesCloneBenchmark-results.txt index 13c241f53d182..f6c6c8781dc25 100644 --- a/core/benchmarks/PropertiesCloneBenchmark-results.txt +++ b/core/benchmarks/PropertiesCloneBenchmark-results.txt @@ -2,39 +2,39 @@ Properties Cloning ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Empty Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 0 0 0 0.3 3617.0 1.0X -Utils.cloneProperties 0 0 0 34.5 29.0 124.7X +SerializationUtils.clone 0 0 0 0.3 3466.0 1.0X +Utils.cloneProperties 0 0 0 34.5 29.0 119.5X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor System Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 0 0 0 0.0 148347.0 1.0X -Utils.cloneProperties 0 0 0 0.4 2815.0 52.7X +SerializationUtils.clone 0 0 0 0.0 156422.0 1.0X +Utils.cloneProperties 0 0 0 0.4 2685.0 58.3X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Small Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 0 0 0 0.0 280142.0 1.0X -Utils.cloneProperties 0 0 0 0.3 3686.0 76.0X +SerializationUtils.clone 0 0 0 0.0 277017.0 1.0X +Utils.cloneProperties 0 0 0 0.3 3666.0 75.6X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Medium Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 1 1 0 0.0 941434.0 1.0X -Utils.cloneProperties 0 0 0 0.0 20278.0 46.4X +SerializationUtils.clone 1 1 0 0.0 920141.0 1.0X +Utils.cloneProperties 0 0 0 0.0 20097.0 45.8X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Large Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SerializationUtils.clone 2 2 0 0.0 1737205.0 1.0X -Utils.cloneProperties 0 0 0 0.0 40696.0 42.7X +SerializationUtils.clone 2 2 0 0.0 1714798.0 1.0X +Utils.cloneProperties 0 0 0 0.0 40385.0 42.5X diff --git a/core/benchmarks/XORShiftRandomBenchmark-jdk21-results.txt b/core/benchmarks/XORShiftRandomBenchmark-jdk21-results.txt index 6f1f333363e96..9f2baa5d9bf80 100644 --- a/core/benchmarks/XORShiftRandomBenchmark-jdk21-results.txt +++ b/core/benchmarks/XORShiftRandomBenchmark-jdk21-results.txt @@ -2,43 +2,43 @@ Pseudo random ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor nextInt: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -java.util.Random 451 451 0 221.9 4.5 1.0X +java.util.Random 451 451 0 222.0 4.5 1.0X XORShiftRandom 185 185 0 539.4 1.9 2.4X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor nextLong: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -java.util.Random 894 899 7 111.8 8.9 1.0X -XORShiftRandom 371 372 2 269.4 3.7 2.4X +java.util.Random 897 897 0 111.5 9.0 1.0X +XORShiftRandom 371 371 0 269.5 3.7 2.4X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor nextDouble: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -java.util.Random 900 900 0 111.1 9.0 1.0X -XORShiftRandom 371 371 0 269.6 3.7 2.4X +java.util.Random 914 914 0 109.4 9.1 1.0X +XORShiftRandom 371 371 1 269.5 3.7 2.5X -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor nextGaussian: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -java.util.Random 3373 3374 2 29.7 33.7 1.0X -XORShiftRandom 2454 2460 11 40.8 24.5 1.4X +java.util.Random 3381 3384 3 29.6 33.8 1.0X +XORShiftRandom 2480 2498 29 40.3 24.8 1.4X ================================================================================================ hash seed ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Hash seed: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -XORShiftRandom.hashSeed 1 1 0 12974.9 0.1 1.0X +XORShiftRandom.hashSeed 1 1 0 12973.9 0.1 1.0X diff --git a/core/benchmarks/XORShiftRandomBenchmark-results.txt b/core/benchmarks/XORShiftRandomBenchmark-results.txt index 9701c0ca237bf..de5f7c04fddfc 100644 --- a/core/benchmarks/XORShiftRandomBenchmark-results.txt +++ b/core/benchmarks/XORShiftRandomBenchmark-results.txt @@ -2,43 +2,43 @@ Pseudo random ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor nextInt: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -java.util.Random 442 442 0 226.4 4.4 1.0X -XORShiftRandom 185 185 0 539.4 1.9 2.4X +java.util.Random 439 439 0 227.7 4.4 1.0X +XORShiftRandom 185 185 0 539.5 1.9 2.4X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor nextLong: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -java.util.Random 913 913 0 109.6 9.1 1.0X -XORShiftRandom 371 372 1 269.6 3.7 2.5X +java.util.Random 911 912 1 109.8 9.1 1.0X +XORShiftRandom 371 371 1 269.7 3.7 2.5X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor nextDouble: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -java.util.Random 906 906 0 110.4 9.1 1.0X -XORShiftRandom 371 371 1 269.6 3.7 2.4X +java.util.Random 904 904 0 110.6 9.0 1.0X +XORShiftRandom 371 371 0 269.7 3.7 2.4X -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor nextGaussian: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -java.util.Random 4170 4171 1 24.0 41.7 1.0X -XORShiftRandom 2993 2996 3 33.4 29.9 1.4X +java.util.Random 3997 3998 1 25.0 40.0 1.0X +XORShiftRandom 2926 2929 2 34.2 29.3 1.4X ================================================================================================ hash seed ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1018-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Hash seed: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -XORShiftRandom.hashSeed 1 1 0 12975.0 0.1 1.0X +XORShiftRandom.hashSeed 2 2 0 6487.7 0.2 1.0X diff --git a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt index 66c89ac139dc2..f6bd681451d5e 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt @@ -2,48 +2,48 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1023-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 654 675 20 0.0 65380.3 1.0X -Compression 10000 times at level 2 without buffer pool 714 715 1 0.0 71445.4 0.9X -Compression 10000 times at level 3 without buffer pool 811 814 3 0.0 81142.7 0.8X -Compression 10000 times at level 1 with buffer pool 605 606 1 0.0 60518.7 1.1X -Compression 10000 times at level 2 with buffer pool 634 637 2 0.0 63441.4 1.0X -Compression 10000 times at level 3 with buffer pool 743 743 0 0.0 74258.7 0.9X +Compression 10000 times at level 1 without buffer pool 659 676 16 0.0 65860.7 1.0X +Compression 10000 times at level 2 without buffer pool 721 723 2 0.0 72135.5 0.9X +Compression 10000 times at level 3 without buffer pool 815 816 1 0.0 81500.6 0.8X +Compression 10000 times at level 1 with buffer pool 608 609 0 0.0 60846.6 1.1X +Compression 10000 times at level 2 with buffer pool 645 647 3 0.0 64476.3 1.0X +Compression 10000 times at level 3 with buffer pool 746 746 1 0.0 74584.0 0.9X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1023-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 824 826 4 0.0 82358.5 1.0X -Decompression 10000 times from level 2 without buffer pool 824 824 0 0.0 82394.3 1.0X -Decompression 10000 times from level 3 without buffer pool 823 824 1 0.0 82343.3 1.0X -Decompression 10000 times from level 1 with buffer pool 748 749 1 0.0 74792.0 1.1X -Decompression 10000 times from level 2 with buffer pool 748 749 2 0.0 74773.6 1.1X -Decompression 10000 times from level 3 with buffer pool 749 749 1 0.0 74868.3 1.1X +Decompression 10000 times from level 1 without buffer pool 828 829 1 0.0 82822.6 1.0X +Decompression 10000 times from level 2 without buffer pool 829 829 1 0.0 82900.7 1.0X +Decompression 10000 times from level 3 without buffer pool 828 833 8 0.0 82784.4 1.0X +Decompression 10000 times from level 1 with buffer pool 758 760 2 0.0 75756.5 1.1X +Decompression 10000 times from level 2 with buffer pool 758 758 1 0.0 75772.3 1.1X +Decompression 10000 times from level 3 with buffer pool 759 759 0 0.0 75852.7 1.1X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1023-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 48 49 1 0.0 377356.2 1.0X -Parallel Compression with 1 workers 36 37 2 0.0 279079.1 1.4X -Parallel Compression with 2 workers 36 38 1 0.0 283760.8 1.3X -Parallel Compression with 4 workers 38 40 2 0.0 298581.6 1.3X -Parallel Compression with 8 workers 41 43 1 0.0 320669.0 1.2X -Parallel Compression with 16 workers 46 48 1 0.0 356997.0 1.1X +Parallel Compression with 0 workers 58 59 1 0.0 452489.9 1.0X +Parallel Compression with 1 workers 42 45 4 0.0 330066.0 1.4X +Parallel Compression with 2 workers 40 42 1 0.0 312560.3 1.4X +Parallel Compression with 4 workers 40 42 2 0.0 308802.7 1.5X +Parallel Compression with 8 workers 41 45 3 0.0 321331.3 1.4X +Parallel Compression with 16 workers 44 45 1 0.0 343311.5 1.3X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1023-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 159 161 1 0.0 1242020.8 1.0X -Parallel Compression with 1 workers 187 188 1 0.0 1463507.3 0.8X -Parallel Compression with 2 workers 114 118 5 0.0 888481.5 1.4X -Parallel Compression with 4 workers 107 110 2 0.0 836926.1 1.5X -Parallel Compression with 8 workers 110 115 3 0.0 856838.0 1.4X -Parallel Compression with 16 workers 112 115 2 0.0 874554.2 1.4X +Parallel Compression with 0 workers 158 160 2 0.0 1234257.6 1.0X +Parallel Compression with 1 workers 193 194 1 0.0 1507686.4 0.8X +Parallel Compression with 2 workers 113 127 11 0.0 881068.0 1.4X +Parallel Compression with 4 workers 109 111 2 0.0 849241.3 1.5X +Parallel Compression with 8 workers 111 115 3 0.0 869455.2 1.4X +Parallel Compression with 16 workers 113 116 2 0.0 881832.5 1.4X diff --git a/core/benchmarks/ZStandardBenchmark-results.txt b/core/benchmarks/ZStandardBenchmark-results.txt index 7093636c4bc2b..136f0333590cc 100644 --- a/core/benchmarks/ZStandardBenchmark-results.txt +++ b/core/benchmarks/ZStandardBenchmark-results.txt @@ -2,48 +2,48 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1023-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 259 260 1 0.0 25854.0 1.0X -Compression 10000 times at level 2 without buffer pool 678 678 1 0.0 67756.2 0.4X -Compression 10000 times at level 3 without buffer pool 775 775 1 0.0 77452.6 0.3X -Compression 10000 times at level 1 with buffer pool 572 573 1 0.0 57218.3 0.5X -Compression 10000 times at level 2 with buffer pool 600 602 1 0.0 60041.6 0.4X -Compression 10000 times at level 3 with buffer pool 710 711 2 0.0 70967.2 0.4X +Compression 10000 times at level 1 without buffer pool 257 259 2 0.0 25704.2 1.0X +Compression 10000 times at level 2 without buffer pool 674 676 2 0.0 67396.3 0.4X +Compression 10000 times at level 3 without buffer pool 775 787 11 0.0 77497.9 0.3X +Compression 10000 times at level 1 with buffer pool 573 574 0 0.0 57347.3 0.4X +Compression 10000 times at level 2 with buffer pool 602 603 2 0.0 60162.8 0.4X +Compression 10000 times at level 3 with buffer pool 722 725 3 0.0 72247.3 0.4X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1023-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 588 588 1 0.0 58764.1 1.0X -Decompression 10000 times from level 2 without buffer pool 589 595 10 0.0 58919.5 1.0X -Decompression 10000 times from level 3 without buffer pool 591 597 12 0.0 59059.2 1.0X -Decompression 10000 times from level 1 with buffer pool 520 521 1 0.0 52039.6 1.1X -Decompression 10000 times from level 2 with buffer pool 521 522 1 0.0 52116.4 1.1X -Decompression 10000 times from level 3 with buffer pool 521 521 0 0.0 52100.6 1.1X +Decompression 10000 times from level 1 without buffer pool 176 177 1 0.1 17641.2 1.0X +Decompression 10000 times from level 2 without buffer pool 176 178 1 0.1 17628.9 1.0X +Decompression 10000 times from level 3 without buffer pool 175 176 0 0.1 17506.1 1.0X +Decompression 10000 times from level 1 with buffer pool 151 152 1 0.1 15051.5 1.2X +Decompression 10000 times from level 2 with buffer pool 150 151 1 0.1 14998.0 1.2X +Decompression 10000 times from level 3 with buffer pool 150 151 0 0.1 15019.4 1.2X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1023-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 46 48 1 0.0 362926.1 1.0X -Parallel Compression with 1 workers 34 36 4 0.0 265302.5 1.4X -Parallel Compression with 2 workers 32 36 1 0.0 252423.6 1.4X -Parallel Compression with 4 workers 36 38 2 0.0 282974.7 1.3X -Parallel Compression with 8 workers 38 40 1 0.0 298633.6 1.2X -Parallel Compression with 16 workers 42 44 1 0.0 329766.4 1.1X +Parallel Compression with 0 workers 57 57 0 0.0 444425.2 1.0X +Parallel Compression with 1 workers 42 44 3 0.0 325107.6 1.4X +Parallel Compression with 2 workers 38 39 2 0.0 294840.0 1.5X +Parallel Compression with 4 workers 36 37 1 0.0 282143.1 1.6X +Parallel Compression with 8 workers 39 40 1 0.0 303793.6 1.5X +Parallel Compression with 16 workers 41 43 1 0.0 324165.5 1.4X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1023-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 155 156 1 0.0 1212983.5 1.0X -Parallel Compression with 1 workers 191 195 7 0.0 1492776.4 0.8X -Parallel Compression with 2 workers 110 118 7 0.0 861496.4 1.4X -Parallel Compression with 4 workers 105 108 2 0.0 821249.1 1.5X -Parallel Compression with 8 workers 110 114 3 0.0 856234.4 1.4X -Parallel Compression with 16 workers 110 113 2 0.0 855949.6 1.4X +Parallel Compression with 0 workers 156 158 1 0.0 1220298.8 1.0X +Parallel Compression with 1 workers 188 189 1 0.0 1467911.4 0.8X +Parallel Compression with 2 workers 111 118 7 0.0 866985.2 1.4X +Parallel Compression with 4 workers 106 109 2 0.0 827592.1 1.5X +Parallel Compression with 8 workers 114 116 2 0.0 888419.5 1.4X +Parallel Compression with 16 workers 111 115 2 0.0 868463.5 1.4X diff --git a/core/pom.xml b/core/pom.xml index adb1b3034b427..7805a3f37ae53 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -118,6 +118,18 @@ org.apache.zookeeper zookeeper + + io.jsonwebtoken + jjwt-api + + + io.jsonwebtoken + jjwt-impl + + + io.jsonwebtoken + jjwt-jackson + @@ -381,6 +393,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} @@ -546,6 +568,7 @@ org.eclipse.jetty:jetty-util org.eclipse.jetty:jetty-server com.google.guava:guava + com.google.guava:failureaccess com.google.protobuf:* @@ -599,6 +622,12 @@ .sh + + jjwt + + compile + + sparkr diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index 4e251a1c2901b..412d612c7f1d5 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -17,6 +17,7 @@ package org.apache.spark.io; import org.apache.spark.storage.StorageUtils; +import org.apache.spark.unsafe.Platform; import java.io.File; import java.io.IOException; @@ -47,7 +48,7 @@ public final class NioBufferedFileInputStream extends InputStream { private final FileChannel fileChannel; public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException { - byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes); + byteBuffer = Platform.allocateDirectBuffer(bufferSizeInBytes); fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); byteBuffer.flip(); this.cleanable = CLEANER.register(this, new ResourceCleaner(fileChannel, byteBuffer)); diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 5e9f1b78273a5..7dd87df713e6e 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -120,7 +120,8 @@ private boolean isEndOfStream() { private void checkReadException() throws IOException { if (readAborted) { - Throwables.propagateIfPossible(readException, IOException.class); + Throwables.throwIfInstanceOf(readException, IOException.class); + Throwables.throwIfUnchecked(readException); throw new IOException(readException); } } diff --git a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java index 8ec5c2221b6e9..0e35ebecfd270 100644 --- a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java +++ b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java @@ -32,14 +32,6 @@ public final class SparkOutOfMemoryError extends OutOfMemoryError implements Spa String errorClass; Map messageParameters; - public SparkOutOfMemoryError(String s) { - super(s); - } - - public SparkOutOfMemoryError(OutOfMemoryError e) { - super(e.getMessage()); - } - public SparkOutOfMemoryError(String errorClass, Map messageParameters) { super(SparkThrowableHelper.getMessage(errorClass, messageParameters)); this.errorClass = errorClass; @@ -52,7 +44,7 @@ public Map getMessageParameters() { } @Override - public String getErrorClass() { + public String getCondition() { return errorClass; } } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index fe798e40a6ad7..bd9f58bf7415f 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -21,13 +21,7 @@ import java.io.InterruptedIOException; import java.io.IOException; import java.nio.channels.ClosedByInterruptException; -import java.util.Arrays; -import java.util.ArrayList; -import java.util.BitSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; +import java.util.*; import com.google.common.annotations.VisibleForTesting; @@ -122,6 +116,30 @@ public class TaskMemoryManager { */ private volatile long acquiredButNotUsed = 0L; + /** + * Current off heap memory usage by this task. + */ + private long currentOffHeapMemory = 0L; + + private final Object offHeapMemoryLock = new Object(); + + /* + * Current on heap memory usage by this task. + */ + private long currentOnHeapMemory = 0L; + + private final Object onHeapMemoryLock = new Object(); + + /** + * Peak off heap memory usage by this task. + */ + private volatile long peakOffHeapMemory = 0L; + + /** + * Peak on heap memory usage by this task. + */ + private volatile long peakOnHeapMemory = 0L; + /** * Construct a new TaskMemoryManager. */ @@ -202,6 +220,19 @@ public long acquireExecutionMemory(long required, MemoryConsumer requestingConsu logger.debug("Task {} acquired {} for {}", taskAttemptId, Utils.bytesToString(got), requestingConsumer); } + + if (mode == MemoryMode.OFF_HEAP) { + synchronized (offHeapMemoryLock) { + currentOffHeapMemory += got; + peakOffHeapMemory = Math.max(peakOffHeapMemory, currentOffHeapMemory); + } + } else { + synchronized (onHeapMemoryLock) { + currentOnHeapMemory += got; + peakOnHeapMemory = Math.max(peakOnHeapMemory, currentOnHeapMemory); + } + } + return got; } } @@ -254,8 +285,12 @@ private long trySpillAndAcquire( logger.error("error while calling spill() on {}", e, MDC.of(LogKeys.MEMORY_CONSUMER$.MODULE$, consumerToSpill)); // checkstyle.off: RegexpSinglelineJava - throw new SparkOutOfMemoryError("error while calling spill() on " + consumerToSpill + " : " - + e.getMessage()); + throw new SparkOutOfMemoryError( + "_LEGACY_ERROR_TEMP_3300", + new HashMap() {{ + put("consumerToSpill", consumerToSpill.toString()); + put("message", e.getMessage()); + }}); // checkstyle.on: RegexpSinglelineJava } } @@ -269,6 +304,15 @@ public void releaseExecutionMemory(long size, MemoryConsumer consumer) { consumer); } memoryManager.releaseExecutionMemory(size, taskAttemptId, consumer.getMode()); + if (consumer.getMode() == MemoryMode.OFF_HEAP) { + synchronized (offHeapMemoryLock) { + currentOffHeapMemory -= size; + } + } else { + synchronized (onHeapMemoryLock) { + currentOnHeapMemory -= size; + } + } } /** @@ -507,4 +551,19 @@ public long getMemoryConsumptionForThisTask() { public MemoryMode getTungstenMemoryMode() { return tungstenMemoryMode; } + + /** + * Returns peak task-level off-heap memory usage in bytes. + * + */ + public long getPeakOnHeapExecutionMemory() { + return peakOnHeapMemory; + } + + /** + * Returns peak task-level on-heap memory usage in bytes. + */ + public long getPeakOffHeapExecutionMemory() { + return peakOffHeapMemory; + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 13fd18c0942b1..ac9d335d63591 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -87,7 +87,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; private final int initialSortBufferSize; - private final int inputBufferSizeInBytes; + private final int mergeBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -140,8 +140,8 @@ public UnsafeShuffleWriter( this.transferToEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_MERGE_PREFER_NIO()); this.initialSortBufferSize = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); - this.inputBufferSizeInBytes = - (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.mergeBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_MERGE_BUFFER_SIZE()) * 1024; open(); } @@ -372,7 +372,7 @@ private void mergeSpillsWithFileStream( for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new NioBufferedFileInputStream( spills[i].file, - inputBufferSizeInBytes); + mergeBufferSizeInBytes); // Only convert the partitionLengths when debug level is enabled. if (logger.isDebugEnabled()) { logger.debug("Partition lengths for mapId {} in Spill {}: {}", mapId, i, diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 7579c0aefb250..761ced66f78cf 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -18,6 +18,7 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; +import java.util.HashMap; import java.util.LinkedList; import javax.annotation.Nullable; @@ -215,7 +216,7 @@ public void expandPointerArray(LongArray newArray) { if (array != null) { if (newArray.size() < array.size()) { // checkstyle.off: RegexpSinglelineJava - throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); + throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3301", new HashMap()); // checkstyle.on: RegexpSinglelineJava } Platform.copyMemory( diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.css deleted file mode 100644 index 6db36f6e75d39..0000000000000 --- a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.css +++ /dev/null @@ -1 +0,0 @@ -:root{--dt-row-selected: 2, 117, 216;--dt-row-selected-text: 255, 255, 255;--dt-row-selected-link: 9, 10, 11;--dt-row-stripe: 0, 0, 0;--dt-row-hover: 0, 0, 0;--dt-column-ordering: 0, 0, 0;--dt-html-background: white}:root.dark{--dt-html-background: rgb(33, 37, 41)}table.dataTable td.dt-control{text-align:center;cursor:pointer}table.dataTable td.dt-control:before{display:inline-block;color:rgba(0, 0, 0, 0.5);content:"►"}table.dataTable tr.dt-hasChild td.dt-control:before{content:"▼"}html.dark table.dataTable td.dt-control:before{color:rgba(255, 255, 255, 0.5)}html.dark table.dataTable tr.dt-hasChild td.dt-control:before{color:rgba(255, 255, 255, 0.5)}table.dataTable thead>tr>th.sorting,table.dataTable thead>tr>th.sorting_asc,table.dataTable thead>tr>th.sorting_desc,table.dataTable thead>tr>th.sorting_asc_disabled,table.dataTable thead>tr>th.sorting_desc_disabled,table.dataTable thead>tr>td.sorting,table.dataTable thead>tr>td.sorting_asc,table.dataTable thead>tr>td.sorting_desc,table.dataTable thead>tr>td.sorting_asc_disabled,table.dataTable thead>tr>td.sorting_desc_disabled{cursor:pointer;position:relative;padding-right:26px}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after{position:absolute;display:block;opacity:.125;right:10px;line-height:9px;font-size:.8em}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:before{bottom:50%;content:"▲";content:"▲"/""}table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:after{top:50%;content:"▼";content:"▼"/""}table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:after{opacity:.6}table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting_asc_disabled:before{display:none}table.dataTable thead>tr>th:active,table.dataTable thead>tr>td:active{outline:none}div.dataTables_scrollBody>table.dataTable>thead>tr>th:before,div.dataTables_scrollBody>table.dataTable>thead>tr>th:after,div.dataTables_scrollBody>table.dataTable>thead>tr>td:before,div.dataTables_scrollBody>table.dataTable>thead>tr>td:after{display:none}div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:2px}div.dataTables_processing>div:last-child{position:relative;width:80px;height:15px;margin:1em auto}div.dataTables_processing>div:last-child>div{position:absolute;top:0;width:13px;height:13px;border-radius:50%;background:rgb(2, 117, 216);background:rgb(var(--dt-row-selected));animation-timing-function:cubic-bezier(0, 1, 1, 0)}div.dataTables_processing>div:last-child>div:nth-child(1){left:8px;animation:datatables-loader-1 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(2){left:8px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(3){left:32px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(4){left:56px;animation:datatables-loader-3 .6s infinite}@keyframes datatables-loader-1{0%{transform:scale(0)}100%{transform:scale(1)}}@keyframes datatables-loader-3{0%{transform:scale(1)}100%{transform:scale(0)}}@keyframes datatables-loader-2{0%{transform:translate(0, 0)}100%{transform:translate(24px, 0)}}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th,table.dataTable thead td,table.dataTable tfoot th,table.dataTable tfoot td{text-align:left}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable{clear:both;margin-top:6px !important;margin-bottom:6px !important;max-width:none !important;border-collapse:separate !important;border-spacing:0}table.dataTable td,table.dataTable th{-webkit-box-sizing:content-box;box-sizing:content-box}table.dataTable td.dataTables_empty,table.dataTable th.dataTables_empty{text-align:center}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.table-striped>tbody>tr:nth-of-type(2n+1){background-color:transparent}table.dataTable>tbody>tr{background-color:transparent}table.dataTable>tbody>tr.selected>*{box-shadow:inset 0 0 0 9999px rgb(2, 117, 216);box-shadow:inset 0 0 0 9999px rgb(var(--dt-row-selected));color:rgb(255, 255, 255);color:rgb(var(--dt-row-selected-text))}table.dataTable>tbody>tr.selected a{color:rgb(9, 10, 11);color:rgb(var(--dt-row-selected-link))}table.dataTable.table-striped>tbody>tr.odd>*{box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-stripe), 0.05)}table.dataTable.table-striped>tbody>tr.odd.selected>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.95);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.95)}table.dataTable.table-hover>tbody>tr:hover>*{box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-hover), 0.075)}table.dataTable.table-hover>tbody>tr.selected:hover>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.975);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.975)}div.dataTables_wrapper div.dataTables_length label{font-weight:normal;text-align:left;white-space:nowrap}div.dataTables_wrapper div.dataTables_length select{width:auto;display:inline-block}div.dataTables_wrapper div.dataTables_filter{text-align:right}div.dataTables_wrapper div.dataTables_filter label{font-weight:normal;white-space:nowrap;text-align:left}div.dataTables_wrapper div.dataTables_filter input{margin-left:.5em;display:inline-block;width:auto}div.dataTables_wrapper div.dataTables_info{padding-top:.85em}div.dataTables_wrapper div.dataTables_paginate{margin:0;white-space:nowrap;text-align:right}div.dataTables_wrapper div.dataTables_paginate ul.pagination{margin:2px 0;white-space:nowrap;justify-content:flex-end}div.dataTables_wrapper div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:1em 0}div.dataTables_scrollHead table.dataTable{margin-bottom:0 !important}div.dataTables_scrollBody>table{border-top:none;margin-top:0 !important;margin-bottom:0 !important}div.dataTables_scrollBody>table>thead .sorting:before,div.dataTables_scrollBody>table>thead .sorting_asc:before,div.dataTables_scrollBody>table>thead .sorting_desc:before,div.dataTables_scrollBody>table>thead .sorting:after,div.dataTables_scrollBody>table>thead .sorting_asc:after,div.dataTables_scrollBody>table>thead .sorting_desc:after{display:none}div.dataTables_scrollBody>table>tbody tr:first-child th,div.dataTables_scrollBody>table>tbody tr:first-child td{border-top:none}div.dataTables_scrollFoot>.dataTables_scrollFootInner{box-sizing:content-box}div.dataTables_scrollFoot>.dataTables_scrollFootInner>table{margin-top:0 !important;border-top:none}@media screen and (max-width: 767px){div.dataTables_wrapper div.dataTables_length,div.dataTables_wrapper div.dataTables_filter,div.dataTables_wrapper div.dataTables_info,div.dataTables_wrapper div.dataTables_paginate{text-align:center}div.dataTables_wrapper div.dataTables_paginate ul.pagination{justify-content:center !important}}table.dataTable.table-sm>thead>tr>th:not(.sorting_disabled){padding-right:20px}table.table-bordered.dataTable{border-right-width:0}table.table-bordered.dataTable th,table.table-bordered.dataTable td{border-left-width:0}table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable td:last-child,table.table-bordered.dataTable td:last-child{border-right-width:1px}table.table-bordered.dataTable tbody th,table.table-bordered.dataTable tbody td{border-bottom-width:0}div.dataTables_scrollHead table.table-bordered{border-bottom-width:0}div.table-responsive>div.dataTables_wrapper>div.row{margin:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:first-child{padding-left:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:last-child{padding-right:0} diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.css new file mode 100644 index 0000000000000..d344f78a39748 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.css @@ -0,0 +1 @@ +:root{--dt-row-selected: 2, 117, 216;--dt-row-selected-text: 255, 255, 255;--dt-row-selected-link: 9, 10, 11;--dt-row-stripe: 0, 0, 0;--dt-row-hover: 0, 0, 0;--dt-column-ordering: 0, 0, 0;--dt-html-background: white}:root.dark{--dt-html-background: rgb(33, 37, 41)}table.dataTable td.dt-control{text-align:center;cursor:pointer}table.dataTable td.dt-control:before{display:inline-block;color:rgba(0, 0, 0, 0.5);content:"▶"}table.dataTable tr.dt-hasChild td.dt-control:before{content:"▼"}html.dark table.dataTable td.dt-control:before,:root[data-bs-theme=dark] table.dataTable td.dt-control:before{color:rgba(255, 255, 255, 0.5)}html.dark table.dataTable tr.dt-hasChild td.dt-control:before,:root[data-bs-theme=dark] table.dataTable tr.dt-hasChild td.dt-control:before{color:rgba(255, 255, 255, 0.5)}table.dataTable thead>tr>th.sorting,table.dataTable thead>tr>th.sorting_asc,table.dataTable thead>tr>th.sorting_desc,table.dataTable thead>tr>th.sorting_asc_disabled,table.dataTable thead>tr>th.sorting_desc_disabled,table.dataTable thead>tr>td.sorting,table.dataTable thead>tr>td.sorting_asc,table.dataTable thead>tr>td.sorting_desc,table.dataTable thead>tr>td.sorting_asc_disabled,table.dataTable thead>tr>td.sorting_desc_disabled{cursor:pointer;position:relative;padding-right:26px}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after{position:absolute;display:block;opacity:.125;right:10px;line-height:9px;font-size:.8em}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:before{bottom:50%;content:"▲";content:"▲"/""}table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:after{top:50%;content:"▼";content:"▼"/""}table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:after{opacity:.6}table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting_asc_disabled:before{display:none}table.dataTable thead>tr>th:active,table.dataTable thead>tr>td:active{outline:none}div.dataTables_scrollBody>table.dataTable>thead>tr>th:before,div.dataTables_scrollBody>table.dataTable>thead>tr>th:after,div.dataTables_scrollBody>table.dataTable>thead>tr>td:before,div.dataTables_scrollBody>table.dataTable>thead>tr>td:after{display:none}div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:2px;z-index:10}div.dataTables_processing>div:last-child{position:relative;width:80px;height:15px;margin:1em auto}div.dataTables_processing>div:last-child>div{position:absolute;top:0;width:13px;height:13px;border-radius:50%;background:rgb(2, 117, 216);background:rgb(var(--dt-row-selected));animation-timing-function:cubic-bezier(0, 1, 1, 0)}div.dataTables_processing>div:last-child>div:nth-child(1){left:8px;animation:datatables-loader-1 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(2){left:8px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(3){left:32px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(4){left:56px;animation:datatables-loader-3 .6s infinite}@keyframes datatables-loader-1{0%{transform:scale(0)}100%{transform:scale(1)}}@keyframes datatables-loader-3{0%{transform:scale(1)}100%{transform:scale(0)}}@keyframes datatables-loader-2{0%{transform:translate(0, 0)}100%{transform:translate(24px, 0)}}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th,table.dataTable thead td,table.dataTable tfoot th,table.dataTable tfoot td{text-align:left}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable{clear:both;margin-top:6px !important;margin-bottom:6px !important;max-width:none !important;border-collapse:separate !important;border-spacing:0}table.dataTable td,table.dataTable th{-webkit-box-sizing:content-box;box-sizing:content-box}table.dataTable td.dataTables_empty,table.dataTable th.dataTables_empty{text-align:center}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.table-striped>tbody>tr:nth-of-type(2n+1){background-color:transparent}table.dataTable>tbody>tr{background-color:transparent}table.dataTable>tbody>tr.selected>*{box-shadow:inset 0 0 0 9999px rgb(2, 117, 216);box-shadow:inset 0 0 0 9999px rgb(var(--dt-row-selected));color:rgb(255, 255, 255);color:rgb(var(--dt-row-selected-text))}table.dataTable>tbody>tr.selected a{color:rgb(9, 10, 11);color:rgb(var(--dt-row-selected-link))}table.dataTable.table-striped>tbody>tr.odd>*{box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-stripe), 0.05)}table.dataTable.table-striped>tbody>tr.odd.selected>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.95);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.95)}table.dataTable.table-hover>tbody>tr:hover>*{box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-hover), 0.075)}table.dataTable.table-hover>tbody>tr.selected:hover>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.975);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.975)}div.dataTables_wrapper div.dataTables_length label{font-weight:normal;text-align:left;white-space:nowrap}div.dataTables_wrapper div.dataTables_length select{width:auto;display:inline-block}div.dataTables_wrapper div.dataTables_filter{text-align:right}div.dataTables_wrapper div.dataTables_filter label{font-weight:normal;white-space:nowrap;text-align:left}div.dataTables_wrapper div.dataTables_filter input{margin-left:.5em;display:inline-block;width:auto}div.dataTables_wrapper div.dataTables_info{padding-top:.85em}div.dataTables_wrapper div.dataTables_paginate{margin:0;white-space:nowrap;text-align:right}div.dataTables_wrapper div.dataTables_paginate ul.pagination{margin:2px 0;white-space:nowrap;justify-content:flex-end}div.dataTables_wrapper div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:1em 0}div.dataTables_scrollHead table.dataTable{margin-bottom:0 !important}div.dataTables_scrollBody>table{border-top:none;margin-top:0 !important;margin-bottom:0 !important}div.dataTables_scrollBody>table>thead .sorting:before,div.dataTables_scrollBody>table>thead .sorting_asc:before,div.dataTables_scrollBody>table>thead .sorting_desc:before,div.dataTables_scrollBody>table>thead .sorting:after,div.dataTables_scrollBody>table>thead .sorting_asc:after,div.dataTables_scrollBody>table>thead .sorting_desc:after{display:none}div.dataTables_scrollBody>table>tbody tr:first-child th,div.dataTables_scrollBody>table>tbody tr:first-child td{border-top:none}div.dataTables_scrollFoot>.dataTables_scrollFootInner{box-sizing:content-box}div.dataTables_scrollFoot>.dataTables_scrollFootInner>table{margin-top:0 !important;border-top:none}@media screen and (max-width: 767px){div.dataTables_wrapper div.dataTables_length,div.dataTables_wrapper div.dataTables_filter,div.dataTables_wrapper div.dataTables_info,div.dataTables_wrapper div.dataTables_paginate{text-align:center}div.dataTables_wrapper div.dataTables_paginate ul.pagination{justify-content:center !important}}table.dataTable.table-sm>thead>tr>th:not(.sorting_disabled){padding-right:20px}table.table-bordered.dataTable{border-right-width:0}table.table-bordered.dataTable th,table.table-bordered.dataTable td{border-left-width:0}table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable td:last-child,table.table-bordered.dataTable td:last-child{border-right-width:1px}table.table-bordered.dataTable tbody th,table.table-bordered.dataTable tbody td{border-bottom-width:0}div.dataTables_scrollHead table.table-bordered{border-bottom-width:0}div.table-responsive>div.dataTables_wrapper>div.row{margin:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:first-child{padding-left:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:last-child{padding-right:0} diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.js similarity index 83% rename from core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.js rename to core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.js index 04de9c97cc514..c99016713ab1f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.js @@ -1,4 +1,4 @@ /*! DataTables Bootstrap 4 integration * ©2011-2017 SpryMedia Ltd - datatables.net/license */ -!function(t){var n,o;"function"==typeof define&&define.amd?define(["jquery","datatables.net"],function(e){return t(e,window,document)}):"object"==typeof exports?(n=require("jquery"),o=function(e,a){a.fn.dataTable||require("datatables.net")(e,a)},"undefined"==typeof window?module.exports=function(e,a){return e=e||window,a=a||n(e),o(e,a),t(a,0,e.document)}:(o(window,n),module.exports=t(n,window,window.document))):t(jQuery,window,document)}(function(x,e,n,o){"use strict";var r=x.fn.dataTable;return x.extend(!0,r.defaults,{dom:"<'row'<'col-sm-12 col-md-6'l><'col-sm-12 col-md-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-12 col-md-5'i><'col-sm-12 col-md-7'p>>",renderer:"bootstrap"}),x.extend(r.ext.classes,{sWrapper:"dataTables_wrapper dt-bootstrap4",sFilterInput:"form-control form-control-sm",sLengthSelect:"custom-select custom-select-sm form-control form-control-sm",sProcessing:"dataTables_processing card",sPageButton:"paginate_button page-item"}),r.ext.renderer.pageButton.bootstrap=function(i,e,d,a,l,c){function u(e,a){for(var t,n,o=function(e){e.preventDefault(),x(e.currentTarget).hasClass("disabled")||m.page()==e.data.action||m.page(e.data.action).draw("page")},r=0,s=a.length;r",{class:b.sPageButton+" "+f,id:0===d&&"string"==typeof t?i.sTableId+"_"+t:null}).append(x("",{href:n?null:"#","aria-controls":i.sTableId,"aria-disabled":n?"true":null,"aria-label":w[t],role:"link","aria-current":"active"===f?"page":null,"data-dt-idx":t,tabindex:i.iTabIndex,class:"page-link"}).html(p)).appendTo(e),i.oApi._fnBindAction(n,{action:t},o))}}var p,f,t,m=new r.Api(i),b=i.oClasses,g=i.oLanguage.oPaginate,w=i.oLanguage.oAria.paginate||{};try{t=x(e).find(n.activeElement).data("dt-idx")}catch(e){}u(x(e).empty().html('
{ - if (app.isFinished) { + if (app.isFinished || app.desc.appUiUrl.isBlank()) { app.desc.name } else { sparkVersion, SparkConf} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ +import org.apache.spark.internal.config.{MASTER_REST_SERVER_FILTERS, MASTER_REST_SERVER_MAX_THREADS, MASTER_REST_SERVER_VIRTUAL_THREADS} import org.apache.spark.util.Utils /** @@ -59,7 +64,8 @@ private[spark] abstract class RestSubmissionServer( protected val clearRequestServlet: ClearRequestServlet protected val readyzRequestServlet: ReadyzRequestServlet - private var _server: Option[Server] = None + // Visible for testing + private[rest] var _server: Option[Server] = None // A mapping from URL prefixes to servlets that serve them. Exposed for testing. protected val baseContext = s"/${RestSubmissionServer.PROTOCOL_VERSION}/submissions" @@ -87,7 +93,14 @@ private[spark] abstract class RestSubmissionServer( * Return a 2-tuple of the started server and the bound port. */ private def doStart(startPort: Int): (Server, Int) = { - val threadPool = new QueuedThreadPool + val threadPool = new QueuedThreadPool(masterConf.get(MASTER_REST_SERVER_MAX_THREADS)) + threadPool.setName(getClass().getSimpleName()) + if (Utils.isJavaVersionAtLeast21 && masterConf.get(MASTER_REST_SERVER_VIRTUAL_THREADS)) { + val newVirtualThreadPerTaskExecutor = + classOf[Executors].getMethod("newVirtualThreadPerTaskExecutor") + val service = newVirtualThreadPerTaskExecutor.invoke(null).asInstanceOf[ExecutorService] + threadPool.setVirtualThreadsExecutor(service) + } threadPool.setDaemon(true) val server = new Server(threadPool) @@ -111,12 +124,26 @@ private[spark] abstract class RestSubmissionServer( contextToServlet.foreach { case (prefix, servlet) => mainHandler.addServlet(new ServletHolder(servlet), prefix) } + addFilters(mainHandler) server.setHandler(mainHandler) server.start() val boundPort = connector.getLocalPort (server, boundPort) } + /** + * Add filters, if any, to the given ServletContextHandlers. + */ + private def addFilters(handler: ServletContextHandler): Unit = { + masterConf.get(MASTER_REST_SERVER_FILTERS).foreach { filter => + val params = masterConf.getAllWithPrefix(s"spark.$filter.param.").toMap + val holder = new FilterHolder() + holder.setClassName(filter) + params.foreach { case (k, v) => holder.setInitParameter(k, v) } + handler.addFilter(holder, "/*", EnumSet.allOf(classOf[DispatcherType])) + } + } + def stop(): Unit = { _server.foreach(_.stop()) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 8e3dcdc9d3787..5e29199a352ff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -22,7 +22,7 @@ import java.io.File import jakarta.servlet.http.HttpServletResponse import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} -import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription, SparkSubmit} import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.internal.config import org.apache.spark.launcher.{JavaModuleOptions, SparkLauncher} @@ -174,6 +174,11 @@ private[rest] class StandaloneSubmitRequestServlet( conf: SparkConf) extends SubmitRequestServlet { + private def replacePlaceHolder(variable: String) = variable match { + case s"{{$name}}" if System.getenv(name) != null => System.getenv(name) + case _ => variable + } + /** * Build a driver description from the fields specified in the submit request. * @@ -196,6 +201,7 @@ private[rest] class StandaloneSubmitRequestServlet( // Optional fields val sparkProperties = request.sparkProperties + .map(x => (x._1, replacePlaceHolder(x._2))) val driverMemory = sparkProperties.get(config.DRIVER_MEMORY.key) val driverCores = sparkProperties.get(config.DRIVER_CORES.key) val driverDefaultJavaOptions = sparkProperties.get(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS) @@ -212,10 +218,13 @@ private[rest] class StandaloneSubmitRequestServlet( val (_, masterPort) = Utils.extractHostPortFromSparkUrl(masterUrl) val updatedMasters = masters.map( _.replace(s":$masterRestPort", s":$masterPort")).getOrElse(masterUrl) - val appArgs = request.appArgs + val appArgs = Option(request.appArgs).getOrElse(Array[String]()) // Filter SPARK_LOCAL_(IP|HOSTNAME) environment variables from being set on the remote system. + // In addition, the placeholders are replaced into the values of environment variables. val environmentVariables = - request.environmentVariables.filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) + Option(request.environmentVariables).getOrElse(Map.empty[String, String]) + .filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) + .map(x => (x._1, replacePlaceHolder(x._2))) // Construct driver description val conf = new SparkConf(false) @@ -229,9 +238,16 @@ private[rest] class StandaloneSubmitRequestServlet( val sparkJavaOpts = Utils.sparkJavaOpts(conf) val javaModuleOptions = JavaModuleOptions.defaultModuleOptionArray().toImmutableArraySeq val javaOpts = javaModuleOptions ++ sparkJavaOpts ++ defaultJavaOpts ++ extraJavaOpts + val sparkSubmitOpts = if (mainClass.equals(classOf[SparkSubmit].getName)) { + sparkProperties.get("spark.app.name") + .map { v => Seq("-c", s"spark.app.name=$v") } + .getOrElse(Seq.empty[String]) + } else { + Seq.empty[String] + } val command = new Command( "org.apache.spark.deploy.worker.DriverWrapper", - Seq("{{WORKER_URL}}", "{{USER_JAR}}", mainClass) ++ appArgs, // args to the DriverWrapper + Seq("{{WORKER_URL}}", "{{USER_JAR}}", mainClass) ++ sparkSubmitOpts ++ appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 7f462148c71a1..c6ff3dbb33cbd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -47,9 +47,6 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { super.doValidate() assert(sparkProperties != null, "No Spark properties set!") assertFieldIsSet(appResource, "appResource") - assertFieldIsSet(appArgs, "appArgs") - assertFieldIsSet(environmentVariables, "environmentVariables") - assertPropertyIsSet("spark.app.name") assertPropertyIsBoolean(config.DRIVER_SUPERVISE.key) assertPropertyIsNumeric(config.DRIVER_CORES.key) assertPropertyIsNumeric(config.CORES_MAX.key) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index de517acbf8c5b..b9d88266ed538 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -151,7 +151,7 @@ private[spark] class HadoopDelegationTokenManager( creds.addAll(newTokens) } }) - if(!currentUser.equals(freshUGI)) { + if (!currentUser.equals(freshUGI)) { FileSystem.closeAllForUGI(freshUGI) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index b47f9e5a43afc..fc750b54d0b8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -150,6 +150,9 @@ private[deploy] class HadoopFSDelegationTokenProvider val interval = newExpiration - getIssueDate(tokenKind, identifier) logInfo(log"Renewal interval is ${MDC(TOTAL_TIME, interval)} for" + log" token ${MDC(TOKEN_KIND, tokenKind)}") + // The token here is only used to obtain renewal intervals. We should cancel it in + // a timely manner to avoid causing additional pressure on the server. + token.cancel(hadoopConf) interval }.toOption } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index bb96ecb38a640..ca0e024ad1aed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import scala.jdk.CollectionConverters._ -import com.google.common.io.Files +import com.google.common.io.{Files, FileWriteMode} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} @@ -216,7 +216,7 @@ private[deploy] class DriverRunner( val redactedCommand = Utils.redactCommandLineArgs(conf, builder.command.asScala.toSeq) .mkString("\"", "\" \"", "\"") val header = "Launch Command: %s\n%s\n\n".format(redactedCommand, "=" * 40) - Files.append(header, stderr, StandardCharsets.UTF_8) + Files.asCharSink(stderr, StandardCharsets.UTF_8, FileWriteMode.APPEND).write(header) CommandUtils.redirectStream(process.getErrorStream, stderr) } runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 8d0fb7a54f72a..d21904dd16ea7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -191,7 +191,7 @@ private[deploy] class ExecutorRunner( stdoutAppender = FileAppender(process.getInputStream, stdout, conf, true) val stderr = new File(executorDir, "stderr") - Files.write(header, stderr, StandardCharsets.UTF_8) + Files.asCharSink(stderr, StandardCharsets.UTF_8).write(header) stderrAppender = FileAppender(process.getErrorStream, stderr, conf, true) state = ExecutorState.RUNNING diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 7ff7974ab59f6..b2ec23887a400 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -980,6 +980,7 @@ private[deploy] object Worker extends Logging { def main(argStrings: Array[String]): Unit = { Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler( exitOnUncaughtException = false)) + Utils.resetStructuredLogging() Utils.initDaemon(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 94a27e1a3e6da..87ca01fe82a97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -22,6 +22,7 @@ import java.lang.management.ManagementFactory import scala.annotation.tailrec import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Worker._ import org.apache.spark.util.{IntParam, MemoryParam, Utils} @@ -59,6 +60,9 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { // This mutates the SparkConf, so all accesses to it must be made after this line propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(conf) + Logging.uninitialize() conf.get(WORKER_UI_PORT).foreach { webUiPort = _ } @@ -66,11 +70,6 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { @tailrec private def parse(args: List[String]): Unit = args match { - case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value) - host = value - parse(tail) - case ("--host" | "-h") :: value :: tail => Utils.checkHost(value) host = value @@ -133,7 +132,6 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { " -c CORES, --cores CORES Number of cores to use\n" + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + - " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + " --webui-port PORT Port for web UI (default: 8081)\n" + diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 1b1053a7013e0..e880cf8da9ec2 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -422,6 +422,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { backendCreateFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) => CoarseGrainedExecutorBackend): Unit = { + Utils.resetStructuredLogging() Utils.initDaemon(log) SparkHadoopUtil.get.runAsSparkUser { () => @@ -467,6 +468,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } } + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(driverConf) + Logging.uninitialize() + cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 586a8a7db28a3..c299f38526aeb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -22,7 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer -import java.util.{Locale, Properties} +import java.util.{Locale, Properties, Timer, TimerTask} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.locks.ReentrantLock @@ -177,7 +177,7 @@ private[spark] class Executor( // Whether to monitor killed / interrupted tasks private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED) - private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH) + private val killOnFatalErrorDepth = conf.get(KILL_ON_FATAL_ERROR_DEPTH) private val systemLoader = Utils.getContextOrSparkClassLoader @@ -209,9 +209,10 @@ private[spark] class Executor( // The default isolation group val defaultSessionState: IsolatedSessionState = newSessionState(JobArtifactState("default", None)) + private val cacheExpiryTime = 30 * 60 * 1000 val isolatedSessionCache: Cache[String, IsolatedSessionState] = CacheBuilder.newBuilder() .maximumSize(100) - .expireAfterAccess(30, TimeUnit.MINUTES) + .expireAfterAccess(cacheExpiryTime, TimeUnit.MILLISECONDS) .removalListener(new RemovalListener[String, IsolatedSessionState]() { override def onRemoval( notification: RemovalNotification[String, IsolatedSessionState]): Unit = { @@ -295,6 +296,8 @@ private[spark] class Executor( private val pollOnHeartbeat = if (METRICS_POLLING_INTERVAL_MS > 0) false else true + private val timer = new Timer("executor-state-timer", true) + // Poller for the memory metrics. Visible for testing. private[executor] val metricsPoller = new ExecutorMetricsPoller( env.memoryManager, @@ -445,6 +448,9 @@ private[spark] class Executor( case NonFatal(e) => logWarning("Unable to stop heartbeater", e) } + if (timer != null) { + timer.cancel() + } ShuffleBlockPusher.stop() if (threadPool != null) { threadPool.shutdown() @@ -550,7 +556,7 @@ private[spark] class Executor( // Collect latest accumulator values to report back to the driver val accums: Seq[AccumulatorV2[_, _]] = Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val accUpdates = accums.map(acc => acc.toInfoUpdate) setTaskFinishedAndClearInterruptStatus() (accums, accUpdates) @@ -559,9 +565,17 @@ private[spark] class Executor( override def run(): Unit = { // Classloader isolation + var maybeTimerTask: Option[TimerTask] = None val isolatedSession = taskDescription.artifacts.state match { case Some(jobArtifactState) => - isolatedSessionCache.get(jobArtifactState.uuid, () => newSessionState(jobArtifactState)) + val state = isolatedSessionCache.get( + jobArtifactState.uuid, () => newSessionState(jobArtifactState)) + maybeTimerTask = Some(new TimerTask { + // Resets the expire time till the task ends. + def run(): Unit = isolatedSessionCache.getIfPresent(jobArtifactState.uuid) + }) + maybeTimerTask.foreach(timer.schedule(_, cacheExpiryTime / 10, cacheExpiryTime / 10)) + state case _ => defaultSessionState } @@ -706,6 +720,8 @@ private[spark] class Executor( task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) task.metrics.setResultSerializationTime(TimeUnit.NANOSECONDS.toMillis( afterSerializationNs - beforeSerializationNs)) + task.metrics.setPeakOnHeapExecutionMemory(taskMemoryManager.getPeakOnHeapExecutionMemory) + task.metrics.setPeakOffHeapExecutionMemory(taskMemoryManager.getPeakOffHeapExecutionMemory) // Expose task metrics using the Dropwizard metrics system. // Update task metrics counters executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime) @@ -860,6 +876,7 @@ private[spark] class Executor( uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { + maybeTimerTask.foreach(_.cancel) cleanMDCForTask(taskName, mdcProperties) runningTasks.remove(taskId) if (taskStarted) { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index 5300598ef53eb..7f0be5c1b704f 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -49,6 +49,9 @@ object ExecutorExitCode { * TaskReaper. */ val KILLED_BY_TASK_REAPER = 57 + /** Executor is unable to re-register BlockManager. */ + val BLOCK_MANAGER_REREGISTRATION_FAILED = 58 + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -63,6 +66,8 @@ object ExecutorExitCode { "ExternalBlockStore failed to create a local temporary directory." case HEARTBEAT_FAILURE => "Unable to send heartbeats to driver." + case BLOCK_MANAGER_REREGISTRATION_FAILED => + "Executor killed due to a failure of block manager re-registration." case KILLED_BY_TASK_REAPER => "Executor killed by TaskReaper." case _ => diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 227e7d84654dd..582c93007f4f5 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -29,8 +29,6 @@ import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} import org.apache.spark.util._ -import org.apache.spark.util.ArrayImplicits._ - /** * :: DeveloperApi :: @@ -58,6 +56,8 @@ class TaskMetrics private[spark] () extends Serializable { private val _memoryBytesSpilled = new LongAccumulator private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator + private val _peakOnHeapExecutionMemory = new LongAccumulator + private val _peakOffHeapExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] /** @@ -111,9 +111,22 @@ class TaskMetrics private[spark] () extends Serializable { * joins. The value of this accumulator should be approximately the sum of the peak sizes * across all such data structures created in this task. For SQL jobs, this only tracks all * unsafe operators and ExternalSort. + * This is not equal to peakOnHeapExecutionMemory + peakOffHeapExecutionMemory */ + // TODO: SPARK-48789: the naming is confusing since this does not really reflect the whole + // execution memory. We'd better deprecate this once we have a replacement. def peakExecutionMemory: Long = _peakExecutionMemory.sum + /** + * Peak on heap execution memory as tracked by TaskMemoryManager. + */ + def peakOnHeapExecutionMemory: Long = _peakOnHeapExecutionMemory.sum + + /** + * Peak off heap execution memory as tracked by TaskMemoryManager. + */ + def peakOffHeapExecutionMemory: Long = _peakOffHeapExecutionMemory.sum + /** * Storage statuses of any blocks that have been updated as a result of this task. * @@ -141,6 +154,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def setResultSerializationTime(v: Long): Unit = _resultSerializationTime.setValue(v) private[spark] def setPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.setValue(v) + private[spark] def setPeakOnHeapExecutionMemory(v: Long): Unit = + _peakOnHeapExecutionMemory.setValue(v) + private[spark] def setPeakOffHeapExecutionMemory(v: Long): Unit = + _peakOffHeapExecutionMemory.setValue(v) private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v) private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v) private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v) @@ -227,6 +244,8 @@ class TaskMetrics private[spark] () extends Serializable { MEMORY_BYTES_SPILLED -> _memoryBytesSpilled, DISK_BYTES_SPILLED -> _diskBytesSpilled, PEAK_EXECUTION_MEMORY -> _peakExecutionMemory, + PEAK_ON_HEAP_EXECUTION_MEMORY -> _peakOnHeapExecutionMemory, + PEAK_OFF_HEAP_EXECUTION_MEMORY -> _peakOffHeapExecutionMemory, UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses, shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched, shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched, @@ -272,8 +291,17 @@ class TaskMetrics private[spark] () extends Serializable { */ @transient private[spark] lazy val _externalAccums = new ArrayBuffer[AccumulatorV2[_, _]] - private[spark] def externalAccums: Seq[AccumulatorV2[_, _]] = withReadLock { - _externalAccums.toArray.toImmutableArraySeq + /** + * Perform an `op` conversion on the `_externalAccums` within the read lock. + * + * Note `op` is expected to not modify the `_externalAccums` and not being + * lazy evaluation for safe concern since `ArrayBuffer` is lazily evaluated. + * And we intentionally keeps `_externalAccums` as mutable instead of converting + * it to immutable for the performance concern. + */ + private[spark] def withExternalAccums[T](op: ArrayBuffer[AccumulatorV2[_, _]] => T) + : T = withReadLock { + op(_externalAccums) } private def withReadLock[B](fn: => B): B = { @@ -298,7 +326,9 @@ class TaskMetrics private[spark] () extends Serializable { _externalAccums += a } - private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums + private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = withReadLock { + internalAccums ++ _externalAccums + } private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = { // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index f0d6cba6ae734..3c3017a9a64c1 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -45,8 +45,8 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int): Unit = { - val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) - val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) + val defaultMaxSplitBytes = sc.conf.get(config.FILES_MAX_PARTITION_BYTES) + val openCostInBytes = sc.conf.get(config.FILES_OPEN_COST_IN_BYTES) val defaultParallelism = Math.max(sc.defaultParallelism, minPartitions) val files = listStatus(context).asScala val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum diff --git a/core/src/main/scala/org/apache/spark/internal/config/History.scala b/core/src/main/scala/org/apache/spark/internal/config/History.scala index 64a8681ca2954..bbd4afcaebab4 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/History.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/History.scala @@ -109,6 +109,12 @@ private[spark] object History { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10g") + val HISTORY_SERVER_UI_TITLE = ConfigBuilder("spark.history.ui.title") + .version("4.0.0") + .doc("Specifies the title of the History Server UI page.") + .stringConf + .createWithDefault("History Server") + val HISTORY_SERVER_UI_PORT = ConfigBuilder("spark.history.ui.port") .doc("Web UI port to bind Spark History Server") .version("1.0.0") @@ -201,6 +207,14 @@ private[spark] object History { .toSequence .createWithDefault(Nil) + val HISTORY_UI_MAX_APPS = ConfigBuilder("spark.history.ui.maxApplications") + .version("2.0.1") + .doc("The number of applications to display on the history summary page. Application UIs " + + "are still available by accessing their URLs directly even if they are not displayed on " + + "the history summary page.") + .intConf + .createWithDefault(Integer.MAX_VALUE) + val NUM_REPLAY_THREADS = ConfigBuilder("spark.history.fs.numReplayThreads") .version("2.0.0") .doc("Number of threads that will be used by history server to process event logs.") diff --git a/core/src/main/scala/org/apache/spark/internal/config/UI.scala b/core/src/main/scala/org/apache/spark/internal/config/UI.scala index c4c5796a2b278..fe5a52debdafc 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/UI.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/UI.scala @@ -244,6 +244,20 @@ private[spark] object UI { .checkValues(Set("ALLOW", "LOCAL", "DENY")) .createWithDefault("LOCAL") + val MASTER_UI_TITLE = ConfigBuilder("spark.master.ui.title") + .version("4.0.0") + .doc("Specifies the title of the Master UI page. If unset, `Spark Master at ` " + + "is used by default.") + .stringConf + .createOptional + + val MASTER_UI_VISIBLE_ENV_VAR_PREFIXES = ConfigBuilder("spark.master.ui.visibleEnvVarPrefixes") + .doc("Comma-separated list of key-prefix strings to show environment variables") + .version("4.0.0") + .stringConf + .toSequence + .createWithDefault(Seq.empty[String]) + val UI_SQL_GROUP_SUB_EXECUTION_ENABLED = ConfigBuilder("spark.ui.groupSQLSubExecutionEnabled") .doc("Whether to group sub executions together in SQL UI when they belong to the same " + "root execution") diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 6de500024816f..324ef701c4266 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -271,6 +271,18 @@ package object config { .toSequence .createWithDefault(GarbageCollectionMetrics.OLD_GENERATION_BUILTIN_GARBAGE_COLLECTORS) + private[spark] val EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS = + ConfigBuilder("spark.eventLog.includeTaskMetricsAccumulators") + .doc("Whether to include TaskMetrics' underlying accumulator values in the event log " + + "(as part of the Task/Stage/Job metrics' 'Accumulables' fields. The TaskMetrics " + + "values are already logged in the 'Task Metrics' fields (so the accumulator updates " + + "are redundant). This flag defaults to true for behavioral backwards compatibility " + + "for applications that might rely on the redundant logging. " + + "See SPARK-42204 for details.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite") .version("1.0.0") @@ -918,6 +930,18 @@ package object config { .booleanConf .createOptional + private[spark] val EXCLUDE_ON_FAILURE_ENABLED_APPLICATION = + ConfigBuilder("spark.excludeOnFailure.application.enabled") + .version("4.0.0") + .booleanConf + .createOptional + + private[spark] val EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE = + ConfigBuilder("spark.excludeOnFailure.taskAndStage.enabled") + .version("4.0.0") + .booleanConf + .createOptional + private[spark] val MAX_TASK_ATTEMPTS_PER_EXECUTOR = ConfigBuilder("spark.excludeOnFailure.task.maxTaskAttemptsPerExecutor") .version("3.1.0") @@ -1110,13 +1134,6 @@ package object config { .stringConf .createOptional - // To limit how many applications are shown in the History Server summary ui - private[spark] val HISTORY_UI_MAX_APPS = - ConfigBuilder("spark.history.ui.maxApplications") - .version("2.0.1") - .intConf - .createWithDefault(Integer.MAX_VALUE) - private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") .version("2.1.0") .booleanConf @@ -1369,7 +1386,6 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR = ConfigBuilder("spark.shuffle.accurateBlockSkewedFactor") - .internal() .doc("A shuffle block is considered as skewed and will be accurately recorded in " + "HighlyCompressedMapStatus if its size is larger than this factor multiplying " + "the median shuffle block size or SHUFFLE_ACCURATE_BLOCK_THRESHOLD. It is " + @@ -1473,6 +1489,14 @@ package object config { s" ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024}.") .createWithDefaultString("32k") + private[spark] val SHUFFLE_FILE_MERGE_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.file.merge.buffer") + .doc("Size of the in-memory buffer for each shuffle file input stream, in KiB unless " + + "otherwise specified. These buffers use off-heap buffers and are related to the number " + + "of files in the shuffle file. Too large buffers should be avoided.") + .version("4.0.0") + .fallbackConf(SHUFFLE_FILE_BUFFER_SIZE) + private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") .doc("(Deprecated since Spark 4.0, please use 'spark.shuffle.localDisk.file.output.buffer'.)") @@ -1617,8 +1641,7 @@ package object config { .version("3.2.0") .stringConf .transform(_.toUpperCase(Locale.ROOT)) - .checkValue(Set("ADLER32", "CRC32").contains, "Shuffle checksum algorithm " + - "should be either ADLER32 or CRC32.") + .checkValues(Set("ADLER32", "CRC32", "CRC32C")) .createWithDefault("ADLER32") private[spark] val SHUFFLE_COMPRESS = @@ -1964,6 +1987,26 @@ package object config { .intConf .createWithDefault(6066) + private[spark] val MASTER_REST_SERVER_MAX_THREADS = ConfigBuilder("spark.master.rest.maxThreads") + .doc("Maximum number of threads to use in the Spark Master REST API Server.") + .version("4.0.0") + .intConf + .createWithDefault(200) + + private[spark] val MASTER_REST_SERVER_FILTERS = ConfigBuilder("spark.master.rest.filters") + .doc("Comma separated list of filter class names to apply to the Spark Master REST API.") + .version("4.0.0") + .stringConf + .toSequence + .createWithDefault(Nil) + + private[spark] val MASTER_REST_SERVER_VIRTUAL_THREADS = + ConfigBuilder("spark.master.rest.virtualThread.enabled") + .doc("If true, Spark master tries to use Java 21 virtual thread for REST API.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + private[spark] val MASTER_UI_PORT = ConfigBuilder("spark.master.ui.port") .version("1.1.0") .intConf @@ -1985,6 +2028,14 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val MASTER_USE_DRIVER_ID_AS_APP_NAME = + ConfigBuilder("spark.master.useDriverIdAsAppName.enabled") + .internal() + .doc("(Experimental) If true, Spark master tries to set driver ID as appName.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + private[spark] val IO_COMPRESSION_SNAPPY_BLOCKSIZE = ConfigBuilder("spark.io.compression.snappy.blockSize") .doc("Block size in bytes used in Snappy compression, in the case when " + @@ -2424,11 +2475,11 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH = + private[spark] val KILL_ON_FATAL_ERROR_DEPTH = ConfigBuilder("spark.executor.killOnFatalError.depth") .doc("The max depth of the exception chain in a failed task Spark will search for a fatal " + - "error to check whether it should kill an executor. 0 means not checking any fatal " + - "error, 1 means checking only the exception but not the cause, and so on.") + "error to check whether it should kill the JVM process. 0 means not checking any fatal" + + " error, 1 means checking only the exception but not the cause, and so on.") .internal() .version("3.1.0") .intConf diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index f245d2d4e4074..476cddc643954 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -168,7 +168,8 @@ class HadoopMapReduceCommitProtocol( // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, // the file name is fine and won't overflow. val split = taskContext.getTaskAttemptID.getTaskID.getId - f"${spec.prefix}part-$split%05d-$jobId${spec.suffix}" + val basename = taskContext.getConfiguration.get("mapreduce.output.basename", "part") + f"${spec.prefix}$basename-$split%05d-$jobId${spec.suffix}" } override def setupJob(jobContext: JobContext): Unit = { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala index 7571395289967..2ab49eae8cd85 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala @@ -24,15 +24,21 @@ import jakarta.servlet.http.HttpServletRequest import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SparkConf +import org.apache.spark.annotation.{DeveloperApi, Since, Unstable} import org.apache.spark.ui.JettyUtils._ /** + * :: DeveloperApi :: * This exposes the metrics of the given registry with Prometheus format. * * The output is consistent with /metrics/json result in terms of item ordering * and with the previous result of Spark JMX Sink + Prometheus JMX Converter combination * in terms of key string format. + * + * This is used by Spark MetricsSystem internally and Spark K8s operator. */ +@Unstable +@DeveloperApi private[spark] class PrometheusServlet( val property: Properties, val registry: MetricRegistry) extends Sink { @@ -47,7 +53,10 @@ private[spark] class PrometheusServlet( ) } - def getMetricsSnapshot(request: HttpServletRequest): String = { + def getMetricsSnapshot(request: HttpServletRequest): String = getMetricsSnapshot() + + @Since("4.0.0") + def getMetricsSnapshot(): String = { import scala.jdk.CollectionConverters._ val gaugesLabel = """{type="gauges"}""" diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index bf539320b5985..2b6f322d1805d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -242,10 +242,12 @@ class NewHadoopRDD[K, V]( private var finished = false private var reader = try { - val _reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - _reader + Utils.tryInitializeResource( + format.createRecordReader(split.serializableHadoopSplit.value, hadoopAttemptContext) + ) { reader => + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + reader + } } catch { case e: FileNotFoundException if ignoreMissingFiles => logWarning(log"Skipped missing file: ${MDC(PATH, split.serializableHadoopSplit)}", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index a806b72766c6f..126c92e4cb656 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -240,7 +240,7 @@ private object PipedRDD { def tokenize(command: String): Seq[String] = { val buf = new ArrayBuffer[String] val tok = new StringTokenizer(command) - while(tok.hasMoreElements) { + while (tok.hasMoreElements) { buf += tok.nextToken() } buf.toSeq diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ac93abf3fe7a0..0db0133f632bf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -211,6 +211,11 @@ abstract class RDD[T: ClassTag]( * @return This RDD. */ def unpersist(blocking: Boolean = false): this.type = { + if (isLocallyCheckpointed) { + // This means its lineage has been truncated and cannot be recomputed once unpersisted. + logWarning(log"RDD ${MDC(RDD_ID, id)} was locally checkpointed, its lineage has been" + + log" truncated and cannot be recomputed after unpersisting") + } logInfo(log"Removing RDD ${MDC(RDD_ID, id)} from persistence list") sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala index 7dcde35de2518..0baa578764d08 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala @@ -327,7 +327,7 @@ object ResourceProfile extends Logging { */ val CORES = "cores" /** - * built-in executor resource: cores + * built-in executor resource: memory */ val MEMORY = "memory" /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6c824e2fdeaed..4f7338f74e298 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -27,6 +27,7 @@ import scala.annotation.tailrec import scala.collection.Map import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.concurrent.Promise import scala.concurrent.duration._ import scala.util.control.NonFatal @@ -173,7 +174,7 @@ private[spark] class DAGScheduler( // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in // this set, the job will be immediately cancelled. private[scheduler] val cancelledJobGroups = - new LimitedSizeFIFOSet[String](sc.getConf.get(config.NUM_CANCELLED_JOB_GROUPS_TO_TRACK)) + new LimitedSizeFIFOSet[String](sc.conf.get(config.NUM_CANCELLED_JOB_GROUPS_TO_TRACK)) /** * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids @@ -223,9 +224,9 @@ private[spark] class DAGScheduler( private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ - private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY) + private val disallowStageRetryForTest = sc.conf.get(TEST_NO_STAGE_RETRY) - private val shouldMergeResourceProfiles = sc.getConf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS) + private val shouldMergeResourceProfiles = sc.conf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS) /** * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, @@ -233,19 +234,19 @@ private[spark] class DAGScheduler( * executor(instead of the host) on a FetchFailure. */ private[scheduler] val unRegisterOutputOnHostOnFetchFailure = - sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) + sc.conf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) /** * Number of consecutive stage attempts allowed before a stage is aborted. */ private[scheduler] val maxConsecutiveStageAttempts = - sc.getConf.get(config.STAGE_MAX_CONSECUTIVE_ATTEMPTS) + sc.conf.get(config.STAGE_MAX_CONSECUTIVE_ATTEMPTS) /** * Max stage attempts allowed before a stage is aborted. */ private[scheduler] val maxStageAttempts: Int = { - Math.max(maxConsecutiveStageAttempts, sc.getConf.get(config.STAGE_MAX_ATTEMPTS)) + Math.max(maxConsecutiveStageAttempts, sc.conf.get(config.STAGE_MAX_ATTEMPTS)) } /** @@ -253,7 +254,7 @@ private[spark] class DAGScheduler( * count spark.stage.maxConsecutiveAttempts */ private[scheduler] val ignoreDecommissionFetchFailure = - sc.getConf.get(config.STAGE_IGNORE_DECOMMISSION_FETCH_FAILURE) + sc.conf.get(config.STAGE_IGNORE_DECOMMISSION_FETCH_FAILURE) /** * Number of max concurrent tasks check failures for each barrier job. @@ -263,14 +264,14 @@ private[spark] class DAGScheduler( /** * Time in seconds to wait between a max concurrent tasks check failure and the next check. */ - private val timeIntervalNumTasksCheck = sc.getConf + private val timeIntervalNumTasksCheck = sc.conf .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) /** * Max number of max concurrent tasks check failures allowed for a job before fail the job * submission. */ - private val maxFailureNumTasksCheck = sc.getConf + private val maxFailureNumTasksCheck = sc.conf .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) private val messageScheduler = @@ -285,26 +286,26 @@ private[spark] class DAGScheduler( taskScheduler.setDAGScheduler(this) - private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf, isDriver = true) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.conf, isDriver = true) private val blockManagerMasterDriverHeartbeatTimeout = - sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis + sc.conf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis private val shuffleMergeResultsTimeoutSec = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT) private val shuffleMergeFinalizeWaitSec = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT) private val shuffleMergeWaitMinSizeThreshold = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT) - private val shufflePushMinRatio = sc.getConf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO) + private val shufflePushMinRatio = sc.conf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO) private val shuffleMergeFinalizeNumThreads = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS) - private val shuffleFinalizeRpcThreads = sc.getConf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS) + private val shuffleFinalizeRpcThreads = sc.conf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS) // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient needs to be // initialized lazily @@ -327,11 +328,10 @@ private[spark] class DAGScheduler( ThreadUtils.newDaemonFixedThreadPool(shuffleFinalizeRpcThreads, "shuffle-merge-finalize-rpc") /** Whether rdd cache visibility tracking is enabled. */ - private val trackingCacheVisibility: Boolean = - sc.getConf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED) + private val trackingCacheVisibility: Boolean = sc.conf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED) /** Whether to abort a stage after canceling all of its tasks. */ - private val legacyAbortStageAfterKillTasks = sc.getConf.get(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS) + private val legacyAbortStageAfterKillTasks = sc.conf.get(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS) /** * Called by the TaskSetManager to report task's starting. @@ -556,7 +556,7 @@ private[spark] class DAGScheduler( * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage */ private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { - if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { + if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.conf)) { throw SparkCoreErrors.barrierStageWithDynamicAllocationError() } } @@ -1116,11 +1116,18 @@ private[spark] class DAGScheduler( /** * Cancel all jobs with a given tag. + * + * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. + * @param reason reason for cancellation. + * @param cancelledJobs a promise to be completed with operation IDs being cancelled. */ - def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = { + def cancelJobsWithTag( + tag: String, + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason)) + eventProcessLoop.post(JobTagCancelled(tag, reason, cancelledJobs)) } /** @@ -1234,17 +1241,22 @@ private[spark] class DAGScheduler( jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) } - private[scheduler] def handleJobTagCancelled(tag: String, reason: Option[String]): Unit = { - // Cancel all jobs belonging that have this tag. + private[scheduler] def handleJobTagCancelled( + tag: String, + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = { + // Cancel all jobs that have all provided tags. // First finds all active jobs with this group id, and then kill stages for them. - val jobIds = activeJobs.filter { activeJob => + val jobsToBeCancelled = activeJobs.filter { activeJob => Option(activeJob.properties).exists { properties => Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("") .split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag) } - }.map(_.jobId) - val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) + } + val updatedReason = + reason.getOrElse("part of cancelled job tags %s".format(tag)) + jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_, Option(updatedReason))) + cancelledJobs.map(_.success(jobsToBeCancelled.toSeq)) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -2150,7 +2162,7 @@ private[spark] class DAGScheduler( case mapStage: ShuffleMapStage => val numMissingPartitions = mapStage.findMissingPartitions().length if (numMissingPartitions < mapStage.numTasks) { - if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { val reason = "A shuffle map stage with indeterminate output was failed " + "and retried. However, Spark can only do this while using the new " + "shuffle block fetching protocol. Please check the config " + @@ -2880,8 +2892,8 @@ private[spark] class DAGScheduler( val finalException = exception.collect { // If the error is user-facing (defines error class and is not internal error), we don't // wrap it with "Job aborted" and expose this error to the end users directly. - case st: Exception with SparkThrowable if st.getErrorClass != null && - !SparkThrowableHelper.isInternalError(st.getErrorClass) => + case st: Exception with SparkThrowable if st.getCondition != null && + !SparkThrowableHelper.isInternalError(st.getCondition) => st }.getOrElse { new SparkException(s"Job aborted due to stage failure: $reason", cause = exception.orNull) @@ -3113,8 +3125,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) - case JobTagCancelled(tag, reason) => - dagScheduler.handleJobTagCancelled(tag, reason) + case JobTagCancelled(tag, reason, cancelledJobs) => + dagScheduler.handleJobTagCancelled(tag, reason, cancelledJobs) case AllJobsCancelled => dagScheduler.doCancelAllJobs() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index c9ad54d1fdc7e..8932d2ef323ba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler import java.util.Properties +import scala.concurrent.Promise + import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.util.{AccumulatorV2, CallSite} @@ -71,7 +73,8 @@ private[scheduler] case class JobGroupCancelled( private[scheduler] case class JobTagCancelled( tagName: String, - reason: Option[String]) extends DAGSchedulerEvent + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]) extends DAGSchedulerEvent private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index efd8fecb974e8..1e46142fab255 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -31,7 +31,7 @@ import org.apache.spark.deploy.history.EventLogFileWriter import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.util.{JsonProtocol, Utils} +import org.apache.spark.util.{JsonProtocol, JsonProtocolOptions, Utils} /** * A SparkListener that logs events to persistent storage. @@ -74,6 +74,8 @@ private[spark] class EventLoggingListener( private val liveStageExecutorMetrics = mutable.HashMap.empty[(Int, Int), mutable.HashMap[String, ExecutorMetrics]] + private[this] val jsonProtocolOptions = new JsonProtocolOptions(sparkConf) + /** * Creates the log file in the configured log directory. */ @@ -84,7 +86,7 @@ private[spark] class EventLoggingListener( private def initEventLog(): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) - val eventJson = JsonProtocol.sparkEventToJsonString(metadata) + val eventJson = JsonProtocol.sparkEventToJsonString(metadata, jsonProtocolOptions) logWriter.writeEvent(eventJson, flushLogger = true) if (testing && loggedEvents != null) { loggedEvents += eventJson @@ -93,7 +95,7 @@ private[spark] class EventLoggingListener( /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false): Unit = { - val eventJson = JsonProtocol.sparkEventToJsonString(event) + val eventJson = JsonProtocol.sparkEventToJsonString(event, jsonProtocolOptions) logWriter.writeEvent(eventJson, flushLogger) if (testing) { loggedEvents += eventJson diff --git a/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala index 1606072153906..82ec0ef91f4fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala @@ -425,14 +425,16 @@ private[spark] object HealthTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" /** - * Returns true if the excludeOnFailure is enabled, based on checking the configuration - * in the following order: - * 1. Is it specifically enabled or disabled? - * 2. Is it enabled via the legacy timeout conf? - * 3. Default is off + * Returns true if the excludeOnFailure is enabled on the application level, + * based on checking the configuration in the following order: + * 1. Is application level exclusion specifically enabled or disabled? + * 2. Is overall exclusion feature enabled or disabled? + * 3. Is it enabled via the legacy timeout conf? + * 4. Default is off */ def isExcludeOnFailureEnabled(conf: SparkConf): Boolean = { - conf.get(config.EXCLUDE_ON_FAILURE_ENABLED) match { + conf.get(config.EXCLUDE_ON_FAILURE_ENABLED_APPLICATION) + .orElse(conf.get(config.EXCLUDE_ON_FAILURE_ENABLED)) match { case Some(enabled) => enabled case None => diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index feed831620840..bfd6759387034 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -45,14 +45,21 @@ private[spark] class JobWaiter[T]( def completionFuture: Future[Unit] = jobPromise.future /** - * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled - * asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it - * will fail this job with a SparkException. + * Sends a signal to the DAGScheduler to cancel the job with an optional reason. The + * cancellation itself is handled asynchronously. After the low level scheduler cancels + * all the tasks belonging to this job, it will fail this job with a SparkException. */ - def cancel(): Unit = { - dagScheduler.cancelJob(jobId, None) + def cancel(reason: Option[String]): Unit = { + dagScheduler.cancelJob(jobId, reason) } + /** + * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is + * handled asynchronously. After the low level scheduler cancels all the tasks belonging + * to this job, it will fail this job with a SparkException. + */ + def cancel(): Unit = cancel(None) + override def taskSucceeded(index: Int, result: Any): Unit = { // resultHandler call must be synchronized in case resultHandler itself is not thread safe. synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index cc19b71bfc4d6..384f939a843bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,8 +22,6 @@ import javax.annotation.Nullable import scala.collection.Map -import com.fasterxml.jackson.annotation.JsonTypeInfo - import org.apache.spark.TaskEndReason import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} @@ -31,13 +29,6 @@ import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} -@DeveloperApi -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") -trait SparkListenerEvent { - /* Whether output this event to the event log */ - protected[spark] def logEvent: Boolean = true -} - @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 6e449e4dc1112..f511aed6d2166 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -215,7 +215,7 @@ private[spark] abstract class Task[T]( context.taskMetrics().nonZeroInternalAccums() ++ // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not // filter them out. - context.taskMetrics().externalAccums.filter(a => !taskFailed || a.countFailedValues) + context.taskMetrics().withExternalAccums(_.filter(a => !taskFailed || a.countFailedValues)) } else { Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ec678256a708e..8e3cb1379339d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -875,7 +875,7 @@ private[spark] class TaskSchedulerImpl( executorRunTime = acc.value.asInstanceOf[Long] } } - acc.toInfo(Some(acc.value), None) + acc.toInfoUpdate } val taskProcessRate = if (efficientTaskCalcualtionEnabled) { getTaskProcessRate(recordsRead, executorRunTime) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala index c9aa74e0852be..3637305293107 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala @@ -31,6 +31,9 @@ import org.apache.spark.util.Clock * which is handled by [[HealthTracker]]. Note that HealthTracker does not know anything * about task failures until a taskset completes successfully. * + * If isDryRun is true, then this class will only function to store information for application + * level exclusion, and will not actually exclude any tasks in task/stage level. + * * THREADING: This class is a helper to [[TaskSetManager]]; as with the methods in * [[TaskSetManager]] this class is designed only to be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. @@ -40,7 +43,8 @@ private[scheduler] class TaskSetExcludelist( val conf: SparkConf, val stageId: Int, val stageAttemptId: Int, - val clock: Clock) extends Logging { + val clock: Clock, + val isDryRun: Boolean = false) extends Logging { private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR) private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) @@ -80,13 +84,13 @@ private[scheduler] class TaskSetExcludelist( * of the scheduler, where those filters will have already been applied. */ def isExecutorExcludedForTask(executorId: String, index: Int): Boolean = { - execToFailures.get(executorId).exists { execFailures => + !isDryRun && execToFailures.get(executorId).exists { execFailures => execFailures.getNumTaskFailures(index) >= MAX_TASK_ATTEMPTS_PER_EXECUTOR } } def isNodeExcludedForTask(node: String, index: Int): Boolean = { - nodeToExcludedTaskIndexes.get(node).exists(_.contains(index)) + !isDryRun && nodeToExcludedTaskIndexes.get(node).exists(_.contains(index)) } /** @@ -96,11 +100,11 @@ private[scheduler] class TaskSetExcludelist( * scheduler, where those filters will already have been applied. */ def isExecutorExcludedForTaskSet(executorId: String): Boolean = { - excludedExecs.contains(executorId) + !isDryRun && excludedExecs.contains(executorId) } def isNodeExcludedForTaskSet(node: String): Boolean = { - excludedNodes.contains(node) + !isDryRun && excludedNodes.contains(node) } private[scheduler] def updateExcludedForFailedTask( @@ -163,3 +167,18 @@ private[scheduler] class TaskSetExcludelist( } } } + +private[scheduler] object TaskSetExcludelist { + + /** + * Returns true if the excludeOnFailure is enabled on the task/stage level, + * based on checking the configuration in the following order: + * 1. Is taskset level exclusion specifically enabled or disabled? + * 2. Is overall exclusion feature enabled or disabled? + * 3. Default is off + */ + def isExcludeOnFailureEnabled(conf: SparkConf): Boolean = { + conf.get(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE) + .orElse(conf.get(config.EXCLUDE_ON_FAILURE_ENABLED)).getOrElse(false) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 7dba4a6dc8fc4..fdc82285b76bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -143,8 +143,18 @@ private[spark] class TaskSetManager( private var calculatedTasks = 0 private[scheduler] val taskSetExcludelistHelperOpt: Option[TaskSetExcludelist] = { - healthTracker.map { _ => - new TaskSetExcludelist(sched.sc.listenerBus, conf, stageId, taskSet.stageAttemptId, clock) + if (TaskSetExcludelist.isExcludeOnFailureEnabled(conf)) { + Some(new TaskSetExcludelist(sched.sc.listenerBus, conf, stageId, + taskSet.stageAttemptId, clock)) + } else if (healthTracker.isDefined) { + // If we enabled exclusion at application level but not at taskset level exclusion, we create + // TaskSetExcludelist in dry run mode. + // In this mode, TaskSetExcludeList would not exclude any executors but only store + // task failure information. + Some(new TaskSetExcludelist(sched.sc.listenerBus, conf, stageId, + taskSet.stageAttemptId, clock, isDryRun = true)) + } else { + None } } @@ -698,7 +708,6 @@ private[spark] class TaskSetManager( private[scheduler] def getCompletelyExcludedTaskIfAny( hostToExecutors: HashMap[String, HashSet[String]]): Option[Int] = { taskSetExcludelistHelperOpt.flatMap { taskSetExcludelist => - val appHealthTracker = healthTracker.get // Only look for unschedulable tasks when at least one executor has registered. Otherwise, // task sets will be (unnecessarily) aborted in cases when no executors have registered yet. if (hostToExecutors.nonEmpty) { @@ -725,7 +734,7 @@ private[spark] class TaskSetManager( hostToExecutors.forall { case (host, execsOnHost) => // Check if the task can run on the node val nodeExcluded = - appHealthTracker.isNodeExcluded(host) || + healthTracker.exists(_.isNodeExcluded(host)) || taskSetExcludelist.isNodeExcludedForTaskSet(host) || taskSetExcludelist.isNodeExcludedForTask(host, indexInTaskSet) if (nodeExcluded) { @@ -733,7 +742,7 @@ private[spark] class TaskSetManager( } else { // Check if the task can run on any of the executors execsOnHost.forall { exec => - appHealthTracker.isExecutorExcluded(exec) || + healthTracker.exists(_.isExecutorExcluded(exec)) || taskSetExcludelist.isExecutorExcludedForTaskSet(exec) || taskSetExcludelist.isExecutorExcludedForTask(exec, indexInTaskSet) } @@ -797,7 +806,7 @@ private[spark] class TaskSetManager( val info = taskInfos(tid) // SPARK-37300: when the task was already finished state, just ignore it, // so that there won't cause successful and tasksSuccessful wrong result. - if(info.finished) { + if (info.finished) { if (dropTaskInfoAccumulablesOnTaskCompletion) { // SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable // lifetime. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index f4caecd7d6741..eb408a95589f7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -186,6 +186,9 @@ private[spark] class StandaloneSchedulerBackend( val reason: ExecutorLossReason = exitStatus match { case Some(ExecutorExitCode.HEARTBEAT_FAILURE) => ExecutorExited(ExecutorExitCode.HEARTBEAT_FAILURE, exitCausedByApp = false, message) + case Some(ExecutorExitCode.BLOCK_MANAGER_REREGISTRATION_FAILED) => + ExecutorExited(ExecutorExitCode.BLOCK_MANAGER_REREGISTRATION_FAILED, + exitCausedByApp = false, message) case Some(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) => ExecutorExited(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR, exitCausedByApp = false, message) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index ec5d53e91b3e0..783da1fa4c286 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -38,6 +38,7 @@ import org.apache.avro.generic.{GenericContainer, GenericData, GenericRecord} import org.roaringbitmap.RoaringBitmap import org.apache.spark._ +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.CLASS_NAME @@ -519,6 +520,7 @@ private[spark] class KryoSerializerInstance( * Interface implemented by clients to register their classes with Kryo when using Kryo * serialization. */ +@DeveloperApi trait KryoRegistrator { def registerClasses(kryo: Kryo): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 73e72b7f1dfc7..80e6ab7c0a663 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -238,7 +238,7 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false f: BlockInfo => Boolean): Option[BlockInfo] = { var done = false var result: Option[BlockInfo] = None - while(!done) { + while (!done) { val wrapper = blockInfoWrappers.get(blockId) if (wrapper == null) { done = true diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a6d62005e4e66..d99bc5bf30546 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -40,7 +40,7 @@ import org.apache.commons.io.IOUtils import org.apache.spark._ import org.apache.spark.errors.SparkCoreErrors -import org.apache.spark.executor.DataReadMethod +import org.apache.spark.executor.{DataReadMethod, ExecutorExitCode} import org.apache.spark.internal.{config, Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.config.{Network, RDD_CACHE_VISIBILITY_TRACKING_ENABLED, Tests} @@ -463,7 +463,7 @@ private[spark] class BlockManager( * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing * so may corrupt or change the data stored by the `BlockManager`. */ - private case class ByteBufferBlockStoreUpdater[T]( + private[spark] case class ByteBufferBlockStoreUpdater[T]( blockId: BlockId, level: StorageLevel, classTag: ClassTag[T], @@ -671,7 +671,7 @@ private[spark] class BlockManager( reportAllBlocks() } else { logError("Exiting executor due to block manager re-registration failure") - System.exit(-1) + System.exit(ExecutorExitCode.BLOCK_MANAGER_REREGISTRATION_FAILED) } } @@ -1516,7 +1516,7 @@ private[spark] class BlockManager( return true } - if(master.isRDDBlockVisible(blockId)) { + if (master.isRDDBlockVisible(blockId)) { // Cache the visibility status if block exists. blockInfoManager.tryMarkBlockAsVisible(blockId) true @@ -1882,7 +1882,7 @@ private[spark] class BlockManager( blockId, numPeersToReplicateTo) - while(numFailures <= maxReplicationFailureCount && + while (numFailures <= maxReplicationFailureCount && peersForReplication.nonEmpty && peersReplicatedTo.size < numPeersToReplicateTo) { val peer = peersForReplication.head @@ -2126,8 +2126,10 @@ private[spark] class BlockManager( hasRemoveBlock = true if (tellMaster) { // Only update storage level from the captured block status before deleting, so that - // memory size and disk size are being kept for calculating delta. - reportBlockStatus(blockId, blockStatus.get.copy(storageLevel = StorageLevel.NONE)) + // memory size and disk size are being kept for calculating delta. Reset the replica + // count 0 in storage level to notify that it is a remove operation. + val storageLevel = StorageLevel(blockStatus.get.storageLevel.toInt, 0) + reportBlockStatus(blockId, blockStatus.get.copy(storageLevel = storageLevel)) } } finally { if (!hasRemoveBlock) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 73f89ea0e86e5..fc4e6e771aad7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -1059,13 +1059,13 @@ private[spark] class BlockManagerInfo( _blocks.put(blockId, blockStatus) _remainingMem -= memSize if (blockExists) { - logInfo(log"Updated ${MDC(BLOCK_ID, blockId)} in memory on " + + logDebug(log"Updated ${MDC(BLOCK_ID, blockId)} in memory on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} (current size: " + log"${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(memSize))}, original " + log"size: ${MDC(ORIGINAL_MEMORY_SIZE, Utils.bytesToString(originalMemSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } else { - logInfo(log"Added ${MDC(BLOCK_ID, blockId)} in memory on " + + logDebug(log"Added ${MDC(BLOCK_ID, blockId)} in memory on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} " + log"(size: ${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(memSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") @@ -1075,12 +1075,12 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) if (blockExists) { - logInfo(log"Updated ${MDC(BLOCK_ID, blockId)} on disk on " + + logDebug(log"Updated ${MDC(BLOCK_ID, blockId)} on disk on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} " + log"(current size: ${MDC(CURRENT_DISK_SIZE, Utils.bytesToString(diskSize))}," + log" original size: ${MDC(ORIGINAL_DISK_SIZE, Utils.bytesToString(originalDiskSize))})") } else { - logInfo(log"Added ${MDC(BLOCK_ID, blockId)} on disk on " + + logDebug(log"Added ${MDC(BLOCK_ID, blockId)} on disk on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} (size: " + log"${MDC(CURRENT_DISK_SIZE, Utils.bytesToString(diskSize))})") } @@ -1098,13 +1098,13 @@ private[spark] class BlockManagerInfo( blockStatus.remove(blockId) } if (originalLevel.useMemory) { - logInfo(log"Removed ${MDC(BLOCK_ID, blockId)} on " + + logDebug(log"Removed ${MDC(BLOCK_ID, blockId)} on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} in memory " + log"(size: ${MDC(ORIGINAL_MEMORY_SIZE, Utils.bytesToString(originalMemSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } if (originalLevel.useDisk) { - logInfo(log"Removed ${MDC(BLOCK_ID, blockId)} on " + + logDebug(log"Removed ${MDC(BLOCK_ID, blockId)} on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} on disk" + log" (size: ${MDC(ORIGINAL_DISK_SIZE, Utils.bytesToString(originalDiskSize))})") } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala index 686ac1eb786e0..f29e8778da037 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala @@ -60,7 +60,13 @@ class BlockManagerStorageEndpoint( if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } - SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) + val shuffleManager = SparkEnv.get.shuffleManager + if (shuffleManager != null) { + shuffleManager.unregisterShuffle(shuffleId) + } else { + logDebug(log"Ignore remove shuffle ${MDC(SHUFFLE_ID, shuffleId)}") + true + } } case DecommissionBlockManager => diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1498b224b0c92..3e57094b36a7e 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -35,6 +35,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer @@ -324,7 +325,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: private var _transferred = 0L - private val buffer = ByteBuffer.allocateDirect(64 * 1024) + private val buffer = Platform.allocateDirectBuffer(64 * 1024) buffer.flip() override def count(): Long = blockSize diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ff1799d8ff3e1..57f6901a7a735 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1156,6 +1156,12 @@ final class ShuffleBlockFetcherIterator( log"checksum support for push-based shuffle.") s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + case shuffleBlockBatch: ShuffleBlockBatchId => + logWarning(log"BlockBatch ${MDC(SHUFFLE_BLOCK_INFO, shuffleBlockBatch)} is corrupted " + + log"but corruption diagnosis is skipped due to lack of shuffle checksum support for " + + log"ShuffleBlockBatchId") + s"BlockBatch $shuffleBlockBatch is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for ShuffleBlockBatchId" case unexpected: BlockId => throw SparkException.internalError( s"Unexpected type of BlockId, $unexpected", category = "STORAGE") diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 7a2b7d9caec42..fc7a4675429aa 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -35,7 +35,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { // Carriage return private val CR = '\r' // Update period of progress bar, in milliseconds - private val updatePeriodMSec = sc.getConf.get(UI_CONSOLE_PROGRESS_UPDATE_INTERVAL) + private val updatePeriodMSec = sc.conf.get(UI_CONSOLE_PROGRESS_UPDATE_INTERVAL) // Delay to show up a progress bar, in milliseconds private val firstDelayMSec = 500L diff --git a/core/src/main/scala/org/apache/spark/ui/JWSFilter.scala b/core/src/main/scala/org/apache/spark/ui/JWSFilter.scala new file mode 100644 index 0000000000000..4031aa6ff192f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/JWSFilter.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.crypto.SecretKey + +import io.jsonwebtoken.{JwtException, Jwts} +import io.jsonwebtoken.io.Decoders +import io.jsonwebtoken.security.Keys +import jakarta.servlet.{Filter, FilterChain, FilterConfig, ServletRequest, ServletResponse} +import jakarta.servlet.http.{HttpServletRequest, HttpServletResponse} + +/** + * A servlet filter that requires JWS, a cryptographically signed JSON Web Token, in the header. + * + * Like the other UI filters, the following configurations are required to use this filter. + * {{{ + * - spark.ui.filters=org.apache.spark.ui.JWSFilter + * - spark.org.apache.spark.ui.JWSFilter.param.secretKey=BASE64URL-ENCODED-YOUR-PROVIDED-KEY + * }}} + * The HTTP request should have {@code Authorization: Bearer } header. + * {{{ + * - is a string with three fields, '
..'. + * -
is supposed to be a base64url-encoded string of '{"alg":"HS256","typ":"JWT"}'. + * - is a base64url-encoded string of fully-user-defined content. + * - is a signature based on '
.' and a user-provided key parameter. + * }}} + */ +private class JWSFilter extends Filter { + private val AUTHORIZATION = "Authorization" + + private var key: SecretKey = null + + /** + * Load and validate the configurtions: + * - IllegalArgumentException will happen if the user didn't provide this argument + * - WeakKeyException will happen if the user-provided value is insufficient + */ + override def init(config: FilterConfig): Unit = { + key = Keys.hmacShaKeyFor(Decoders.BASE64URL.decode(config.getInitParameter("secretKey"))); + } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + val hres = res.asInstanceOf[HttpServletResponse] + hres.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + + try { + val header = hreq.getHeader(AUTHORIZATION) + header match { + case null => + hres.sendError(HttpServletResponse.SC_FORBIDDEN, s"${AUTHORIZATION} header is missing.") + case s"Bearer $token" => + val claims = Jwts.parser().verifyWith(key).build().parseSignedClaims(token) + chain.doFilter(req, res) + case _ => + hres.sendError(HttpServletResponse.SC_FORBIDDEN, s"Malformed ${AUTHORIZATION} header.") + } + } catch { + case e: JwtException => + // We intentionally don't expose the detail of JwtException here + hres.sendError(HttpServletResponse.SC_FORBIDDEN, "JWT Validate Fail") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index f503be908c072..068f9b0b2cbcb 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -150,7 +150,10 @@ private[spark] object JettyUtils extends Logging { private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = { beforeRedirect(request) // Make sure we don't end up with "//" in the middle - val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString + val requestURL = new URI(request.getRequestURL.toString).toURL + // scalastyle:off URLConstructor + val newUrl = new URL(requestURL, prefixedDestPath).toString + // scalastyle:on URLConstructor response.sendRedirect(newUrl) } // SPARK-5983 ensure TRACE is not supported diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c7e375849278f..fff6ec4f5b170 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -236,14 +236,17 @@ private[spark] object UIUtils extends Logging { def dataTablesHeaderNodes(request: HttpServletRequest): Seq[Node] = { + - + - + } @@ -446,16 +449,24 @@ private[spark] object UIUtils extends Logging { val startRatio = if (total == 0) 0.0 else (boundedStarted.toDouble / total) * 100 val startWidth = "width: %s%%".format(startRatio) + val killTaskReasonText = reasonToNumKilled.toSeq.sortBy(-_._2).map { + case (reason, count) => s" ($count killed: $reason)" + }.mkString + val progressTitle = s"$completed/$total" + { + if (started > 0) s" ($started running)" else "" + } + { + if (failed > 0) s" ($failed failed)" else "" + } + { + if (skipped > 0) s" ($skipped skipped)" else "" + } + killTaskReasonText +
- + {completed}/{total} { if (failed == 0 && skipped == 0 && started > 0) s"($started running)" } { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } - { reasonToNumKilled.toSeq.sortBy(-_._2).map { - case (reason, count) => s"($count killed: $reason)" - } - } + { killTaskReasonText }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index ce7698a59ae91..a3c2d05414a88 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -52,7 +52,7 @@ private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) Option(request.getParameter("id")).map(_.toInt).foreach { id => store.asOption(store.job(id)).foreach { job => if (job.status == JobExecutionStatus.RUNNING) { - sc.foreach(_.cancelJob(id)) + sc.foreach(_.cancelJob(id, "killed via Web UI")) // Do a quick pause here to give Spark time to kill the job so it shows up as // killed after the refresh. Note that this will block the serving thread so the // time should be limited in duration. diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskThreadDumpPage.scala index 49b919ce0de97..5fe542d4fa21e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskThreadDumpPage.scala @@ -71,7 +71,7 @@ private[spark] class TaskThreadDumpPage(
-

Updated at{UIUtils.formatDate(time)}

+

Updated at {UIUtils.formatDate(time)}

diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 3237a321f1c3e..383a89d40ecee 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -102,16 +102,24 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { metadata.countFailedValues } + private def isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) + /** * Creates an [[AccumulableInfo]] representation of this [[AccumulatorV2]] with the provided * values. */ private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { - val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) AccumulableInfo(id, name, internOption(update), internOption(value), isInternal, countFailedValues) } + /** + * Creates an [[AccumulableInfo]] representation of this [[AccumulatorV2]] as an update. + */ + private[spark] def toInfoUpdate: AccumulableInfo = { + AccumulableInfo(id, name, internOption(Some(value)), None, isInternal, countFailedValues) + } + final private[spark] def isAtDriverSide: Boolean = atDriverSide /** diff --git a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala index a4145bb36acc9..fd10d60a13fd1 100644 --- a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import java.io.OutputStream import java.nio.ByteBuffer +import org.apache.spark.SparkException import org.apache.spark.storage.StorageUtils import org.apache.spark.unsafe.Platform @@ -29,16 +30,18 @@ import org.apache.spark.unsafe.Platform * @param capacity The initial capacity of the direct byte buffer */ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputStream { - private var buffer = Platform.allocateDirectBuffer(capacity) + private[this] var buffer = Platform.allocateDirectBuffer(capacity) def this() = this(32) override def write(b: Int): Unit = { + checkNotClosed() ensureCapacity(buffer.position() + 1) buffer.put(b.toByte) } override def write(b: Array[Byte], off: Int, len: Int): Unit = { + checkNotClosed() ensureCapacity(buffer.position() + len) buffer.put(b, off, len) } @@ -57,21 +60,35 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS if (newCapacity < minCapacity) newCapacity = minCapacity val oldBuffer = buffer oldBuffer.flip() - val newBuffer = ByteBuffer.allocateDirect(newCapacity) + val newBuffer = Platform.allocateDirectBuffer(newCapacity) newBuffer.put(oldBuffer) StorageUtils.dispose(oldBuffer) buffer = newBuffer } - def reset(): Unit = buffer.clear() + private def checkNotClosed(): Unit = { + if (buffer == null) { + throw SparkException.internalError( + "Cannot call methods on a closed DirectByteBufferOutputStream") + } + } + + def reset(): Unit = { + checkNotClosed() + buffer.clear() + } - def size(): Int = buffer.position() + def size(): Int = { + checkNotClosed() + buffer.position() + } /** * Any subsequent call to [[close()]], [[write()]], [[reset()]] will invalidate the buffer * returned by this method. */ def toByteBuffer: ByteBuffer = { + checkNotClosed() val outputBuffer = buffer.duplicate() outputBuffer.flip() outputBuffer @@ -80,6 +97,7 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS override def close(): Unit = { // Eagerly free the direct byte buffer without waiting for GC to reduce memory pressure. StorageUtils.dispose(buffer) + buffer = null } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 22dcf6c11e4b4..e30380f41566a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -28,6 +28,7 @@ import org.json4s.jackson.JsonMethods.compact import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.internal.config._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.{DeterministicLevel, RDDOperationScope} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceInformation, ResourceProfile, TaskResourceRequest} @@ -37,6 +38,16 @@ import org.apache.spark.storage._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils.weakIntern +/** + * Helper class for passing configuration options to JsonProtocol. + * We use this instead of passing SparkConf directly because it lets us avoid + * repeated re-parsing of configuration values on each read. + */ +private[spark] class JsonProtocolOptions(conf: SparkConf) { + val includeTaskMetricsAccumulators: Boolean = + conf.get(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS) +} + /** * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- * and forwards-compatibility guarantees: any version of Spark should be able to read JSON output @@ -55,30 +66,41 @@ import org.apache.spark.util.Utils.weakIntern private[spark] object JsonProtocol extends JsonUtils { // TODO: Remove this file and put JSON serialization into each individual class. + private[util] + val defaultOptions: JsonProtocolOptions = new JsonProtocolOptions(new SparkConf(false)) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ + // Only for use in tests. Production code should use the two-argument overload defined below. def sparkEventToJsonString(event: SparkListenerEvent): String = { + sparkEventToJsonString(event, defaultOptions) + } + + def sparkEventToJsonString(event: SparkListenerEvent, options: JsonProtocolOptions): String = { toJsonString { generator => - writeSparkEventToJson(event, generator) + writeSparkEventToJson(event, generator, options) } } - def writeSparkEventToJson(event: SparkListenerEvent, g: JsonGenerator): Unit = { + def writeSparkEventToJson( + event: SparkListenerEvent, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => - stageSubmittedToJson(stageSubmitted, g) + stageSubmittedToJson(stageSubmitted, g, options) case stageCompleted: SparkListenerStageCompleted => - stageCompletedToJson(stageCompleted, g) + stageCompletedToJson(stageCompleted, g, options) case taskStart: SparkListenerTaskStart => - taskStartToJson(taskStart, g) + taskStartToJson(taskStart, g, options) case taskGettingResult: SparkListenerTaskGettingResult => - taskGettingResultToJson(taskGettingResult, g) + taskGettingResultToJson(taskGettingResult, g, options) case taskEnd: SparkListenerTaskEnd => - taskEndToJson(taskEnd, g) + taskEndToJson(taskEnd, g, options) case jobStart: SparkListenerJobStart => - jobStartToJson(jobStart, g) + jobStartToJson(jobStart, g, options) case jobEnd: SparkListenerJobEnd => jobEndToJson(jobEnd, g) case environmentUpdate: SparkListenerEnvironmentUpdate => @@ -112,12 +134,15 @@ private[spark] object JsonProtocol extends JsonUtils { } } - def stageSubmittedToJson(stageSubmitted: SparkListenerStageSubmitted, g: JsonGenerator): Unit = { + def stageSubmittedToJson( + stageSubmitted: SparkListenerStageSubmitted, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageSubmitted) g.writeFieldName("Stage Info") // SPARK-42205: don't log accumulables in start events: - stageInfoToJson(stageSubmitted.stageInfo, g, includeAccumulables = false) + stageInfoToJson(stageSubmitted.stageInfo, g, options, includeAccumulables = false) Option(stageSubmitted.properties).foreach { properties => g.writeFieldName("Properties") propertiesToJson(properties, g) @@ -125,38 +150,48 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted, g: JsonGenerator): Unit = { + def stageCompletedToJson( + stageCompleted: SparkListenerStageCompleted, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageCompleted) g.writeFieldName("Stage Info") - stageInfoToJson(stageCompleted.stageInfo, g, includeAccumulables = true) + stageInfoToJson(stageCompleted.stageInfo, g, options, includeAccumulables = true) g.writeEndObject() } - def taskStartToJson(taskStart: SparkListenerTaskStart, g: JsonGenerator): Unit = { + def taskStartToJson( + taskStart: SparkListenerTaskStart, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskStart) g.writeNumberField("Stage ID", taskStart.stageId) g.writeNumberField("Stage Attempt ID", taskStart.stageAttemptId) g.writeFieldName("Task Info") // SPARK-42205: don't log accumulables in start events: - taskInfoToJson(taskStart.taskInfo, g, includeAccumulables = false) + taskInfoToJson(taskStart.taskInfo, g, options, includeAccumulables = false) g.writeEndObject() } def taskGettingResultToJson( taskGettingResult: SparkListenerTaskGettingResult, - g: JsonGenerator): Unit = { + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { val taskInfo = taskGettingResult.taskInfo g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskGettingResult) g.writeFieldName("Task Info") // SPARK-42205: don't log accumulables in "task getting result" events: - taskInfoToJson(taskInfo, g, includeAccumulables = false) + taskInfoToJson(taskInfo, g, options, includeAccumulables = false) g.writeEndObject() } - def taskEndToJson(taskEnd: SparkListenerTaskEnd, g: JsonGenerator): Unit = { + def taskEndToJson( + taskEnd: SparkListenerTaskEnd, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskEnd) g.writeNumberField("Stage ID", taskEnd.stageId) @@ -165,7 +200,7 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeFieldName("Task End Reason") taskEndReasonToJson(taskEnd.reason, g) g.writeFieldName("Task Info") - taskInfoToJson(taskEnd.taskInfo, g, includeAccumulables = true) + taskInfoToJson(taskEnd.taskInfo, g, options, includeAccumulables = true) g.writeFieldName("Task Executor Metrics") executorMetricsToJson(taskEnd.taskExecutorMetrics, g) Option(taskEnd.taskMetrics).foreach { m => @@ -175,7 +210,10 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - def jobStartToJson(jobStart: SparkListenerJobStart, g: JsonGenerator): Unit = { + def jobStartToJson( + jobStart: SparkListenerJobStart, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobStart) g.writeNumberField("Job ID", jobStart.jobId) @@ -186,7 +224,7 @@ private[spark] object JsonProtocol extends JsonUtils { // the job was submitted: it is technically possible for a stage to belong to multiple // concurrent jobs, so this situation can arise even without races occurring between // event logging and stage completion. - jobStart.stageInfos.foreach(stageInfoToJson(_, g, includeAccumulables = true)) + jobStart.stageInfos.foreach(stageInfoToJson(_, g, options, includeAccumulables = true)) g.writeEndArray() g.writeArrayFieldStart("Stage IDs") jobStart.stageIds.foreach(g.writeNumber) @@ -386,6 +424,7 @@ private[spark] object JsonProtocol extends JsonUtils { def stageInfoToJson( stageInfo: StageInfo, g: JsonGenerator, + options: JsonProtocolOptions, includeAccumulables: Boolean): Unit = { g.writeStartObject() g.writeNumberField("Stage ID", stageInfo.stageId) @@ -404,7 +443,10 @@ private[spark] object JsonProtocol extends JsonUtils { stageInfo.failureReason.foreach(g.writeStringField("Failure Reason", _)) g.writeFieldName("Accumulables") if (includeAccumulables) { - accumulablesToJson(stageInfo.accumulables.values, g) + accumulablesToJson( + stageInfo.accumulables.values, + g, + includeTaskMetricsAccumulators = options.includeTaskMetricsAccumulators) } else { g.writeStartArray() g.writeEndArray() @@ -418,6 +460,7 @@ private[spark] object JsonProtocol extends JsonUtils { def taskInfoToJson( taskInfo: TaskInfo, g: JsonGenerator, + options: JsonProtocolOptions, includeAccumulables: Boolean): Unit = { g.writeStartObject() g.writeNumberField("Task ID", taskInfo.taskId) @@ -435,7 +478,10 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeBooleanField("Killed", taskInfo.killed) g.writeFieldName("Accumulables") if (includeAccumulables) { - accumulablesToJson(taskInfo.accumulables, g) + accumulablesToJson( + taskInfo.accumulables, + g, + includeTaskMetricsAccumulators = options.includeTaskMetricsAccumulators) } else { g.writeStartArray() g.writeEndArray() @@ -443,13 +489,23 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - private lazy val accumulableExcludeList = Set("internal.metrics.updatedBlockStatuses") + private[util] val accumulableExcludeList = Set(InternalAccumulator.UPDATED_BLOCK_STATUSES) + + private[this] val taskMetricAccumulableNames = TaskMetrics.empty.nameToAccums.keySet.toSet - def accumulablesToJson(accumulables: Iterable[AccumulableInfo], g: JsonGenerator): Unit = { + def accumulablesToJson( + accumulables: Iterable[AccumulableInfo], + g: JsonGenerator, + includeTaskMetricsAccumulators: Boolean = true): Unit = { g.writeStartArray() accumulables - .filterNot(_.name.exists(accumulableExcludeList.contains)) - .toList.sortBy(_.id).foreach(a => accumulableInfoToJson(a, g)) + .filterNot { acc => + acc.name.exists(accumulableExcludeList.contains) || + (!includeTaskMetricsAccumulators && acc.name.exists(taskMetricAccumulableNames.contains)) + } + .toList + .sortBy(_.id) + .foreach(a => accumulableInfoToJson(a, g)) g.writeEndArray() } @@ -597,6 +653,8 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeNumberField("Executor Run Time", taskMetrics.executorRunTime) g.writeNumberField("Executor CPU Time", taskMetrics.executorCpuTime) g.writeNumberField("Peak Execution Memory", taskMetrics.peakExecutionMemory) + g.writeNumberField("Peak On Heap Execution Memory", taskMetrics.peakOnHeapExecutionMemory) + g.writeNumberField("Peak Off Heap Execution Memory", taskMetrics.peakOffHeapExecutionMemory) g.writeNumberField("Result Size", taskMetrics.resultSize) g.writeNumberField("JVM GC Time", taskMetrics.jvmGCTime) g.writeNumberField("Result Serialization Time", taskMetrics.resultSerializationTime) @@ -1254,6 +1312,10 @@ private[spark] object JsonProtocol extends JsonUtils { // The "Peak Execution Memory" field was added in Spark 3.0.0: metrics.setPeakExecutionMemory( jsonOption(json.get("Peak Execution Memory")).map(_.extractLong).getOrElse(0)) + metrics.setPeakOnHeapExecutionMemory( + jsonOption(json.get("Peak On Heap Execution Memory")).map(_.extractLong).getOrElse(0)) + metrics.setPeakOffHeapExecutionMemory( + jsonOption(json.get("Peak Off Heap Execution Memory")).map(_.extractLong).getOrElse(0)) metrics.setResultSize(json.get("Result Size").extractLong) metrics.setJvmGCTime(json.get("JVM GC Time").extractLong) metrics.setResultSerializationTime(json.get("Result Serialization Time").extractLong) @@ -1379,7 +1441,7 @@ private[spark] object JsonProtocol extends JsonUtils { val accumUpdates = jsonOption(json.get("Accumulator Updates")) .map(_.extractElements.map(accumulableInfoFromJson).toArray.toImmutableArraySeq) .getOrElse(taskMetricsFromJson(json.get("Metrics")).accumulators().map(acc => { - acc.toInfo(Some(acc.value), None) + acc.toInfoUpdate }).toArray.toImmutableArraySeq) ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost diff --git a/core/src/main/scala/org/apache/spark/util/LazyTry.scala b/core/src/main/scala/org/apache/spark/util/LazyTry.scala new file mode 100644 index 0000000000000..7edc08672c26b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/LazyTry.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Try + +/** + * Wrapper utility for a lazy val, with two differences compared to scala behavior: + * + * 1. Non-retrying in case of failure. This wrapper stores the exception in a Try, and will re-throw + * it on the access to `get`. + * In scala, when a `lazy val` field initialization throws an exception, the field remains + * uninitialized, and initialization will be re-attempted on the next access. This also can lead + * to performance issues, needlessly computing something towards a failure, and also can lead to + * duplicated side effects. + * + * 2. Resolving locking issues. + * In scala, when a `lazy val` field is initialized, it grabs the synchronized lock on the + * enclosing object instance. This can lead both to performance issues, and deadlocks. + * For example: + * a) Thread 1 entered a synchronized method, grabbing a coarse lock on the parent object. + * b) Thread 2 get spawned off, and tries to initialize a lazy value on the same parent object + * This causes scala to also try to grab a lock on the parent object. + * c) If thread 1 waits for thread 2 to join, a deadlock occurs. + * This wrapper will only grab a lock on the wrapper itself, and not the parent object. + * + * @param initialize The block of code to initialize the lazy value. + * @tparam T type of the lazy value. + */ +private[spark] class LazyTry[T](initialize: => T) extends Serializable { + private lazy val tryT: Try[T] = Utils.doTryWithCallerStacktrace { initialize } + + /** + * Get the lazy value. If the initialization block threw an exception, it will be re-thrown here. + * The exception will be re-thrown with the current caller's stacktrace. + * An exception with stack trace from when the exception was first thrown can be accessed with + * ``` + * ex.getSuppressed.find { e => + * e.getMessage == org.apache.spark.util.Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE + * } + * ``` + */ + def get: T = Utils.getTryWithCallerStacktrace(tryT) +} + +private[spark] object LazyTry { + /** + * Create a new LazyTry instance. + * + * @param initialize The block of code to initialize the lazy value. + * @tparam T type of the lazy value. + * @return a new LazyTry instance. + */ + def apply[T](initialize: => T): LazyTry[T] = new LazyTry(initialize) +} diff --git a/core/src/main/scala/org/apache/spark/util/NonFateSharingCache.scala b/core/src/main/scala/org/apache/spark/util/NonFateSharingCache.scala index 21184d70b386a..7d01facc1e421 100644 --- a/core/src/main/scala/org/apache/spark/util/NonFateSharingCache.scala +++ b/core/src/main/scala/org/apache/spark/util/NonFateSharingCache.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.util.concurrent.Callable +import java.util.concurrent.{Callable, TimeUnit} import com.google.common.cache.{Cache, CacheBuilder, CacheLoader, LoadingCache} @@ -68,6 +68,20 @@ private[spark] object NonFateSharingCache { override def load(k: K): V = loadingFunc.apply(k) })) } + + def apply[K, V]( + maximumSize: Long, + expireAfterAccessTime: Long, + expireAfterAccessTimeUnit: TimeUnit): NonFateSharingCache[K, V] = { + val builder = CacheBuilder.newBuilder().asInstanceOf[CacheBuilder[K, V]] + if (maximumSize > 0L) { + builder.maximumSize(maximumSize) + } + if(expireAfterAccessTime > 0) { + builder.expireAfterAccess(expireAfterAccessTime, expireAfterAccessTimeUnit) + } + new NonFateSharingCache(builder.build[K, V]()) + } } private[spark] class NonFateSharingCache[K, V](protected val cache: Cache[K, V]) { diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index c1ea4f929101f..25a6322743baa 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -17,9 +17,11 @@ package org.apache.spark.util +import org.apache.spark.SparkEnv import org.apache.spark.executor.{ExecutorExitCode, KilledByTaskReaperException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.THREAD +import org.apache.spark.internal.config.KILL_ON_FATAL_ERROR_DEPTH /** * The default uncaught exception handler for Spark daemons. It terminates the whole process for @@ -36,6 +38,17 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: val _ = SparkExitCode.OOM } + // The maximum depth to search in the exception cause chain for a fatal error, + // as defined by KILL_ON_FATAL_ERROR_DEPTH. + // + // SPARK-50034: When this handler is called, there is a fatal error in the cause chain within + // the specified depth. We should identify that fatal error and exit with the + // correct exit code. + private val killOnFatalErrorDepth: Int = + // At this point SparkEnv might be None + Option(SparkEnv.get).map(_.conf.get(KILL_ON_FATAL_ERROR_DEPTH)).getOrElse(5) + + override def uncaughtException(thread: Thread, exception: Throwable): Unit = { try { val mdc = MDC(THREAD, thread) @@ -50,19 +63,31 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) if (!ShutdownHookManager.inShutdown()) { - exception match { - case _: OutOfMemoryError => - System.exit(SparkExitCode.OOM) - case e: SparkFatalException if e.throwable.isInstanceOf[OutOfMemoryError] => - // SPARK-24294: This is defensive code, in case that SparkFatalException is - // misused and uncaught. - System.exit(SparkExitCode.OOM) - case _: KilledByTaskReaperException if exitOnUncaughtException => - System.exit(ExecutorExitCode.KILLED_BY_TASK_REAPER) - case _ if exitOnUncaughtException => - System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) - case _ => - // SPARK-30310: Don't System.exit() when exitOnUncaughtException is false + // Traverse the causes up to killOnFatalErrorDepth layers + var currentException: Throwable = exception + var depth = 0 + + while (currentException != null && depth < killOnFatalErrorDepth) { + currentException match { + case _: OutOfMemoryError => + System.exit(SparkExitCode.OOM) + case e: SparkFatalException if e.throwable.isInstanceOf[OutOfMemoryError] => + // SPARK-24294: This is defensive code, in case that SparkFatalException is + // misused and uncaught. + System.exit(SparkExitCode.OOM) + case _: KilledByTaskReaperException if exitOnUncaughtException => + System.exit(ExecutorExitCode.KILLED_BY_TASK_REAPER) + // No match, continue traversing the cause chain + case _ => + } + // Move to the next cause in the chain + currentException = currentException.getCause + depth += 1 + } + + // SPARK-30310: Don't System.exit() when exitOnUncaughtException is false + if (exitOnUncaughtException) { + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/TransientLazy.scala b/core/src/main/scala/org/apache/spark/util/TransientLazy.scala new file mode 100644 index 0000000000000..2833ef93669a6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TransientLazy.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +/** + * Construct to lazily initialize a variable. + * This may be helpful for avoiding deadlocks in certain scenarios. For example, + * a) Thread 1 entered a synchronized method, grabbing a coarse lock on the parent object. + * b) Thread 2 gets spawned off, and tries to initialize a lazy value on the same parent object + * (in our case, this was the logger). This causes scala to also try to grab a coarse lock on + * the parent object. + * c) If thread 1 waits for thread 2 to join, a deadlock occurs. + * The main difference between this and [[LazyTry]] is that this does not cache failures. + * + * @note + * Scala 3 uses a different implementation of lazy vals which doesn't have this problem. + * Please refer to Lazy + * Vals Initialization for more details. + */ +private[spark] class TransientLazy[T](initializer: => T) extends Serializable { + + @transient + private[this] lazy val value: T = initializer + + def apply(): T = { + value + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a37aedfcb635a..109db36d40695 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -638,7 +638,7 @@ private[spark] object Utils val is = Channels.newInputStream(source) downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => - val uc = new URL(url).openConnection() + val uc = new URI(url).toURL.openConnection() val timeoutMs = conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 uc.setConnectTimeout(timeoutMs) @@ -1351,6 +1351,86 @@ private[spark] object Utils } } + val TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE = + "Full stacktrace of original doTryWithCallerStacktrace caller" + + val TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE = + "Stacktrace under doTryWithCallerStacktrace" + + /** + * Use Try with stacktrace substitution for the caller retrieving the error. + * + * Normally in case of failure, the exception would have the stacktrace of the caller that + * originally called doTryWithCallerStacktrace. However, we want to replace the part above + * this function with the stacktrace of the caller who calls getTryWithCallerStacktrace. + * So here we save the part of the stacktrace below doTryWithCallerStacktrace, and + * getTryWithCallerStacktrace will stitch it with the new stack trace of the caller. + * The full original stack trace is kept in ex.getSuppressed. + * + * @param f Code block to be wrapped in Try + * @return Try with Success or Failure of the code block. Use with getTryWithCallerStacktrace. + */ + def doTryWithCallerStacktrace[T](f: => T): Try[T] = { + val t = Try { + f + } + t match { + case Failure(ex) => + // Note: we remove the common suffix instead of e.g. finding the call to this function, to + // account for recursive calls with multiple doTryWithCallerStacktrace on the stack trace. + val origStackTrace = ex.getStackTrace + val currentStackTrace = Thread.currentThread().getStackTrace + val commonSuffixLen = origStackTrace.reverse.zip(currentStackTrace.reverse).takeWhile { + case (exElem, currentElem) => exElem == currentElem + }.length + val belowEx = new Exception(TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE) + belowEx.setStackTrace(origStackTrace.dropRight(commonSuffixLen)) + ex.addSuppressed(belowEx) + + // keep the full original stack trace in a suppressed exception. + val fullEx = new Exception(TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE) + fullEx.setStackTrace(origStackTrace) + ex.addSuppressed(fullEx) + case Success(_) => // nothing + } + t + } + + /** + * Retrieve the result of Try that was created by doTryWithCallerStacktrace. + * + * In case of failure, the resulting exception has a stack trace that combines the stack trace + * below the original doTryWithCallerStacktrace which triggered it, with the caller stack trace + * of the current caller of getTryWithCallerStacktrace. + * + * Full stack trace of the original doTryWithCallerStacktrace caller can be retrieved with + * ``` + * ex.getSuppressed.find { e => + * e.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE + * } + * ``` + * + * + * @param t Try from doTryWithCallerStacktrace + * @return Result of the Try or rethrows the failure exception with modified stacktrace. + */ + def getTryWithCallerStacktrace[T](t: Try[T]): T = t match { + case Failure(ex) => + val belowStacktrace = ex.getSuppressed.find { e => + // added in doTryWithCallerStacktrace + e.getMessage == TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE + }.getOrElse { + // If we don't have the expected stacktrace information, just rethrow + throw ex + }.getStackTrace + // We are modifying and throwing the original exception. It would be better if we could + // return a copy, but we can't easily clone it and preserve. If this is accessed from + // multiple threads that then look at the stack trace, this could break. + ex.setStackTrace(belowStacktrace ++ Thread.currentThread().getStackTrace.drop(1)) + throw ex + case Success(s) => s + } + // A regular expression to match classes of the internal Spark API's // that we want to skip when finding the call site of a method. private val SPARK_CORE_CLASS_REGEX = @@ -1753,12 +1833,6 @@ private[spark] object Utils Files.createSymbolicLink(dst.toPath, src.toPath) } - - /** Return the class name of the given object, removing all dollar signs */ - def getFormattedClassName(obj: AnyRef): String = { - getSimpleName(obj.getClass).replace("$", "") - } - /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ @@ -2586,6 +2660,32 @@ private[spark] object Utils SignalUtils.registerLogger(log) } + /** + * Utility function to enable or disable structured logging based on system properties. + * This is designed for a code path which we cannot use SparkConf yet, and should be used before + * the first invocation of `Logging.log()`. For example, this should be used before `initDaemon`. + */ + def resetStructuredLogging(): Unit = { + if (System.getProperty(STRUCTURED_LOGGING_ENABLED.key, "false").equals("false")) { + Logging.disableStructuredLogging() + } else { + Logging.enableStructuredLogging() + } + } + + /** + * Utility function to enable or disable structured logging based on SparkConf. + * This is designed for a code path which logging system may be initilized before + * loading SparkConf. + */ + def resetStructuredLogging(sparkConf: SparkConf): Unit = { + if (sparkConf.getBoolean(STRUCTURED_LOGGING_ENABLED.key, defaultValue = true)) { + Logging.enableStructuredLogging() + } else { + Logging.disableStructuredLogging() + } + } + /** * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute * these jars through file server. In the YARN mode, it will return an empty list, since YARN @@ -2814,68 +2914,6 @@ private[spark] object Utils Hex.encodeHexString(secretBytes) } - /** - * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. - * This method mimics scalatest's getSimpleNameOfAnObjectsClass. - */ - def getSimpleName(cls: Class[_]): String = { - try { - cls.getSimpleName - } catch { - // TODO: the value returned here isn't even quite right; it returns simple names - // like UtilsSuite$MalformedClassObject$MalformedClass instead of MalformedClass - // The exact value may not matter much as it's used in log statements - case _: InternalError => - stripDollars(stripPackages(cls.getName)) - } - } - - /** - * Remove the packages from full qualified class name - */ - private def stripPackages(fullyQualifiedName: String): String = { - fullyQualifiedName.split("\\.").takeRight(1)(0) - } - - /** - * Remove trailing dollar signs from qualified class name, - * and return the trailing part after the last dollar sign in the middle - */ - @scala.annotation.tailrec - def stripDollars(s: String): String = { - val lastDollarIndex = s.lastIndexOf('$') - if (lastDollarIndex < s.length - 1) { - // The last char is not a dollar sign - if (lastDollarIndex == -1 || !s.contains("$iw")) { - // The name does not have dollar sign or is not an interpreter - // generated class, so we should return the full string - s - } else { - // The class name is interpreter generated, - // return the part after the last dollar sign - // This is the same behavior as getClass.getSimpleName - s.substring(lastDollarIndex + 1) - } - } - else { - // The last char is a dollar sign - // Find last non-dollar char - val lastNonDollarChar = s.findLast(_ != '$') - lastNonDollarChar match { - case None => s - case Some(c) => - val lastNonDollarIndex = s.lastIndexOf(c) - if (lastNonDollarIndex == -1) { - s - } else { - // Strip the trailing dollar signs - // Invoke stripDollars again to get the simple name - stripDollars(s.substring(0, lastNonDollarIndex + 1)) - } - } - } - } - /** * Regular expression matching full width characters. * @@ -2944,9 +2982,7 @@ private[spark] object Utils if (props == null) { return props } - val resultProps = new Properties() - props.forEach((k, v) => resultProps.put(k, v)) - resultProps + props.clone().asInstanceOf[Properties] } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 46e311d8b0476..ec43666898fa7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -208,7 +208,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) /** * Re-hash a value to deal better with hash functions that don't differ in the lower bits. */ - private def rehash(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() + private def rehash(h: Int): Int = Hashing.murmur3_32_fixed().hashInt(h).asInt() /** Double the table's size and re-hash everything */ protected def growTable(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 6bb5058f5ed14..7245d87a8baba 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -45,7 +45,7 @@ class BitSet(numBits: Int) extends Serializable { def setUntil(bitIndex: Int): Unit = { val wordIndex = bitIndex >> 6 // divide by 64 Arrays.fill(words, 0, wordIndex, -1) - if(wordIndex < words.length) { + if (wordIndex < words.length) { // Set the remaining bits (note that the mask could still be zero) val mask = ~(-1L << (bitIndex & 0x3f)) words(wordIndex) |= mask @@ -58,7 +58,7 @@ class BitSet(numBits: Int) extends Serializable { def clearUntil(bitIndex: Int): Unit = { val wordIndex = bitIndex >> 6 // divide by 64 Arrays.fill(words, 0, wordIndex, 0) - if(wordIndex < words.length) { + if (wordIndex < words.length) { // Clear the remaining bits val mask = -1L << (bitIndex & 0x3f) words(wordIndex) &= mask @@ -75,7 +75,7 @@ class BitSet(numBits: Int) extends Serializable { assert(newBS.numWords >= numWords) assert(newBS.numWords >= other.numWords) var ind = 0 - while( ind < smaller ) { + while (ind < smaller) { newBS.words(ind) = words(ind) & other.words(ind) ind += 1 } @@ -92,15 +92,15 @@ class BitSet(numBits: Int) extends Serializable { assert(newBS.numWords >= other.numWords) val smaller = math.min(numWords, other.numWords) var ind = 0 - while( ind < smaller ) { + while (ind < smaller) { newBS.words(ind) = words(ind) | other.words(ind) ind += 1 } - while( ind < numWords ) { + while (ind < numWords) { newBS.words(ind) = words(ind) ind += 1 } - while( ind < other.numWords ) { + while (ind < other.numWords) { newBS.words(ind) = other.words(ind) ind += 1 } @@ -242,7 +242,7 @@ class BitSet(numBits: Int) extends Serializable { def union(other: BitSet): Unit = { require(this.numWords <= other.numWords) var ind = 0 - while( ind < this.numWords ) { + while (ind < this.numWords) { this.words(ind) = this.words(ind) | other.words(ind) ind += 1 } diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index a42fa9ba6bc85..3d1eb5788c707 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -266,7 +266,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( /** * Re-hash a value to deal better with hash functions that don't differ in the lower bits. */ - private def hashcode(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() + private def hashcode(h: Int): Int = Hashing.murmur3_32_fixed().hashInt(h).asInt() private def nextPowerOf2(n: Int): Int = { if (n == 0) { diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 11bd2b2a3312c..802cb2667cc88 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -960,7 +960,7 @@ public void textFiles() throws IOException { rdd.saveAsTextFile(outputDir); // Read the plain text file and check it's OK File outputFile = new File(outputDir, "part-00000"); - String content = Files.toString(outputFile, StandardCharsets.UTF_8); + String content = Files.asCharSource(outputFile, StandardCharsets.UTF_8).read(); assertEquals("1\n2\n3\n4\n", content); // Also try reading it in as a text file RDD List expected = Arrays.asList("1", "2", "3", "4"); diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 9b70ccdf07e1b..a9d7e8a0f2eda 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -147,7 +147,7 @@ private[spark] object AccumulatorSuite { * Make an `AccumulableInfo` out of an `AccumulatorV2` with the intent to use the * info as an accumulator update. */ - def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) + def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfoUpdate /** * Run one or more Spark jobs and verify that in at least one job the peak execution memory diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 5651dc9b2dbdc..5f9912cbd021d 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -334,8 +334,8 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until 8) { val tempFile = new File(tempDir, s"part-0000$i") - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", tempFile, - StandardCharsets.UTF_8) + Files.asCharSink(tempFile, StandardCharsets.UTF_8) + .write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1") } for (p <- Seq(1, 2, 8)) { diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 934517a5b5606..ca51e61f5ed44 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -25,6 +25,7 @@ import scala.concurrent.{ExecutionContext, Future} // scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global // scalastyle:on executioncontextglobal +import scala.concurrent.Promise import scala.concurrent.duration._ import org.scalatest.BeforeAndAfter @@ -33,7 +34,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Deploy._ -import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorRemoved, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.scheduler.{JobFailed, SparkListener, SparkListenerExecutorRemoved, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.ThreadUtils /** @@ -207,7 +208,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => checkError( exception = e, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "0", @@ -221,7 +222,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sc.setJobGroup(jobGroupName, "") sc.parallelize(1 to 100).count() }, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "1", @@ -257,7 +258,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => checkError( exception = e, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "0", @@ -287,7 +288,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sem.acquire(1) sc.cancelJobGroupAndFutureJobs(s"job-group-$idx") ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => - assert(e.getErrorClass == "SPARK_JOB_CANCELLED") + assert(e.getCondition == "SPARK_JOB_CANCELLED") } } // submit a job with the 0 job group that was evicted from cancelledJobGroups set, it should run @@ -613,6 +614,40 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft f2.get() } + test("cancel FutureAction with custom reason") { + + val cancellationPromise = Promise[Unit]() + + // listener to capture job end events and their reasons + var failureReason: Option[String] = None + + sc = new SparkContext("local[2]", "test") + sc.addSparkListener(new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobEnd.jobResult match { + case jobFailed: JobFailed => + failureReason = Some(jobFailed.exception.getMessage) + case _ => // do nothing + } + } + }) + + val rdd = sc.parallelize(1 to 100, 2).map(_ * 2) + val asyncAction = rdd.collectAsync() + val reason = "custom cancel reason" + + Future { + asyncAction.cancel(Option(reason)) + cancellationPromise.success(()) + } + + // wait for the cancellation to complete and check the reason + cancellationPromise.future.map { _ => + Thread.sleep(1000) + assert(failureReason.contains(reason)) + } + } + test("interruptible iterator of shuffle reader") { // In this test case, we create a Spark job of two stages. The second stage is cancelled during // execution and a counter is used to make sure that the corresponding tasks are indeed diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 7106a780b3256..22c6280198c9a 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -27,7 +27,10 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel def sc: SparkContext = _sc - val conf = new SparkConf(false) + // SPARK-49647: use `SparkConf()` instead of `SparkConf(false)` because we want to + // load defaults from system properties and the classpath, including default test + // settings specified in the SBT and Maven build definitions. + val conf: SparkConf = new SparkConf() /** * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 12f9d2f83c777..dd42549e46d93 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -119,8 +119,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu val absolutePath2 = file2.getAbsolutePath try { - Files.write("somewords1", file1, StandardCharsets.UTF_8) - Files.write("somewords2", file2, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("somewords1") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("somewords2") val length1 = file1.length() val length2 = file2.length() @@ -178,10 +178,10 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu s"${jarFile.getParent}/../${jarFile.getParentFile.getName}/${jarFile.getName}#zoo" try { - Files.write("somewords1", file1, StandardCharsets.UTF_8) - Files.write("somewords22", file2, StandardCharsets.UTF_8) - Files.write("somewords333", file3, StandardCharsets.UTF_8) - Files.write("somewords4444", file4, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("somewords1") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("somewords22") + Files.asCharSink(file3, StandardCharsets.UTF_8).write("somewords333") + Files.asCharSink(file4, StandardCharsets.UTF_8).write("somewords4444") val length1 = file1.length() val length2 = file2.length() val length3 = file1.length() @@ -373,8 +373,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(subdir2.mkdir()) val file1 = new File(subdir1, "file") val file2 = new File(subdir2, "file") - Files.write("old", file1, StandardCharsets.UTF_8) - Files.write("new", file2, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("old") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("new") sc = new SparkContext("local-cluster[1,1,1024]", "test") sc.addFile(file1.getAbsolutePath) def getAddedFileContents(): String = { @@ -503,12 +503,15 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu try { // Create 5 text files. - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, - StandardCharsets.UTF_8) - Files.write("someline1 in file2\nsomeline2 in file2", file2, StandardCharsets.UTF_8) - Files.write("someline1 in file3", file3, StandardCharsets.UTF_8) - Files.write("someline1 in file4\nsomeline2 in file4", file4, StandardCharsets.UTF_8) - Files.write("someline1 in file2\nsomeline2 in file5", file5, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8) + .write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1") + Files.asCharSink(file2, StandardCharsets.UTF_8) + .write("someline1 in file2\nsomeline2 in file2") + Files.asCharSink(file3, StandardCharsets.UTF_8).write("someline1 in file3") + Files.asCharSink(file4, StandardCharsets.UTF_8) + .write("someline1 in file4\nsomeline2 in file4") + Files.asCharSink(file5, StandardCharsets.UTF_8) + .write("someline1 in file2\nsomeline2 in file5") sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) @@ -1420,6 +1423,43 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(conf) sc.stop() } + + test("SPARK-50247: BLOCK_MANAGER_REREGISTRATION_FAILED should be counted as network failure") { + // This test case follows the test structure of HEARTBEAT_FAILURE error code (SPARK-39957) + val conf = new SparkConf().set(TASK_MAX_FAILURES, 1) + val sc = new SparkContext("local-cluster[1, 1, 1024]", "test-exit-code", conf) + val result = sc.parallelize(1 to 10, 1).map { x => + val context = org.apache.spark.TaskContext.get() + if (context.taskAttemptId() == 0) { + System.exit(ExecutorExitCode.BLOCK_MANAGER_REREGISTRATION_FAILED) + } else { + x + } + }.count() + assert(result == 10L) + sc.stop() + } + + test("SPARK-50247: BLOCK_MANAGER_REREGISTRATION_FAILED will be counted as task failure when " + + "EXECUTOR_REMOVE_DELAY is disabled") { + // This test case follows the test structure of HEARTBEAT_FAILURE error code (SPARK-39957) + val conf = new SparkConf().set(TASK_MAX_FAILURES, 1).set(EXECUTOR_REMOVE_DELAY.key, "0s") + val sc = new SparkContext("local-cluster[1, 1, 1024]", "test-exit-code", conf) + eventually(timeout(30.seconds), interval(1.seconds)) { + val e = intercept[SparkException] { + sc.parallelize(1 to 10, 1).map { x => + val context = org.apache.spark.TaskContext.get() + if (context.taskAttemptId() == 0) { + System.exit(ExecutorExitCode.BLOCK_MANAGER_REREGISTRATION_FAILED) + } else { + x + } + }.count() + } + assert(e.getMessage.contains("Remote RPC client disassociated")) + } + sc.stop() + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 518c0592488fc..e38efc27b78f9 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -327,9 +327,9 @@ abstract class SparkFunSuite } /** - * Checks an exception with an error class against expected results. + * Checks an exception with an error condition against expected results. * @param exception The exception to check - * @param errorClass The expected error class identifying the error + * @param condition The expected error condition identifying the error * @param sqlState Optional the expected SQLSTATE, not verified if not supplied * @param parameters A map of parameter names and values. The names are as defined * in the error-classes file. @@ -338,12 +338,12 @@ abstract class SparkFunSuite */ protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String] = None, parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, queryContext: Array[ExpectedContext] = Array.empty): Unit = { - assert(exception.getErrorClass === errorClass) + assert(exception.getCondition === condition) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala if (matchPVals) { @@ -380,63 +380,65 @@ abstract class SparkFunSuite } else if (actual.contextType() == QueryContextType.DataFrame) { assert(actual.fragment() === expected.fragment, "Invalid code fragment of a query context. Actual:" + actual.toString) - assert(actual.callSite().matches(expected.callSitePattern), - "Invalid callSite of a query context. Actual:" + actual.toString) + if (expected.callSitePattern.nonEmpty) { + assert(actual.callSite().matches(expected.callSitePattern), + "Invalid callSite of a query context. Actual:" + actual.toString) + } } } } protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, parameters: Map[String, String]): Unit = - checkError(exception, errorClass, Some(sqlState), parameters) + checkError(exception, condition, Some(sqlState), parameters) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context)) + checkError(exception, condition, Some(sqlState), parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, None, parameters, false, Array(context)) + checkError(exception, condition, None, parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, context: ExpectedContext): Unit = - checkError(exception, errorClass, None, Map.empty, false, Array(context)) + checkError(exception, condition, Some(sqlState), Map.empty, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String], parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, sqlState, parameters, + checkError(exception, condition, sqlState, parameters, false, Array(context)) protected def checkErrorMatchPVals( exception: SparkThrowable, - errorClass: String, + condition: String, parameters: Map[String, String]): Unit = - checkError(exception, errorClass, None, parameters, matchPVals = true) + checkError(exception, condition, None, parameters, matchPVals = true) protected def checkErrorMatchPVals( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String], parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, sqlState, parameters, + checkError(exception, condition, sqlState, parameters, matchPVals = true, Array(context)) protected def checkErrorTableNotFound( @@ -444,7 +446,7 @@ abstract class SparkFunSuite tableName: String, queryContext: ExpectedContext): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> tableName), queryContext = Array(queryContext)) @@ -452,13 +454,13 @@ abstract class SparkFunSuite exception: SparkThrowable, tableName: String): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> tableName)) protected def checkErrorTableAlreadyExists(exception: SparkThrowable, tableName: String): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> tableName)) case class ExpectedContext( diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 231cfdc3f32fc..ea845c0f93a4b 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -43,16 +43,13 @@ class SparkThrowableSuite extends SparkFunSuite { /* Used to regenerate the error class file. Run: {{{ SPARK_GENERATE_GOLDEN_FILES=1 build/sbt \ - "core/testOnly *SparkThrowableSuite -- -t \"Error classes are correctly formatted\"" + "core/testOnly *SparkThrowableSuite -- -t \"Error conditions are correctly formatted\"" }}} */ private val regenerateCommand = "SPARK_GENERATE_GOLDEN_FILES=1 build/sbt " + - "\"core/testOnly *SparkThrowableSuite -- -t \\\"Error classes match with document\\\"\"" + "\"core/testOnly *SparkThrowableSuite -- -t \\\"Error conditions are correctly formatted\\\"\"" private val errorJsonFilePath = getWorkspaceFilePath( - // Note that though we call them "error classes" here, the proper name is "error conditions", - // hence why the name of the JSON file is different. We will address this inconsistency as part - // of this ticket: https://issues.apache.org/jira/browse/SPARK-47429 "common", "utils", "src", "main", "resources", "error", "error-conditions.json") private val errorReader = new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL)) @@ -81,8 +78,8 @@ class SparkThrowableSuite extends SparkFunSuite { mapper.readValue(errorJsonFilePath.toUri.toURL, new TypeReference[Map[String, ErrorInfo]]() {}) } - test("Error classes are correctly formatted") { - val errorClassFileContents = + test("Error conditions are correctly formatted") { + val errorConditionFileContents = IOUtils.toString(errorJsonFilePath.toUri.toURL.openStream(), StandardCharsets.UTF_8) val mapper = JsonMapper.builder() .addModule(DefaultScalaModule) @@ -96,33 +93,30 @@ class SparkThrowableSuite extends SparkFunSuite { .writeValueAsString(errorReader.errorInfoMap) if (regenerateGoldenFiles) { - if (rewrittenString.trim != errorClassFileContents.trim) { - val errorClassesFile = errorJsonFilePath.toFile - logInfo(s"Regenerating error class file $errorClassesFile") - Files.delete(errorClassesFile.toPath) + if (rewrittenString.trim != errorConditionFileContents.trim) { + val errorConditionsFile = errorJsonFilePath.toFile + logInfo(s"Regenerating error conditions file $errorConditionsFile") + Files.delete(errorConditionsFile.toPath) FileUtils.writeStringToFile( - errorClassesFile, + errorConditionsFile, rewrittenString + lineSeparator, StandardCharsets.UTF_8) } } else { - assert(rewrittenString.trim == errorClassFileContents.trim) + assert(rewrittenString.trim == errorConditionFileContents.trim) } } test("SQLSTATE is mandatory") { - val errorClassesNoSqlState = errorReader.errorInfoMap.filter { + val errorConditionsNoSqlState = errorReader.errorInfoMap.filter { case (error: String, info: ErrorInfo) => !error.startsWith("_LEGACY_ERROR_TEMP") && info.sqlState.isEmpty }.keys.toSeq - assert(errorClassesNoSqlState.isEmpty, - s"Error classes without SQLSTATE: ${errorClassesNoSqlState.mkString(", ")}") + assert(errorConditionsNoSqlState.isEmpty, + s"Error classes without SQLSTATE: ${errorConditionsNoSqlState.mkString(", ")}") } test("Error class and error state / SQLSTATE invariants") { - // Unlike in the rest of the codebase, the term "error class" is used here as it is in our - // documentation as well as in the SQL standard. We can remove this comment as part of this - // ticket: https://issues.apache.org/jira/browse/SPARK-47429 val errorClassesJson = Utils.getSparkClassLoader.getResource("error/error-classes.json") val errorStatesJson = Utils.getSparkClassLoader.getResource("error/error-states.json") val mapper = JsonMapper.builder() @@ -171,9 +165,9 @@ class SparkThrowableSuite extends SparkFunSuite { .enable(SerializationFeature.INDENT_OUTPUT) .build() mapper.writeValue(tmpFile, errorReader.errorInfoMap) - val rereadErrorClassToInfoMap = mapper.readValue( + val rereadErrorConditionToInfoMap = mapper.readValue( tmpFile, new TypeReference[Map[String, ErrorInfo]]() {}) - assert(rereadErrorClassToInfoMap == errorReader.errorInfoMap) + assert(rereadErrorConditionToInfoMap == errorReader.errorInfoMap) } test("Error class names should contain only capital letters, numbers and underscores") { @@ -205,15 +199,8 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { getMessage("UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", Map.empty[String, String]) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessageParameters().get("message").contains("Undefined error message parameter")) - - // Does not fail with too many args (expects 0 args) - assert(getMessage("DIVIDE_BY_ZERO", Map("config" -> "foo", "a" -> "bar")) == - "[DIVIDE_BY_ZERO] Division by zero. " + - "Use `try_divide` to tolerate divisor being 0 and return NULL instead. " + - "If necessary set foo to \"false\" " + - "to bypass this error. SQLSTATE: 22012") } test("Error message is formatted") { @@ -258,7 +245,8 @@ class SparkThrowableSuite extends SparkFunSuite { throw new SparkException("Arbitrary legacy message") } catch { case e: SparkThrowable => - assert(e.getErrorClass == null) + assert(e.getCondition == null) + assert(!e.isInternalError) assert(e.getSqlState == null) case _: Throwable => // Should not end up here @@ -274,7 +262,8 @@ class SparkThrowableSuite extends SparkFunSuite { cause = null) } catch { case e: SparkThrowable => - assert(e.getErrorClass == "CANNOT_PARSE_DECIMAL") + assert(e.getCondition == "CANNOT_PARSE_DECIMAL") + assert(!e.isInternalError) assert(e.getSqlState == "22018") case _: Throwable => // Should not end up here @@ -368,7 +357,7 @@ class SparkThrowableSuite extends SparkFunSuite { |}""".stripMargin) // Legacy mode when an exception does not have any error class class LegacyException extends Throwable with SparkThrowable { - override def getErrorClass: String = null + override def getCondition: String = null override def getMessage: String = "Test message" } val e3 = new LegacyException @@ -463,7 +452,7 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURI.toURL)) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessage.contains("DIVIDE.BY_ZERO")) } @@ -489,8 +478,54 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURI.toURL)) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessage.contains("BY.ZERO")) } } + + test("handle null values in message parameters") { + withTempDir { dir => + val json = new File(dir, "errors.json") + FileUtils.writeStringToFile(json, + """ + |{ + | "MISSING_PARAMETER" : { + | "message" : [ + | "Parameter is missing." + | ] + | } + |} + |""".stripMargin, StandardCharsets.UTF_8) + + val reader = new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURI.toURL)) + // Attempt to get the error message with a null parameter + val errorMessage = reader.getErrorMessage("MISSING_PARAMETER", Map("param" -> null)) + + assert(errorMessage.contains("Parameter null is missing.")) + } + } + + test("detect unused message parameters") { + checkError( + exception = intercept[SparkException] { + SparkThrowableHelper.getMessage( + errorClass = "CANNOT_UP_CAST_DATATYPE", + messageParameters = Map( + "expression" -> "CAST('aaa' AS LONG)", + "sourceType" -> "STRING", + "targetType" -> "LONG", + "op" -> "CAST", // unused parameter + "details" -> "implicit cast" + )) + }, + condition = "INTERNAL_ERROR", + parameters = Map( + "message" -> + ("Found unused message parameters of the error class 'CANNOT_UP_CAST_DATATYPE'. " + + "Its error message format has 4 placeholders, but the passed message parameters map " + + "has 5 items. Consider to add placeholders to the error format or " + + "remove unused message parameters.") + ) + ) + } } diff --git a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala index e7315d6119be0..7e88c7ee684bd 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala @@ -94,9 +94,11 @@ private[spark] class Benchmark( /** * Runs the benchmark and outputs the results to stdout. This should be copied and added as * a comment with the benchmark. Although the results vary from machine to machine, it should - * provide some baseline. + * provide some baseline. If `relativeTime` is set to `true`, the `Relative` column will be + * the relative time of each case relative to the first case (less is better). Otherwise, it + * will be the relative execution speed of each case relative to the first case (more is better). */ - def run(): Unit = { + def run(relativeTime: Boolean = false): Unit = { require(benchmarks.nonEmpty) // scalastyle:off println("Running benchmark: " + name) @@ -112,10 +114,12 @@ private[spark] class Benchmark( out.println(Benchmark.getJVMOSInfo()) out.println(Benchmark.getProcessorName()) val nameLen = Math.max(40, Math.max(name.length, benchmarks.map(_.name.length).max)) + val relativeHeader = if (relativeTime) "Relative time" else "Relative" out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n", - name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", "Per Row(ns)", "Relative") + name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", "Per Row(ns)", relativeHeader) out.println("-" * (nameLen + 80)) results.zip(benchmarks).foreach { case (result, benchmark) => + val relative = if (relativeTime) result.bestMs / firstBest else firstBest / result.bestMs out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n", benchmark.name, "%5.0f" format result.bestMs, @@ -123,7 +127,7 @@ private[spark] class Benchmark( "%5.0f" format result.stdevMs, "%10.1f" format result.bestRate, "%6.1f" format (1000 / result.bestRate), - "%3.1fX" format (firstBest / result.bestMs)) + "%3.1fX" format relative) } out.println() // scalastyle:on diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 1efef3383b821..b0f36b9744fa8 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -317,13 +317,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio // Instead, crash the driver by directly accessing the broadcast value. val e1 = intercept[SparkException] { broadcast.value } assert(e1.isInternalError) - assert(e1.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e1.getCondition == "INTERNAL_ERROR_BROADCAST") val e2 = intercept[SparkException] { broadcast.unpersist(blocking = true) } assert(e2.isInternalError) - assert(e2.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e2.getCondition == "INTERNAL_ERROR_BROADCAST") val e3 = intercept[SparkException] { broadcast.destroy(blocking = true) } assert(e3.isInternalError) - assert(e3.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e3.getCondition == "INTERNAL_ERROR_BROADCAST") } else { val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) @@ -339,7 +339,7 @@ package object testPackage extends Assertions { val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) assert(thrown.isInternalError) - assert(thrown.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(thrown.getCondition == "INTERNAL_ERROR_BROADCAST") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 98f475e2d9ce8..015a38a631461 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.net.URL +import java.net.URI import scala.collection.mutable import scala.io.Source @@ -65,7 +65,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) info.logUrlMap.values.foreach { logUrl => - assert(new URL(logUrl).getHost === SPARK_PUBLIC_DNS) + assert(new URI(logUrl).toURL.getHost === SPARK_PUBLIC_DNS) } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 40d8eae644a07..ca81283e073ac 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -1802,6 +1802,23 @@ class SparkSubmitSuite val (_, classpath, _, _) = submit.prepareSubmitEnvironment(appArgs) assert(classpath.contains(".")) } + + // Requires Python dependencies for Spark Connect. Should be enabled by default. + ignore("Spark Connect application submission (Python)") { + val pyFile = File.createTempFile("remote_test", ".py") + pyFile.deleteOnExit() + val content = + "from pyspark.sql import SparkSession;" + + "spark = SparkSession.builder.getOrCreate();" + + "assert 'connect' in str(type(spark));" + + "assert spark.range(1).first()[0] == 0" + FileUtils.write(pyFile, content, StandardCharsets.UTF_8) + val args = Seq( + "--name", "testPyApp", + "--remote", "local", + pyFile.getAbsolutePath) + runSparkSubmit(args) + } } object JarCreationTest extends Logging { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala index f34f792881f90..7501a98a1a573 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala @@ -221,7 +221,7 @@ class SingleFileEventLogFileReaderSuite extends EventLogFileReadersSuite { val entry = is.getNextEntry assert(entry != null) val actual = new String(ByteStreams.toByteArray(is), StandardCharsets.UTF_8) - val expected = Files.toString(new File(logPath.toString), StandardCharsets.UTF_8) + val expected = Files.asCharSource(new File(logPath.toString), StandardCharsets.UTF_8).read() assert(actual === expected) assert(is.getNextEntry === null) } @@ -368,8 +368,8 @@ class RollingEventLogFilesReaderSuite extends EventLogFileReadersSuite { assert(allFileNames.contains(fileName)) val actual = new String(ByteStreams.toByteArray(is), StandardCharsets.UTF_8) - val expected = Files.toString(new File(logPath.toString, fileName), - StandardCharsets.UTF_8) + val expected = Files.asCharSource( + new File(logPath.toString, fileName), StandardCharsets.UTF_8).read() assert(actual === expected) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 3013a5bf4a294..852f94bda870d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -708,7 +708,8 @@ abstract class FsHistoryProviderSuite extends SparkFunSuite with Matchers with P while (entry != null) { val actual = new String(ByteStreams.toByteArray(inputStream), StandardCharsets.UTF_8) val expected = - Files.toString(logs.find(_.getName == entry.getName).get, StandardCharsets.UTF_8) + Files.asCharSource(logs.find(_.getName == entry.getName).get, StandardCharsets.UTF_8) + .read() actual should be (expected) totalEntries += 1 entry = inputStream.getNextEntry diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala index 2b9b110a41424..807e5ec3e823e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala @@ -45,8 +45,8 @@ class HistoryServerArgumentsSuite extends SparkFunSuite { test("Properties File Arguments Parsing --properties-file") { withTempDir { tmpDir => val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) - Files.write("spark.test.CustomPropertyA blah\n" + - "spark.test.CustomPropertyB notblah\n", outFile, UTF_8) + Files.asCharSink(outFile, UTF_8).write("spark.test.CustomPropertyA blah\n" + + "spark.test.CustomPropertyB notblah\n") val argStrings = Array("--properties-file", outFile.getAbsolutePath) val hsa = new HistoryServerArguments(conf, argStrings) assert(conf.get("spark.test.CustomPropertyA") === "blah") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerPageSuite.scala index 83ef300460b95..100145a2f4833 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerPageSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerPageSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.history -import java.net.URL +import java.net.URI import jakarta.servlet.http.HttpServletResponse import org.json4s.DefaultFormats @@ -43,12 +43,13 @@ class HistoryServerPageSuite extends SparkFunSuite with BeforeAndAfter { private val localhost: String = Utils.localHostNameForURI() private var port: Int = -1 - private def startHistoryServer(logDir: String): Unit = { + private def startHistoryServer(logDir: String, title: Option[String] = None): Unit = { assert(server.isEmpty) val conf = new SparkConf() .set(HISTORY_LOG_DIR, logDir) .set(UPDATE_INTERVAL_S.key, "0") .set(IS_TESTING, true) + title.foreach(conf.set(HISTORY_SERVER_UI_TITLE.key, _)) val provider = new FsHistoryProvider(conf) provider.checkForLogs() val securityManager = HistoryServer.createSecurityManager(conf) @@ -71,7 +72,7 @@ class HistoryServerPageSuite extends SparkFunSuite with BeforeAndAfter { ApplicationStatus.COMPLETED.toString.toLowerCase() } val (code, jsonOpt, errOpt) = HistoryServerSuite.getContentAndCode( - new URL(s"http://$localhost:$port/api/v1/applications?status=$param") + new URI(s"http://$localhost:$port/api/v1/applications?status=$param").toURL ) assert(code == HttpServletResponse.SC_OK) assert(jsonOpt.isDefined) @@ -100,4 +101,19 @@ class HistoryServerPageSuite extends SparkFunSuite with BeforeAndAfter { stopHistoryServer() } } + + test("SPARK-49128: Support custom History Server UI title") { + Seq(None, Some("Custom History Server Title")).foreach { title => + startHistoryServer(logDirs.head, title) + val page = new HistoryPage(server.get) + val (code, htmlOpt, errOpt) = HistoryServerSuite.getContentAndCode( + new URI(s"http://$localhost:$port/").toURL + ) + assert(code == HttpServletResponse.SC_OK) + val expected = title.getOrElse("History Server") + assert(htmlOpt.isDefined && htmlOpt.get.contains(s"$expected")) + assert(errOpt.isEmpty) + stopHistoryServer() + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 91a93bbe01d7f..6b2bd90cd4314 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.history import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} -import java.net.{HttpURLConnection, URL} +import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets import java.util.zip.ZipInputStream @@ -261,9 +261,9 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with val url = attemptId match { case Some(id) => - new URL(s"${generateURL(s"applications/$appId")}/$id/logs") + new URI(s"${generateURL(s"applications/$appId")}/$id/logs").toURL case None => - new URL(s"${generateURL(s"applications/$appId")}/logs") + new URI(s"${generateURL(s"applications/$appId")}/logs").toURL } val (code, inputStream, error) = HistoryServerSuite.connectAndGetInputStream(url) @@ -283,7 +283,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with val expectedFile = { new File(logDir, entry.getName) } - val expected = Files.toString(expectedFile, StandardCharsets.UTF_8) + val expected = Files.asCharSource(expectedFile, StandardCharsets.UTF_8).read() val actual = new String(ByteStreams.toByteArray(zipStream), StandardCharsets.UTF_8) actual should be (expected) filesCompared += 1 @@ -433,12 +433,12 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with // build a URL for an app or app/attempt plus a page underneath def buildURL(appId: String, suffix: String): URL = { - new URL(s"http://$localhost:$port/history/$appId$suffix") + new URI(s"http://$localhost:$port/history/$appId$suffix").toURL } // build a rest URL for the application and suffix. def applications(appId: String, suffix: String): URL = { - new URL(s"http://$localhost:$port/api/v1/applications/$appId$suffix") + new URI(s"http://$localhost:$port/api/v1/applications/$appId$suffix").toURL } // start initial job @@ -601,7 +601,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with tests.foreach { case (user, expectedCode) => testUrls.foreach { url => val headers = if (user != null) Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) else Nil - val sc = TestUtils.httpResponseCode(new URL(url), headers = headers) + val sc = TestUtils.httpResponseCode(new URI(url).toURL, headers = headers) assert(sc === expectedCode, s"Unexpected status code $sc for $url (user = $user)") } } @@ -620,7 +620,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with s"http://$localhost:$port/api/v1/applications/$appId/2/logs") testUrls.foreach { url => - TestUtils.httpResponseCode(new URL(url)) + TestUtils.httpResponseCode(new URI(url).toURL) } assert(server.cacheMetrics.loadCount.getCount === 0, "downloading event log shouldn't load ui") } @@ -642,7 +642,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with case _ => assert(location.stripSuffix("/") === url.toString) } - HistoryServerSuite.getUrl(new URL(location)) + HistoryServerSuite.getUrl(new URI(location).toURL) } } @@ -676,14 +676,14 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with def buildPageAttemptUrl(appId: String, attemptId: Option[Int]): URL = { attemptId match { case Some(id) => - new URL(s"http://$localhost:$port/history/$appId/$id") + new URI(s"http://$localhost:$port/history/$appId/$id").toURL case None => - new URL(s"http://$localhost:$port/history/$appId") + new URI(s"http://$localhost:$port/history/$appId").toURL } } def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { - HistoryServerSuite.getContentAndCode(new URL(s"http://$localhost:$port/api/v1/$path")) + HistoryServerSuite.getContentAndCode(new URI(s"http://$localhost:$port/api/v1/$path").toURL) } def getUrl(path: String): String = { @@ -691,7 +691,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with } def generateURL(path: String): URL = { - new URL(s"http://$localhost:$port/api/v1/$path") + new URI(s"http://$localhost:$port/api/v1/$path").toURL } def generateExpectation(name: String, path: String): Unit = { @@ -706,7 +706,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with test("SPARK-31697: HistoryServer should set Content-Type") { val port = server.boundPort val nonExistenceAppId = "local-non-existence" - val url = new URL(s"http://$localhost:$port/history/$nonExistenceAppId") + val url = new URI(s"http://$localhost:$port/history/$nonExistenceAppId").toURL val conn = url.openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod("GET") conn.connect() @@ -717,7 +717,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with test("Redirect to the root page when accessed to /history/") { val port = server.boundPort - val url = new URL(s"http://$localhost:$port/history/") + val url = new URI(s"http://$localhost:$port/history/").toURL val conn = url.openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod("GET") conn.setUseCaches(false) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterDecommisionSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterDecommisionSuite.scala index 8c17324d2e38c..19ddb633c9f5c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterDecommisionSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterDecommisionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -import java.net.{HttpURLConnection, URL} +import java.net.{HttpURLConnection, URI} import scala.concurrent.duration._ @@ -38,7 +38,7 @@ class MasterDecommisionSuite extends MasterSuiteBase { val masterUrl = s"http://${Utils.localHostNameForURI()}:${localCluster.masterWebUIPort}" try { eventually(timeout(30.seconds), interval(100.milliseconds)) { - val url = new URL(s"$masterUrl/workers/kill/?host=${Utils.localHostNameForURI()}") + val url = new URI(s"$masterUrl/workers/kill/?host=${Utils.localHostNameForURI()}").toURL val conn = url.openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod("POST") assert(conn.getResponseCode === 405) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index e75c4ca88069d..e64bc724cfba0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -183,6 +183,28 @@ class MasterSuite extends MasterSuiteBase { assert(master.invokePrivate(_newDriverId(submitDate)) === "my-driver-00001") } + test("SPARK-50208: Use driverID as appName in javaOpts") { + val master = makeMaster(new SparkConf().set(MASTER_USE_DRIVER_ID_AS_APP_NAME, true)) + val command = Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", "{{USER_JAR}}", "mainClass"), + Map.empty, Seq.empty, Seq.empty, Seq.empty) + val desc = DriverDescription("", 1, 1, false, command) + val result = master.invokePrivate(_maybeUpdateAppName(desc, "driver-id")) + assert(result.command.javaOpts.contains("-Dspark.app.name=driver-id")) + } + + test("SPARK-50208: Use driverID as appName in arguments") { + val master = makeMaster(new SparkConf().set(MASTER_USE_DRIVER_ID_AS_APP_NAME, true)) + val command = Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", "{{USER_JAR}}", "org.apache.spark.deploy.SparkSubmit", "pi.py"), + Map.empty, Seq.empty, Seq.empty, Seq.empty) + val desc = DriverDescription("", 1, 1, false, command) + val result = master.invokePrivate(_maybeUpdateAppName(desc, "driver-id")) + assert(result.command.arguments.contains("spark.app.name=driver-id")) + } + test("SPARK-45753: Prevent invalid driver id patterns") { val m = intercept[IllegalArgumentException] { makeMaster(new SparkConf().set(DRIVER_ID_PATTERN, "my driver")) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuiteBase.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuiteBase.scala index 629112a27463c..2e159b8288845 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuiteBase.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuiteBase.scala @@ -443,6 +443,8 @@ trait MasterSuiteBase extends SparkFunSuite private val _state = PrivateMethod[RecoveryState.Value](Symbol("state")) protected val _newDriverId = PrivateMethod[String](Symbol("newDriverId")) protected val _newApplicationId = PrivateMethod[String](Symbol("newApplicationId")) + protected val _maybeUpdateAppName = + PrivateMethod[DriverDescription](Symbol("maybeUpdateAppName")) protected val _createApplication = PrivateMethod[ApplicationInfo](Symbol("createApplication")) protected val _persistenceEngine = PrivateMethod[PersistenceEngine](Symbol("persistenceEngine")) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterWorkerUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterWorkerUISuite.scala index 428539068a107..5f206f611fe6b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterWorkerUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterWorkerUISuite.scala @@ -140,4 +140,22 @@ class MasterWorkerUISuite extends MasterSuiteBase { System.getProperties().remove("spark.ui.proxyBase") } } + + test("SPARK-49007: Support custom master web ui title") { + implicit val formats = org.json4s.DefaultFormats + val title = "Spark Custom Title" + val conf = new SparkConf().set(MASTER_UI_TITLE, title) + val localCluster = LocalSparkCluster(2, 2, 512, conf) + localCluster.start() + val masterUrl = s"http://${Utils.localHostNameForURI()}:${localCluster.masterWebUIPort}" + try { + eventually(timeout(50.seconds), interval(100.milliseconds)) { + val html = Utils + .tryWithResource(Source.fromURL(s"$masterUrl/"))(_.getLines().mkString("\n")) + html should include (title) + } + } finally { + localCluster.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala index a9b96f85808d0..ccfc4ee1600a5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala @@ -36,13 +36,16 @@ class ApplicationPageSuite extends SparkFunSuite { private val rp = new ResourceProfile(Map.empty, Map.empty) private val desc = ApplicationDescription("name", Some(4), null, "appUiUrl", rp) + private val descWithoutUI = ApplicationDescription("name", Some(4), null, "", rp) private val appFinished = new ApplicationInfo(0, "app-finished", desc, new Date, null, 1) appFinished.markFinished(ApplicationState.FINISHED) private val appLive = new ApplicationInfo(0, "app-live", desc, new Date, null, 1) + private val appLiveWithoutUI = + new ApplicationInfo(0, "app-live-without-ui", descWithoutUI, new Date, null, 1) private val state = mock(classOf[MasterStateResponse]) when(state.completedApps).thenReturn(Array(appFinished)) - when(state.activeApps).thenReturn(Array(appLive)) + when(state.activeApps).thenReturn(Array(appLive, appLiveWithoutUI)) private val rpc = mock(classOf[RpcEndpointRef]) when(rpc.askSync[MasterStateResponse](RequestMasterState)).thenReturn(state) @@ -61,6 +64,16 @@ class ApplicationPageSuite extends SparkFunSuite { assert(!result.contains(master.historyServerUrl.get)) } + test("SPARK-50021: Application Detail UI is empty when spark.ui.enabled=false") { + val request = mock(classOf[HttpServletRequest]) + when(request.getParameter("appId")).thenReturn("app-live-without-ui") + + val result = new ApplicationPage(masterWebUI).render(request).toString() + assert(result.contains("Application UI: Disabled")) + assert(!result.contains("Application History UI")) + assert(!result.contains(master.historyServerUrl.get)) + } + test("SPARK-45774: Application History UI") { val request = mock(classOf[HttpServletRequest]) when(request.getParameter("appId")).thenReturn("app-finished") diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index caefab9e0746b..5e75d1c424eab 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master.ui import java.io.DataOutputStream -import java.net.{HttpURLConnection, URL} +import java.net.{HttpURLConnection, URI} import java.nio.charset.StandardCharsets import java.util.Date @@ -125,7 +125,7 @@ object MasterWebUISuite { url: String, method: String, body: String = ""): HttpURLConnection = { - val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] + val conn = new URI(url).toURL.openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod(method) if (body.nonEmpty) { conn.setDoOutput(true) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/ReadOnlyMasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/ReadOnlyMasterWebUISuite.scala index 20ff932eb01a3..2679349bfe028 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/ReadOnlyMasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/ReadOnlyMasterWebUISuite.scala @@ -17,24 +17,31 @@ package org.apache.spark.deploy.master.ui +import java.util.Date + import scala.io.Source import jakarta.servlet.http.HttpServletResponse.{SC_METHOD_NOT_ALLOWED, SC_OK} import org.mockito.Mockito.{mock, when} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ import org.apache.spark.deploy.master.ui.MasterWebUISuite._ import org.apache.spark.internal.config.DECOMMISSION_ENABLED +import org.apache.spark.internal.config.UI.MASTER_UI_VISIBLE_ENV_VAR_PREFIXES import org.apache.spark.internal.config.UI.UI_KILL_ENABLED import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.util.Utils class ReadOnlyMasterWebUISuite extends SparkFunSuite { + import org.apache.spark.deploy.DeployTestUtils._ + val conf = new SparkConf() .set(UI_KILL_ENABLED, false) .set(DECOMMISSION_ENABLED, false) + .set(MASTER_UI_VISIBLE_ENV_VAR_PREFIXES.key, "SPARK_SCALA_") val securityMgr = new SecurityManager(conf) val rpcEnv = mock(classOf[RpcEnv]) val master = mock(classOf[Master]) @@ -43,6 +50,14 @@ class ReadOnlyMasterWebUISuite extends SparkFunSuite { when(master.conf).thenReturn(conf) when(master.rpcEnv).thenReturn(rpcEnv) when(master.self).thenReturn(masterEndpointRef) + val desc1 = createAppDesc().copy(name = "WithUI") + val desc2 = desc1.copy(name = "WithoutUI", appUiUrl = "") + val app1 = new ApplicationInfo(new Date().getTime, "app1", desc1, new Date(), null, Int.MaxValue) + val app2 = new ApplicationInfo(new Date().getTime, "app2", desc2, new Date(), null, Int.MaxValue) + val state = new MasterStateResponse( + "host", 8080, None, Array.empty, Array(app1, app2), Array.empty, + Array.empty, Array.empty, RecoveryState.ALIVE) + when(masterEndpointRef.askSync[MasterStateResponse](RequestMasterState)).thenReturn(state) val masterWebUI = new MasterWebUI(master, 0) override def beforeAll(): Unit = { @@ -86,4 +101,22 @@ class ReadOnlyMasterWebUISuite extends SparkFunSuite { assert(result.contains("Spark Properties")) assert(result.contains("Hadoop Properties")) } + + test("SPARK-49206: Add 'Environment Variables' table to Master 'EnvironmentPage'") { + val url = s"http://${Utils.localHostNameForURI()}:${masterWebUI.boundPort}/environment/" + val conn = sendHttpRequest(url, "GET", "") + assert(conn.getResponseCode === SC_OK) + val result = Source.fromInputStream(conn.getInputStream).mkString + assert(result.contains("Environment Variables")) + assert(result.contains("
Thread ID
SPARK_SCALA_VERSION2.1")) + } + + test("SPARK-50022: Fix 'MasterPage' to hide App UI links when UI is disabled") { + val url = s"http://${Utils.localHostNameForURI()}:${masterWebUI.boundPort}/" + val conn = sendHttpRequest(url, "GET") + assert(conn.getResponseCode === SC_OK) + val result = Source.fromInputStream(conn.getInputStream).mkString + assert(result.contains("WithUI")) + assert(result.contains(" WithoutUI\n")) + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 43a5f7a784f53..a155e4cc3ac90 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.deploy.rest import java.io.DataOutputStream -import java.net.{HttpURLConnection, URL} +import java.net.{HttpURLConnection, URI} import java.nio.charset.StandardCharsets +import java.util.Base64 import scala.collection.mutable import jakarta.servlet.http.HttpServletResponse +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.eclipse.jetty.util.thread.ThreadPool.SizedThreadPool import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ @@ -32,6 +35,7 @@ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState._ import org.apache.spark.deploy.master.RecoveryState +import org.apache.spark.internal.config.{MASTER_REST_SERVER_FILTERS, MASTER_REST_SERVER_MAX_THREADS, MASTER_REST_SERVER_VIRTUAL_THREADS} import org.apache.spark.rpc._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -92,7 +96,6 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { val RANDOM_PORT = 9000 val allMasters = s"$masterUrl,${Utils.localHostName()}:$RANDOM_PORT" conf.set("spark.master", allMasters) - conf.set("spark.app.name", "dreamer") val appArgs = Array("one", "two", "six") // main method calls this val response = new RestSubmissionClientApp().run("app-resource", "main-class", appArgs, conf) @@ -110,7 +113,6 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) val conf = new SparkConf(loadDefaults = false) conf.set("spark.master", masterUrl) - conf.set("spark.app.name", "dreamer") val appArgs = Array("one", "two", "six") // main method calls this val response = new RestSubmissionClientApp().run("app-resource", "main-class", appArgs, conf) @@ -445,6 +447,30 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { assert(filteredVariables == Map("SPARK_VAR" -> "1")) } + test("SPARK-49033: Support server-side environment variable replacement in REST Submission API") { + val request = new CreateSubmissionRequest + request.appResource = "" + request.mainClass = "" + request.appArgs = Array.empty[String] + request.sparkProperties = Map.empty[String, String] + request.environmentVariables = Map("AWS_ENDPOINT_URL" -> "{{SPARK_SCALA_VERSION}}") + val servlet = new StandaloneSubmitRequestServlet(null, null, null) + val desc = servlet.buildDriverDescription(request, "spark://master:7077", 6066) + assert(desc.command.environment.get("AWS_ENDPOINT_URL") === Some("2.13")) + } + + test("SPARK-49034: Support server-side sparkProperties replacement in REST Submission API") { + val request = new CreateSubmissionRequest + request.appResource = "" + request.mainClass = "" + request.appArgs = Array.empty[String] + request.sparkProperties = Map("spark.hadoop.fs.s3a.endpoint" -> "{{SPARK_SCALA_VERSION}}") + request.environmentVariables = Map.empty[String, String] + val servlet = new StandaloneSubmitRequestServlet(null, null, null) + val desc = servlet.buildDriverDescription(request, "spark://master:7077", 6066) + assert(desc.command.javaOpts.exists(_.contains("-Dspark.hadoop.fs.s3a.endpoint=2.13"))) + } + test("SPARK-45197: Make StandaloneRestServer add JavaModuleOptions to drivers") { val request = new CreateSubmissionRequest request.appResource = "" @@ -457,6 +483,100 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { assert(desc.command.javaOpts.exists(_.startsWith("--add-opens"))) } + test("SPARK-49103: `spark.master.rest.filters` loads filters successfully") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + + // Causes exceptions in order to verify new configuration loads filters successfully + conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter") + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + val m = intercept[IllegalArgumentException] { + server.get.start() + }.getMessage() + assert(m.contains("Decode argument cannot be null")) + } + + private val TEST_KEY = Base64.getUrlEncoder.encodeToString( + "Visit https://spark.apache.org to download Apache Spark.".getBytes()) + + test("SPARK-49103: REST server stars successfully with `spark.master.rest.filters`") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter") + conf.set("spark.org.apache.spark.ui.JWSFilter.param.secretKey", TEST_KEY) + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + server.get.start() + } + + test("SPARK-49103: JWSFilter successfully protects REST API via configurations") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter") + conf.set("spark.org.apache.spark.ui.JWSFilter.param.secretKey", TEST_KEY) + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + val port = server.get.start() + val masterUrl = s"spark://$localhost:$port" + val json = constructSubmitRequest(masterUrl).toJson + val httpUrl = masterUrl.replace("spark://", "http://") + val submitRequestPath = s"$httpUrl/${RestSubmissionServer.PROTOCOL_VERSION}/submissions/create" + val conn = sendHttpRequest(submitRequestPath, "POST", json) + assert(conn.getResponseCode === HttpServletResponse.SC_FORBIDDEN) + } + + test("SPARK-50195: Fix StandaloneRestServer to propagate app name to SparkSubmit properly") { + Seq((classOf[SparkSubmit].getName, Seq("-c", "spark.app.name=app1")), + ("", Seq.empty)).foreach { case (mainClass, expectedArguments) => + val request = new CreateSubmissionRequest + request.appResource = "" + request.mainClass = mainClass + request.appArgs = Array.empty[String] + request.sparkProperties = Map("spark.app.name" -> "app1") + request.environmentVariables = Map.empty[String, String] + val servlet = new StandaloneSubmitRequestServlet(null, null, null) + val desc = servlet.buildDriverDescription(request, "spark://master:7077", 6066) + assert(desc.command.arguments.slice(3, 5) === expectedArguments) + } + } + + test("SPARK-50381: Support spark.master.rest.maxThreads") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-maxThreads", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + conf.set(MASTER_REST_SERVER_MAX_THREADS, 2000) + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + server.get.start() + val pool = server.get._server.get.getThreadPool.asInstanceOf[SizedThreadPool] + assert(pool.getMaxThreads === 2000) + } + + test("SPARK-50383: Support spark.master.rest.virtualThread.enabled") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-virtualThreads", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + conf.set(MASTER_REST_SERVER_VIRTUAL_THREADS, true) + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + server.get.start() + val pool = server.get._server.get.getThreadPool.asInstanceOf[QueuedThreadPool] + if (Utils.isJavaVersionAtLeast21) { + assert(pool.getVirtualThreadsExecutor != null) + } else { + assert(pool.getVirtualThreadsExecutor == null) + } + } + /* --------------------- * | Helper methods | * --------------------- */ @@ -599,7 +719,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { url: String, method: String, body: String = ""): HttpURLConnection = { - val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] + val conn = new URI(url).toURL.openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod(method) if (body.nonEmpty) { conn.setDoOutput(true) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 9eb5172583120..f2807f258f2d1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -87,8 +87,6 @@ class SubmitRestProtocolSuite extends SparkFunSuite { message.clientSparkVersion = "1.2.3" message.appResource = "honey-walnut-cherry.jar" message.mainClass = "org.apache.spark.examples.SparkPie" - message.appArgs = Array("two slices") - message.environmentVariables = Map("PATH" -> "/dev/null") val conf = new SparkConf(false) conf.set("spark.app.name", "SparkPie") message.sparkProperties = conf.getAll.toMap diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 7bd84c810c42e..8b98df103c014 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -37,7 +37,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.TestUtils._ import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} -import org.apache.spark.internal.config.PLUGINS +import org.apache.spark.internal.config.{EXECUTOR_MEMORY, PLUGINS} import org.apache.spark.resource._ import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.resource.TestResourceIDs._ @@ -581,7 +581,8 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite */ test("SPARK-40320 Executor should exit when initialization failed for fatal error") { val conf = createSparkConf() - .setMaster("local-cluster[1, 1, 1024]") + .setMaster("local-cluster[1, 1, 512]") + .set(EXECUTOR_MEMORY.key, "512m") .set(PLUGINS, Seq(classOf[TestFatalErrorPlugin].getName)) .setAppName("test") sc = new SparkContext(conf) @@ -599,7 +600,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite } try { sc.addSparkListener(listener) - eventually(timeout(15.seconds)) { + eventually(timeout(30.seconds)) { assert(executorAddCounter.get() >= 2) assert(executorRemovedCounter.get() >= 2) } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 805e7ca467497..fa13092dc47aa 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.lang.Thread.UncaughtExceptionHandler import java.net.URL import java.nio.ByteBuffer -import java.util.Properties +import java.util.{HashMap, Properties} import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean @@ -522,7 +522,13 @@ class ExecutorSuite extends SparkFunSuite testThrowable(new OutOfMemoryError(), depthToCheck, isFatal = true) testThrowable(new InterruptedException(), depthToCheck, isFatal = false) testThrowable(new RuntimeException("test"), depthToCheck, isFatal = false) - testThrowable(new SparkOutOfMemoryError("test"), depthToCheck, isFatal = false) + testThrowable( + new SparkOutOfMemoryError( + "_LEGACY_ERROR_USER_RAISED_EXCEPTION", + new HashMap[String, String]() { + put("errorMessage", "test") + }), + depthToCheck, isFatal = false) } // Verify we can handle the cycle in the exception chain diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala index 79fa8d21bf3f1..fc8f48df2cb7d 100644 --- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala @@ -383,7 +383,7 @@ object NonLocalModeSparkPlugin { resources: Map[String, ResourceInformation]): Unit = { val path = conf.get(TEST_PATH_CONF) val strToWrite = createFileStringWithGpuAddrs(id, resources) - Files.write(strToWrite, new File(path, s"$filePrefix$id"), StandardCharsets.UTF_8) + Files.asCharSink(new File(path, s"$filePrefix$id"), StandardCharsets.UTF_8).write(strToWrite) } def reset(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 5c09a1f965b9e..ff971b72d8910 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -132,7 +132,7 @@ class CompressionCodecSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { CompressionCodec.createCodec(conf, "foobar") }, - errorClass = "CODEC_NOT_AVAILABLE.WITH_CONF_SUGGESTION", + condition = "CODEC_NOT_AVAILABLE.WITH_CONF_SUGGESTION", parameters = Map( "codecName" -> "foobar", "configKey" -> "\"spark.io.compression.codec\"", @@ -171,7 +171,7 @@ class CompressionCodecSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { CompressionCodec.getShortName(codecClass.toUpperCase(Locale.ROOT)) }, - errorClass = "CODEC_SHORT_NAME_NOT_FOUND", + condition = "CODEC_SHORT_NAME_NOT_FOUND", parameters = Map("codecName" -> codecClass.toUpperCase(Locale.ROOT))) } } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index a6f1707a1aabf..1f40ef944a843 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -335,6 +335,24 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { tMemManager.releaseExecutionMemory(500L, c) assert(tMemManager.getMemoryConsumptionForThisTask === 0L) } + + test("task peak execution memory usage") { + val memoryManager = createMemoryManager( + maxOnHeapExecutionMemory = 1000L, + maxOffHeapExecutionMemory = 1000L) + + val tMemManager = new TaskMemoryManager(memoryManager, 1) + val offHeapConsumer = new TestMemoryConsumer(tMemManager, MemoryMode.OFF_HEAP) + val onHeapConsumer = new TestMemoryConsumer(tMemManager, MemoryMode.ON_HEAP) + + val result1 = tMemManager.acquireExecutionMemory(500L, offHeapConsumer) + val result2 = tMemManager.acquireExecutionMemory(400L, onHeapConsumer) + assert(result1 === 500L) + assert(result2 === 400L) + assert(tMemManager.getMemoryConsumptionForThisTask === 900L) + assert(tMemManager.getPeakOnHeapExecutionMemory === 400L) + assert(tMemManager.getPeakOffHeapExecutionMemory === 500L) + } } private object MemoryManagerSuite { diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala index 55d82aed5c3f2..817d660763361 100644 --- a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala @@ -88,7 +88,7 @@ class GraphiteSinkSuite extends SparkFunSuite { val e = intercept[SparkException] { new GraphiteSink(props, registry) } - checkError(e, errorClass = "GRAPHITE_SINK_PROPERTY_MISSING", + checkError(e, condition = "GRAPHITE_SINK_PROPERTY_MISSING", parameters = Map("property" -> "host")) } @@ -100,7 +100,7 @@ class GraphiteSinkSuite extends SparkFunSuite { val e = intercept[SparkException] { new GraphiteSink(props, registry) } - checkError(e, errorClass = "GRAPHITE_SINK_PROPERTY_MISSING", + checkError(e, condition = "GRAPHITE_SINK_PROPERTY_MISSING", parameters = Map("property" -> "port")) } @@ -115,7 +115,7 @@ class GraphiteSinkSuite extends SparkFunSuite { exception = intercept[SparkException] { new GraphiteSink(props, registry) }, - errorClass = "GRAPHITE_SINK_INVALID_PROTOCOL", + condition = "GRAPHITE_SINK_INVALID_PROTOCOL", parameters = Map("protocol" -> "http") ) } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 4239180ba6c37..fb2bb83cb7fc4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -201,10 +201,10 @@ class AsyncRDDActionsSuite extends SparkFunSuite with TimeLimits { test("FutureAction result, timeout") { val f = sc.parallelize(1 to 100, 4) - .mapPartitions(itr => { Thread.sleep(20); itr }) + .mapPartitions(itr => { Thread.sleep(200); itr }) .countAsync() intercept[TimeoutException] { - ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) + ThreadUtils.awaitResult(f, Duration(2, "milliseconds")) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 7c5db914cd5ba..aecb8b99d0e31 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -792,7 +792,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { test("randomSplit") { val n = 600 val data = sc.parallelize(1 to n, 2) - for(seed <- 1 to 5) { + for (seed <- 1 to 5) { val splits = data.randomSplit(Array(1.0, 2.0, 3.0), seed) assert(splits.length == 3, "wrong number of splits") assert(splits.flatMap(_.collect()).sorted.toList == data.collect().toList, @@ -922,7 +922,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { exception = intercept[SparkIllegalArgumentException] { rdd1.cartesian(rdd2).partitions }, - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", sqlState = "54000", parameters = Map( "numberOfElements" -> (numSlices.toLong * numSlices.toLong).toString, diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala index ff7d680352177..edf138df9e207 100644 --- a/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala +++ b/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala @@ -148,7 +148,7 @@ object TestResourceDiscoveryPlugin { def writeFile(conf: SparkConf, id: String): Unit = { val path = conf.get(TEST_PATH_CONF) val fileName = s"$id - ${UUID.randomUUID.toString}" - Files.write(id, new File(path, fileName), StandardCharsets.UTF_8) + Files.asCharSink(new File(path, fileName), StandardCharsets.UTF_8).write(id) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 3ef382573517b..66b1ee7b58ac8 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -868,23 +868,23 @@ abstract class RpcEnvSuite extends SparkFunSuite { val conf = createSparkConf() val file = new File(tempDir, "file") - Files.write(UUID.randomUUID().toString(), file, UTF_8) + Files.asCharSink(file, UTF_8).write(UUID.randomUUID().toString) val fileWithSpecialChars = new File(tempDir, "file name") - Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) + Files.asCharSink(fileWithSpecialChars, UTF_8).write(UUID.randomUUID().toString) val empty = new File(tempDir, "empty") - Files.write("", empty, UTF_8); + Files.asCharSink(empty, UTF_8).write("") val jar = new File(tempDir, "jar") - Files.write(UUID.randomUUID().toString(), jar, UTF_8) + Files.asCharSink(jar, UTF_8).write(UUID.randomUUID().toString) val dir1 = new File(tempDir, "dir1") assert(dir1.mkdir()) val subFile1 = new File(dir1, "file1") - Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + Files.asCharSink(subFile1, UTF_8).write(UUID.randomUUID().toString) val dir2 = new File(tempDir, "dir2") assert(dir2.mkdir()) val subFile2 = new File(dir2, "file2") - Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + Files.asCharSink(subFile2, UTF_8).write(UUID.randomUUID().toString) val fileUri = env.fileServer.addFile(file) val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 849832c57edaa..f00fb0d2cfa3f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -101,7 +101,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) + Thread.sleep(Random.nextInt(500)) context.barrier() val time1 = System.currentTimeMillis() // Sleep for a random time before global sync. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 978ceb16b376c..243d33fe55a79 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -779,7 +779,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(failureReason.isDefined) checkError( exception = failureReason.get.asInstanceOf[SparkException], - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map("jobId" -> "0", "reason" -> "") ) @@ -901,7 +901,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti cancel(jobId) checkError( exception = failure.asInstanceOf[SparkException], - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map("jobId" -> jobId.toString, "reason" -> "") ) diff --git a/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala index e7a57c22ef66e..478e578130fcb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala @@ -441,6 +441,23 @@ class HealthTrackerSuite extends SparkFunSuite with MockitoSugar with LocalSpark assert(1000 === HealthTracker.getExcludeOnFailureTimeout(conf)) } + test("SPARK-49252: check exclusion enabling config on the application level") { + val conf = new SparkConf().setMaster("local") + assert(!HealthTracker.isExcludeOnFailureEnabled(conf)) + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED, true) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn off taskset level exclusion, application level healthtracker should still be enabled. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn off the application level exclusion specifically, this overrides the global setting. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_APPLICATION, false) + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + assert(!HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn on application level exclusion, health tracker should be enabled. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_APPLICATION, true) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + } + test("check exclude configuration invariants") { val conf = new SparkConf().setMaster("yarn").set(config.SUBMIT_DEPLOY_MODE, "cluster") Seq( diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ab2c00e368468..7607d4d9fe6d9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -2725,6 +2725,39 @@ class TaskSetManagerSuite assert(executorMonitor.isExecutorIdle("exec2")) } + test("SPARK-49252: TaskSetExcludeList can be created without HealthTracker") { + // When the excludeOnFailure.enabled is set to true, the TaskSetManager should create a + // TaskSetExcludelist even if the application level HealthTracker is not defined. + val conf = new SparkConf().set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, true) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(1) + + val taskSetManager = new TaskSetManager(sched, taskSet, 1, + // No application level HealthTracker. + healthTracker = None) + assert(taskSetManager.taskSetExcludelistHelperOpt.isDefined) + } + + test("SPARK-49252: TaskSetExcludeList will be running in dry run mode when" + + "exludeOnFailure at taskset level is disabled but health tracker is enabled") { + // Disable the excludeOnFailure.enabled at taskset level. + val conf = new SparkConf().set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(1) + + val taskSetManager = new TaskSetManager(sched, taskSet, 1, + // Enable the application level HealthTracker. + healthTracker = Some(new HealthTracker(sc, None))) + assert(taskSetManager.taskSetExcludelistHelperOpt.isDefined) + assert(taskSetManager.taskSetExcludelistHelperOpt.get.isDryRun) + } + } class FakeLongTasks(stageId: Int, partitionId: Int) extends FakeTask(stageId, partitionId) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/ChecksumBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/ChecksumBenchmark.scala new file mode 100644 index 0000000000000..16a50fabb7ffd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ChecksumBenchmark.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.util.zip.{Adler32, CRC32, CRC32C} + +import org.apache.hadoop.util.PureJavaCrc32C + +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} + +/** + * Benchmark for Checksum Algorithms used by shuffle. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "core/Test/runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/Test/runMain " + * Results will be written to "benchmarks/ChecksumBenchmark-results.txt". + * }}} + */ +object ChecksumBenchmark extends BenchmarkBase { + + val N = 1024 + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Benchmark Checksum Algorithms") { + val data: Array[Byte] = (1 until 32 * 1024 * 1024).map(_.toByte).toArray + val benchmark = new Benchmark("Checksum Algorithms", N, 3, output = output) + benchmark.addCase("CRC32") { _ => + (1 to N).foreach(_ => new CRC32().update(data)) + } + benchmark.addCase(s"CRC32C") { _ => + (1 to N).foreach(_ => new CRC32C().update(data)) + } + benchmark.addCase(s"Adler32") { _ => + (1 to N).foreach(_ => new Adler32().update(data)) + } + benchmark.addCase(s"hadoop PureJavaCrc32C") { _ => + (1 to N).foreach(_ => new PureJavaCrc32C().update(data)) + } + benchmark.run() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala index 5b6fb31d598ac..aad649b7b2612 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -111,7 +111,7 @@ class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext wi exception = intercept[SparkOutOfMemoryError] { sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) }, - errorClass = "UNABLE_TO_ACQUIRE_MEMORY", + condition = "UNABLE_TO_ACQUIRE_MEMORY", parameters = Map("requestedBytes" -> "800", "receivedBytes" -> "400")) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala index febe1ac4bb4cf..1c4c00c03a470 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.File +import java.nio.file.{Files, Paths} import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Semaphore, TimeUnit} import scala.collection.mutable.ArrayBuffer @@ -377,20 +378,22 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS .set(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED, true) sc = new SparkContext(conf) TestUtils.waitUntilExecutorsUp(sc, 2, 60000) - val shuffleBlockUpdates = new ArrayBuffer[BlockId]() - var isDecommissionedExecutorRemoved = false + val shuffleBlockUpdates = new ConcurrentLinkedQueue[BlockId]() val execToDecommission = sc.getExecutorIds().head + val decommissionedExecutorLocalDir = sc.parallelize(1 to 100, 10).flatMap { _ => + if (SparkEnv.get.executorId == execToDecommission) { + SparkEnv.get.blockManager.getLocalDiskDirs + } else { + Array.empty[String] + } + }.collect().toSet + assert(decommissionedExecutorLocalDir.size == 1) sc.addSparkListener(new SparkListener { override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { if (blockUpdated.blockUpdatedInfo.blockId.isShuffle) { - shuffleBlockUpdates += blockUpdated.blockUpdatedInfo.blockId + shuffleBlockUpdates.add(blockUpdated.blockUpdatedInfo.blockId) } } - - override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { - assert(execToDecommission === executorRemoved.executorId) - isDecommissionedExecutorRemoved = true - } }) // Run a job to create shuffle data @@ -409,12 +412,13 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS ) eventually(timeout(1.minute), interval(10.milliseconds)) { - assert(isDecommissionedExecutorRemoved) + assert(Files.notExists(Paths.get(decommissionedExecutorLocalDir.head))) // Ensure there are shuffle data have been migrated assert(shuffleBlockUpdates.size >= 2) } val shuffleId = shuffleBlockUpdates + .asScala .find(_.isInstanceOf[ShuffleIndexBlockId]) .map(_.asInstanceOf[ShuffleIndexBlockId].shuffleId) .get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 1fbc900727c4c..f5fca56e5ef77 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -38,6 +38,8 @@ import org.apache.spark.internal.config.Tests._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.shuffle.ExternalBlockStoreClient +import org.apache.spark.network.util.{MapConfigProvider, TransportConf} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} @@ -296,6 +298,41 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } } + test("Test block location after replication with SHUFFLE_SERVICE_FETCH_RDD_ENABLED enabled") { + val newConf = conf.clone() + newConf.set(SHUFFLE_SERVICE_ENABLED, true) + newConf.set(SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true) + newConf.set(Tests.TEST_SKIP_ESS_REGISTER, true) + val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() + val shuffleClient = Some(new ExternalBlockStoreClient( + new TransportConf("shuffle", MapConfigProvider.EMPTY), + null, false, 5000)) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager-2", + new BlockManagerMasterEndpoint(rpcEnv, true, newConf, + new LiveListenerBus(newConf), shuffleClient, blockManagerInfo, mapOutputTracker, + sc.env.shuffleManager, isDriver = true)), + rpcEnv.setupEndpoint("blockmanagerHeartbeat-2", + new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), newConf, true) + + val shuffleServicePort = newConf.get(SHUFFLE_SERVICE_PORT) + val store1 = makeBlockManager(10000, "host-1") + val store2 = makeBlockManager(10000, "host-2") + assert(master.getPeers(store1.blockManagerId).toSet === Set(store2.blockManagerId)) + + val blockId = RDDBlockId(1, 2) + val message = new Array[Byte](1000) + + // if SHUFFLE_SERVICE_FETCH_RDD_ENABLED is enabled, then shuffle port should be present. + store1.putSingle(blockId, message, StorageLevel.DISK_ONLY) + assert(master.getLocations(blockId).contains( + BlockManagerId("host-1", "localhost", shuffleServicePort, None))) + + // after block is removed, shuffle port should be removed. + store1.removeBlock(blockId, true) + assert(!master.getLocations(blockId).contains( + BlockManagerId("host-1", "localhost", shuffleServicePort, None))) + } + test("block replication - addition and deletion of block managers") { val blockSize = 1000 val storeSize = 10000 diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 17dff20dd993b..9fbe15402c8b3 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -33,7 +33,7 @@ import scala.reflect.classTag import com.esotericsoftware.kryo.KryoException import org.mockito.{ArgumentCaptor, ArgumentMatchers => mc} -import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify, when} +import org.mockito.Mockito.{atLeastOnce, doAnswer, mock, never, spy, times, verify, when} import org.scalatest.PrivateMethodTester import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ @@ -698,7 +698,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe removedFromMemory: Boolean, removedFromDisk: Boolean): Unit = { def assertSizeReported(captor: ArgumentCaptor[Long], expectRemoved: Boolean): Unit = { - assert(captor.getAllValues().size() === 1) + assert(captor.getAllValues().size() >= 1) if (expectRemoved) { assert(captor.getValue() > 0) } else { @@ -708,15 +708,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe val memSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]] val diskSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]] - verify(master).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), - mc.eq(StorageLevel.NONE), memSizeCaptor.capture(), diskSizeCaptor.capture()) + val storageLevelCaptor = + ArgumentCaptor.forClass(classOf[StorageLevel]).asInstanceOf[ArgumentCaptor[StorageLevel]] + verify(master, atLeastOnce()).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), + storageLevelCaptor.capture(), memSizeCaptor.capture(), diskSizeCaptor.capture()) assertSizeReported(memSizeCaptor, removedFromMemory) assertSizeReported(diskSizeCaptor, removedFromDisk) + assert(storageLevelCaptor.getValue.replication == 0) } private def assertUpdateBlockInfoNotReported(store: BlockManager, blockId: BlockId): Unit = { verify(master, never()).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), - mc.eq(StorageLevel.NONE), mc.anyInt(), mc.anyInt()) + mc.any[StorageLevel](), mc.anyInt(), mc.anyInt()) } test("reregistration on heart beat") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index 9ce4acc75ec43..8f15e8cf19412 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -128,7 +128,7 @@ class TopologyAwareBlockReplicationPolicyBehavior extends RandomBlockReplication assert(prioritizedPeers.toSet.size == numReplicas) val priorityPeers = prioritizedPeers.take(2) assert(priorityPeers.forall(p => p.host != blockManager.host)) - if(numReplicas > 1) { + if (numReplicas > 1) { // both these conditions should be satisfied when numReplicas > 1 assert(priorityPeers.exists(p => p.topologyInfo == blockManager.topologyInfo)) assert(priorityPeers.exists(p => p.topologyInfo != blockManager.topologyInfo)) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index ca77d2c7b7097..033bd9d244cff 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1943,4 +1943,31 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(err2.getMessage.contains("corrupt at reset")) } + + test("SPARK-43242: Fix throw 'Unexpected type of BlockId' in shuffle corruption diagnose") { + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockBatchId(0, 0, 0, 3) -> createMockManagedBuffer()) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + listener.onBlockFetchSuccess(ShuffleBlockBatchId(0, 0, 0, 3).toString, mockCorruptBuffer()) + } + + val logAppender = new LogAppender("diagnose corruption") + withLogAppender(logAppender) { + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100) + ) + intercept[FetchFailedException](iterator.next()) + verify(transfer, times(2)) + .fetchBlocks(any(), any(), any(), any(), any(), any()) + assert(logAppender.loggingEvents.count( + _.getMessage.getFormattedMessage.contains("Start corruption diagnosis")) === 1) + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("shuffle_0_0_0_3 is corrupted " + + "but corruption diagnosis is skipped due to lack of " + + "shuffle checksum support for ShuffleBlockBatchId"))) + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/JWSFilterSuite.scala b/core/src/test/scala/org/apache/spark/ui/JWSFilterSuite.scala new file mode 100644 index 0000000000000..7338054c158c6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/JWSFilterSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import java.util.{Base64, HashMap => JHashMap} + +import scala.jdk.CollectionConverters._ + +import jakarta.servlet.{FilterChain, FilterConfig, ServletContext} +import jakarta.servlet.http.{HttpServletRequest, HttpServletResponse} +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.{mock, times, verify, when} + +import org.apache.spark._ + +class JWSFilterSuite extends SparkFunSuite { + // {"alg":"HS256","typ":"JWT"} => eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9, {} => e30 + private val TOKEN = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.4EKWlOkobpaAPR0J4BE0cPQ-ZD1tRQKLZp1vtE7upPw" + + private val TEST_KEY = Base64.getUrlEncoder.encodeToString( + "Visit https://spark.apache.org to download Apache Spark.".getBytes()) + + test("Should fail when a parameter is missing") { + val filter = new JWSFilter() + val params = new JHashMap[String, String] + val m = intercept[IllegalArgumentException] { + filter.init(new DummyFilterConfig(params)) + }.getMessage() + assert(m.contains("Decode argument cannot be null")) + } + + test("Succeed to initialize") { + val filter = new JWSFilter() + val params = new JHashMap[String, String] + params.put("secretKey", TEST_KEY) + filter.init(new DummyFilterConfig(params)) + } + + test("Should response with SC_FORBIDDEN when it cannot verify JWS") { + val req = mockRequest() + val res = mock(classOf[HttpServletResponse]) + val chain = mock(classOf[FilterChain]) + + val filter = new JWSFilter() + val params = new JHashMap[String, String] + params.put("secretKey", TEST_KEY) + val conf = new DummyFilterConfig(params) + filter.init(conf) + + // 'Authorization' header is missing + filter.doFilter(req, res, chain) + verify(res).sendError(meq(HttpServletResponse.SC_FORBIDDEN), + meq("Authorization header is missing.")) + verify(chain, times(0)).doFilter(any(), any()) + + // The value of Authorization field is not 'Bearer ' style. + when(req.getHeader("Authorization")).thenReturn("Invalid") + filter.doFilter(req, res, chain) + verify(res).sendError(meq(HttpServletResponse.SC_FORBIDDEN), + meq("Malformed Authorization header.")) + verify(chain, times(0)).doFilter(any(), any()) + } + + test("Should succeed on valid JWS") { + val req = mockRequest() + val res = mock(classOf[HttpServletResponse]) + val chain = mock(classOf[FilterChain]) + + val filter = new JWSFilter() + val params = new JHashMap[String, String] + params.put("secretKey", TEST_KEY) + val conf = new DummyFilterConfig(params) + filter.init(conf) + + when(req.getHeader("Authorization")).thenReturn(s"Bearer $TOKEN") + filter.doFilter(req, res, chain) + verify(chain, times(1)).doFilter(any(), any()) + } + + private def mockRequest(params: Map[String, Array[String]] = Map()): HttpServletRequest = { + val req = mock(classOf[HttpServletRequest]) + when(req.getParameterMap()).thenReturn(params.asJava) + req + } + + class DummyFilterConfig (val map: java.util.Map[String, String]) extends FilterConfig { + override def getFilterName: String = "dummy" + + override def getInitParameter(arg0: String): String = map.get(arg0) + + override def getInitParameterNames: java.util.Enumeration[String] = + java.util.Collections.enumeration(map.keySet) + + override def getServletContext: ServletContext = null + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 327715b5fa2b9..18a8b50e1df19 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui -import java.net.URL +import java.net.{URI, URL} import java.util.Locale import scala.io.Source @@ -544,8 +544,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers { withSpark(newSparkContext(killEnabled = true)) { sc => sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5.seconds), interval(50.milliseconds)) { - val url = new URL( - sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0") + val url = new URI( + sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0").toURL // SPARK-6846: should be POST only but YARN AM doesn't proxy POST TestUtils.httpResponseCode(url, "GET") should be (200) TestUtils.httpResponseCode(url, "POST") should be (200) @@ -557,8 +557,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers { withSpark(newSparkContext(killEnabled = true)) { sc => sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5.seconds), interval(50.milliseconds)) { - val url = new URL( - sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0") + val url = new URI( + sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0").toURL // SPARK-6846: should be POST only but YARN AM doesn't proxy POST TestUtils.httpResponseCode(url, "GET") should be (200) TestUtils.httpResponseCode(url, "POST") should be (200) @@ -692,8 +692,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers { test("live UI json application list") { withSpark(newSparkContext()) { sc => - val appListRawJson = HistoryServerSuite.getUrl(new URL( - sc.ui.get.webUrl + "/api/v1/applications")) + val appListRawJson = HistoryServerSuite.getUrl(new URI( + sc.ui.get.webUrl + "/api/v1/applications").toURL) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) val attempts = (appListJsonAst.children.head \ "attempts").children @@ -918,6 +918,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers { } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.webUrl + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) + new URI(ui.webUrl + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path).toURL } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 10681f22fa624..1b68ed301fb92 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -251,7 +251,7 @@ class UISuite extends SparkFunSuite { val serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions, conf) try { val path = "/test" - val url = new URL(s"http://$localhost:${serverInfo.boundPort}$path/root") + val url = new URI(s"http://$localhost:${serverInfo.boundPort}$path/root").toURL assert(TestUtils.httpResponseCode(url) === HttpServletResponse.SC_NOT_FOUND) @@ -261,8 +261,10 @@ class UISuite extends SparkFunSuite { // Try a request with bad content in a parameter to make sure the security filter // is being added to new handlers. + // scalastyle:off URLConstructor val badRequest = new URL( s"http://$localhost:${serverInfo.boundPort}$path/root?bypass&invalid<=foo") + // scalastyle:on URLConstructor assert(TestUtils.httpResponseCode(badRequest) === HttpServletResponse.SC_OK) assert(servlet.lastRequest.getParameter("invalid<") === null) assert(servlet.lastRequest.getParameter("invalid<") !== null) @@ -283,7 +285,7 @@ class UISuite extends SparkFunSuite { val (_, ctx) = newContext("/ctx1") serverInfo.addHandler(ctx, securityMgr) - TestUtils.withHttpConnection(new URL(s"$serverAddr/ctx%281%29?a%5B0%5D=b")) { conn => + TestUtils.withHttpConnection(new URI(s"$serverAddr/ctx%281%29?a%5B0%5D=b").toURL) { conn => assert(conn.getResponseCode() === HttpServletResponse.SC_FOUND) val location = Option(conn.getHeaderFields().get("Location")) .map(_.get(0)).orNull @@ -319,7 +321,7 @@ class UISuite extends SparkFunSuite { s"$scheme://$localhost:$port/test1/root", s"$scheme://$localhost:$port/test2/root") urls.foreach { url => - val rc = TestUtils.httpResponseCode(new URL(url)) + val rc = TestUtils.httpResponseCode(new URI(url).toURL) assert(rc === expected, s"Unexpected status $rc for $url") } } @@ -366,7 +368,7 @@ class UISuite extends SparkFunSuite { serverInfo.addHandler(redirect, securityMgr) // Test Jetty's built-in redirect to add the trailing slash to the context path. - TestUtils.withHttpConnection(new URL(s"$serverAddr/ctx1")) { conn => + TestUtils.withHttpConnection(new URI(s"$serverAddr/ctx1").toURL) { conn => assert(conn.getResponseCode() === HttpServletResponse.SC_FOUND) val location = Option(conn.getHeaderFields().get("Location")) .map(_.get(0)).orNull @@ -376,7 +378,7 @@ class UISuite extends SparkFunSuite { // Test with a URL handled by the added redirect handler, and also including a path prefix. val headers = Seq("X-Forwarded-Context" -> "/prefix") TestUtils.withHttpConnection( - new URL(s"$serverAddr/src/"), + new URI(s"$serverAddr/src/").toURL, headers = headers) { conn => assert(conn.getResponseCode() === HttpServletResponse.SC_FOUND) val location = Option(conn.getHeaderFields().get("Location")) @@ -387,7 +389,7 @@ class UISuite extends SparkFunSuite { // Not really used by Spark, but test with a relative redirect. val relative = JettyUtils.createRedirectHandler("/rel", "root") serverInfo.addHandler(relative, securityMgr) - TestUtils.withHttpConnection(new URL(s"$serverAddr/rel/")) { conn => + TestUtils.withHttpConnection(new URI(s"$serverAddr/rel/").toURL) { conn => assert(conn.getResponseCode() === HttpServletResponse.SC_FOUND) val location = Option(conn.getHeaderFields().get("Location")) .map(_.get(0)).orNull @@ -410,12 +412,12 @@ class UISuite extends SparkFunSuite { serverInfo.addHandler(ctx, securityMgr) val urlStr = s"http://$localhost:${serverInfo.boundPort}/ctx" - assert(TestUtils.httpResponseCode(new URL(urlStr + "/")) === HttpServletResponse.SC_OK) + assert(TestUtils.httpResponseCode(new URI(urlStr + "/").toURL) === HttpServletResponse.SC_OK) // In the case of trailing slash, // 302 should be return and the redirect URL shouuld be part of the header. - assert(TestUtils.redirectUrl(new URL(urlStr)) === proxyRoot + "/ctx/"); - assert(TestUtils.httpResponseCode(new URL(urlStr)) === HttpServletResponse.SC_FOUND) + assert(TestUtils.redirectUrl(new URI(urlStr).toURL) === proxyRoot + "/ctx/"); + assert(TestUtils.httpResponseCode(new URI(urlStr).toURL) === HttpServletResponse.SC_FOUND) } finally { stopServer(serverInfo) } @@ -452,10 +454,11 @@ class UISuite extends SparkFunSuite { val sparkUI = SparkUI.create(Some(sc), sc.statusStore, sc.conf, sc.env.securityManager, sc.appName, "", sc.startTime) sparkUI.bind() - assert(TestUtils.httpResponseMessage(new URL(sparkUI.webUrl + "/jobs")) + val url = new URI(sparkUI.webUrl + "/jobs").toURL + assert(TestUtils.httpResponseMessage(url) === "Spark is starting up. Please wait a while until it's ready.") sparkUI.attachAllHandlers() - assert(TestUtils.httpResponseMessage(new URL(sparkUI.webUrl + "/jobs")).contains(sc.appName)) + assert(TestUtils.httpResponseMessage(url).contains(sc.appName)) sparkUI.stop() } } diff --git a/core/src/test/scala/org/apache/spark/util/DirectByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/DirectByteBufferOutputStreamSuite.scala new file mode 100644 index 0000000000000..7fd9d1fc05c9c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/DirectByteBufferOutputStreamSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.{SparkException, SparkFunSuite} + +class DirectByteBufferOutputStreamSuite extends SparkFunSuite { + test("use after close") { + val o = new DirectByteBufferOutputStream() + val size = 1000 + o.write(new Array[Byte](size), 0, size) + val b = o.toByteBuffer + o.close() + + // Using `o` after close should throw an exception rather than crashing. + assertThrows[SparkException] { o.write(123) } + assertThrows[SparkException] { o.write(new Array[Byte](size), 0, size) } + assertThrows[SparkException] { o.reset() } + assertThrows[SparkException] { o.size() } + assertThrows[SparkException] { o.toByteBuffer } + + // Using `b` after `o` is closed may crash. + // val arr = new Array[Byte](size) + // b.get(arr) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 35ef0587b9b4c..4497ea1b2b798 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -54,11 +54,11 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter { val inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)) // The `header` should not be covered val header = "Add header" - Files.write(header, testFile, StandardCharsets.UTF_8) + Files.asCharSink(testFile, StandardCharsets.UTF_8).write(header) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(Files.toString(testFile, StandardCharsets.UTF_8) === header + testString) + assert(Files.asCharSource(testFile, StandardCharsets.UTF_8).read() === header + testString) } test("SPARK-35027: basic file appender - close stream") { @@ -392,7 +392,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter { IOUtils.closeQuietly(inputStream) } } else { - Files.toString(file, StandardCharsets.UTF_8) + Files.asCharSource(file, StandardCharsets.UTF_8).read() } }.mkString("") assert(allText === expectedText) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 3eae1b3278e74..30c9693e6dee3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -32,6 +32,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.internal.config._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.{DeterministicLevel, RDDOperationScope} import org.apache.spark.resource._ @@ -276,7 +277,8 @@ class JsonProtocolSuite extends SparkFunSuite { test("StageInfo backward compatibility (details, accumulables)") { val info = makeStageInfo(1, 2, 3, 4L, 5L) - val newJson = toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true)) + val newJson = toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true)) // Fields added after 1.0.0. assert(info.details.nonEmpty) @@ -294,7 +296,8 @@ class JsonProtocolSuite extends SparkFunSuite { test("StageInfo resourceProfileId") { val info = makeStageInfo(1, 2, 3, 4L, 5L, 5) - val json = toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true)) + val json = toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true)) // Fields added after 1.0.0. assert(info.details.nonEmpty) @@ -471,7 +474,7 @@ class JsonProtocolSuite extends SparkFunSuite { stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) - val oldEvent = toJsonString(JsonProtocol.jobStartToJson(jobStart, _)).removeField("Stage Infos") + val oldEvent = sparkEventToJsonString(jobStart).removeField("Stage Infos") val expectedJobStart = SparkListenerJobStart(10, jobSubmissionTime, dummyStageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) @@ -483,8 +486,7 @@ class JsonProtocolSuite extends SparkFunSuite { val stageIds = Seq[Int](1, 2, 3, 4) val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L)) val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) - val oldStartEvent = toJsonString(JsonProtocol.jobStartToJson(jobStart, _)) - .removeField("Submission Time") + val oldStartEvent = sparkEventToJsonString(jobStart).removeField("Submission Time") val expectedJobStart = SparkListenerJobStart(11, -1, stageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldStartEvent)) @@ -519,8 +521,9 @@ class JsonProtocolSuite extends SparkFunSuite { val stageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq(1, 2, 3), "details", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val oldStageInfo = - toJsonString(JsonProtocol.stageInfoToJson(stageInfo, _, includeAccumulables = true)) - .removeField("Parent IDs") + toJsonString( + JsonProtocol.stageInfoToJson(stageInfo, _, defaultOptions, includeAccumulables = true) + ).removeField("Parent IDs") val expectedStageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq.empty, "details", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) @@ -785,6 +788,87 @@ class JsonProtocolSuite extends SparkFunSuite { assert(JsonProtocol.sparkEventFromJson(unknownFieldsJson) === expected) } + test("SPARK-42204: spark.eventLog.includeTaskMetricsAccumulators config") { + val includeConf = new JsonProtocolOptions( + new SparkConf().set(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS, true)) + val excludeConf = new JsonProtocolOptions( + new SparkConf().set(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS, false)) + + val taskMetricsAccumulables = TaskMetrics + .empty + .nameToAccums + .view + .filterKeys(!JsonProtocol.accumulableExcludeList.contains(_)) + .values + .map(_.toInfo(Some(1), None)) + .toSeq + + val taskInfoWithTaskMetricsAccums = makeTaskInfo(222L, 333, 1, 333, 444L, false) + taskInfoWithTaskMetricsAccums.setAccumulables(taskMetricsAccumulables) + val taskInfoWithoutTaskMetricsAccums = makeTaskInfo(222L, 333, 1, 333, 444L, false) + taskInfoWithoutTaskMetricsAccums.setAccumulables(Seq.empty) + + val stageInfoWithTaskMetricsAccums = makeStageInfo(100, 200, 300, 400L, 500L) + stageInfoWithTaskMetricsAccums.accumulables.clear() + stageInfoWithTaskMetricsAccums.accumulables ++= taskMetricsAccumulables.map(x => (x.id, x)) + val stageInfoWithoutTaskMetricsAccums = makeStageInfo(100, 200, 300, 400L, 500L) + stageInfoWithoutTaskMetricsAccums.accumulables.clear() + + // Test events which should be impacted by the config. + + // TaskEnd + { + val originalEvent = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success, + taskInfoWithTaskMetricsAccums, + new ExecutorMetrics(Array(12L, 23L, 45L, 67L, 78L, 89L, + 90L, 123L, 456L, 789L, 40L, 20L, 20L, 10L, 20L, 10L, 301L)), + makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, 0, + hasHadoopInput = false, hasOutput = false)) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(taskInfo = taskInfoWithoutTaskMetricsAccums) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // StageCompleted + { + val originalEvent = SparkListenerStageCompleted(stageInfoWithTaskMetricsAccums) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(stageInfo = stageInfoWithoutTaskMetricsAccums) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // JobStart + { + val originalEvent = + SparkListenerJobStart(1, 1, Seq(stageInfoWithTaskMetricsAccums), properties) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(stageInfos = Seq(stageInfoWithoutTaskMetricsAccums)) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // ExecutorMetricsUpdate events should be unaffected by the config: + val executorMetricsUpdate = + SparkListenerExecutorMetricsUpdate("0", Seq((0, 0, 0, taskMetricsAccumulables))) + assert( + sparkEventToJsonString(executorMetricsUpdate, includeConf) === + sparkEventToJsonString(executorMetricsUpdate, excludeConf)) + assertEquals( + JsonProtocol.sparkEventFromJson(sparkEventToJsonString(executorMetricsUpdate, includeConf)), + executorMetricsUpdate) + } + test("SPARK-42403: properly handle null string values") { // Null string values can appear in a few different event types, // so we test multiple known cases here: @@ -966,7 +1050,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def testStageInfo(info: StageInfo): Unit = { val newInfo = JsonProtocol.stageInfoFromJson( - toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true))) + toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true))) assertEquals(info, newInfo) } @@ -990,7 +1075,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def testTaskInfo(info: TaskInfo): Unit = { val newInfo = JsonProtocol.taskInfoFromJson( - toJsonString(JsonProtocol.taskInfoToJson(info, _, includeAccumulables = true))) + toJsonString( + JsonProtocol.taskInfoToJson(info, _, defaultOptions, includeAccumulables = true))) assertEquals(info, newInfo) } @@ -1453,6 +1539,8 @@ private[spark] object JsonProtocolSuite extends Assertions { t.setExecutorRunTime(b) t.setExecutorCpuTime(b) t.setPeakExecutionMemory(c) + t.setPeakOnHeapExecutionMemory(c) + t.setPeakOffHeapExecutionMemory(c) t.setResultSize(c) t.setJvmGCTime(d) t.setResultSerializationTime(a + b) @@ -1731,6 +1819,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Executor Run Time": 400, | "Executor CPU Time": 400, | "Peak Execution Memory": 500, + | "Peak On Heap Execution Memory": 500, + | "Peak Off Heap Execution Memory": 500, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -1872,6 +1962,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Executor Run Time": 400, | "Executor CPU Time": 400, | "Peak Execution Memory": 500, + | "Peak On Heap Execution Memory": 500, + | "Peak Off Heap Execution Memory": 500, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -2013,6 +2105,8 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Executor Run Time": 400, | "Executor CPU Time": 400, | "Peak Execution Memory": 500, + | "Peak On Heap Execution Memory": 500, + | "Peak Off Heap Execution Memory": 500, | "Result Size": 500, | "JVM GC Time": 600, | "Result Serialization Time": 700, @@ -2683,6 +2777,20 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 10, + | "Name": "$PEAK_ON_HEAP_EXECUTION_MEMORY", + | "Update": 500, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 11, + | "Name": "$PEAK_OFF_HEAP_EXECUTION_MEMORY", + | "Update": 500, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 12, | "Name": "$UPDATED_BLOCK_STATUSES", | "Update": [ | { @@ -2704,175 +2812,175 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Count Failed Values": true | }, | { - | "ID": 11, + | "ID": 13, | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 12, + | "ID": 14, | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 13, + | "ID": 15, | "Name": "${shuffleRead.REMOTE_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 14, + | "ID": 16, | "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 15, + | "ID": 17, | "Name": "${shuffleRead.LOCAL_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 16, + | "ID": 18, | "Name": "${shuffleRead.FETCH_WAIT_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 17, + | "ID": 19, | "Name": "${shuffleRead.RECORDS_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 18, + | "ID": 20, | "Name": "${shuffleRead.CORRUPT_MERGED_BLOCK_CHUNKS}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 19, + | "ID": 21, | "Name": "${shuffleRead.MERGED_FETCH_FALLBACK_COUNT}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID" : 20, + | "ID" : 22, | "Name" : "${shuffleRead.REMOTE_MERGED_BLOCKS_FETCHED}", | "Update" : 0, | "Internal" : true, | "Count Failed Values" : true | }, | { - | "ID" : 21, + | "ID" : 23, | "Name" : "${shuffleRead.LOCAL_MERGED_BLOCKS_FETCHED}", | "Update" : 0, | "Internal" : true, | "Count Failed Values" : true | }, | { - | "ID" : 22, + | "ID" : 24, | "Name" : "${shuffleRead.REMOTE_MERGED_CHUNKS_FETCHED}", | "Update" : 0, | "Internal" : true, | "Count Failed Values" : true | }, | { - | "ID" : 23, + | "ID" : 25, | "Name" : "${shuffleRead.LOCAL_MERGED_CHUNKS_FETCHED}", | "Update" : 0, | "Internal" : true, | "Count Failed Values" : true | }, | { - | "ID" : 24, + | "ID" : 26, | "Name" : "${shuffleRead.REMOTE_MERGED_BYTES_READ}", | "Update" : 0, | "Internal" : true, | "Count Failed Values" : true | }, | { - | "ID" : 25, + | "ID" : 27, | "Name" : "${shuffleRead.LOCAL_MERGED_BYTES_READ}", | "Update" : 0, | "Internal" : true, | "Count Failed Values" : true | }, | { - | "ID" : 26, + | "ID" : 28, | "Name" : "${shuffleRead.REMOTE_REQS_DURATION}", | "Update" : 0, | "Internal" : true, | "Count Failed Values" : true | }, | { - | "ID" : 27, + | "ID" : 29, | "Name" : "${shuffleRead.REMOTE_MERGED_REQS_DURATION}", | "Update" : 0, | "Internal" : true, | "Count Failed Values" : true | }, | { - | "ID": 28, + | "ID": 30, | "Name": "${shuffleWrite.BYTES_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 29, + | "ID": 31, | "Name": "${shuffleWrite.RECORDS_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 30, + | "ID": 32, | "Name": "${shuffleWrite.WRITE_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 31, + | "ID": 33, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 32, + | "ID": 34, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 33, + | "ID": 35, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 34, + | "ID": 36, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 35, + | "ID": 37, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, diff --git a/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala b/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala new file mode 100644 index 0000000000000..79c07f8fbfead --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class LazyTrySuite extends SparkFunSuite{ + test("LazyTry should initialize only once") { + var count = 0 + val lazyVal = LazyTry { + count += 1 + count + } + assert(count == 0) + assert(lazyVal.get == 1) + assert(count == 1) + assert(lazyVal.get == 1) + assert(count == 1) + } + + test("LazyTry should re-throw exceptions") { + val lazyVal = LazyTry { + throw new RuntimeException("test") + } + intercept[RuntimeException] { + lazyVal.get + } + intercept[RuntimeException] { + lazyVal.get + } + } + + test("LazyTry should re-throw exceptions with current caller stack-trace") { + val fileName = Thread.currentThread().getStackTrace()(1).getFileName + val lineNo = Thread.currentThread().getStackTrace()(1).getLineNumber + val lazyVal = LazyTry { + throw new RuntimeException("test") + } + + val e1 = intercept[RuntimeException] { + lazyVal.get // lineNo + 6 + } + assert(e1.getStackTrace + .exists(elem => elem.getFileName == fileName && elem.getLineNumber == lineNo + 6)) + + val e2 = intercept[RuntimeException] { + lazyVal.get // lineNo + 12 + } + assert(e2.getStackTrace + .exists(elem => elem.getFileName == fileName && elem.getLineNumber == lineNo + 12)) + } + + test("LazyTry does not lock containing object") { + class LazyContainer() { + @volatile var aSet = 0 + + val a: LazyTry[Int] = LazyTry { + aSet = 1 + aSet + } + + val b: LazyTry[Int] = LazyTry { + val t = new Thread(new Runnable { + override def run(): Unit = { + assert(a.get == 1) + } + }) + t.start() + t.join() + aSet + } + } + val container = new LazyContainer() + // Nothing is lazy initialized yet + assert(container.aSet == 0) + // This will not deadlock, thread t will initialize a, and update aSet + assert(container.b.get == 1) + assert(container.aSet == 1) + } + + // Scala lazy val tests are added to test for potential changes in the semantics of scala lazy val + + test("Scala lazy val initializing multiple times on error") { + class LazyValError() { + var counter = 0 + lazy val a = { + counter += 1 + throw new RuntimeException("test") + } + } + val lazyValError = new LazyValError() + intercept[RuntimeException] { + lazyValError.a + } + assert(lazyValError.counter == 1) + intercept[RuntimeException] { + lazyValError.a + } + assert(lazyValError.counter == 2) + } + + test("Scala lazy val locking containing object and deadlocking") { + // Note: this will change in scala 3, with different lazy vals not deadlocking with each other. + // https://docs.scala-lang.org/scala3/reference/changed-features/lazy-vals-init.html + class LazyValContainer() { + @volatile var aSet = 0 + @volatile var t: Thread = _ + + lazy val a = { + aSet = 1 + aSet + } + + lazy val b = { + t = new Thread(new Runnable { + override def run(): Unit = { + assert(a == 1) + } + }) + t.start() + t.join(1000) + aSet + } + } + val container = new LazyValContainer() + // Nothing is lazy initialized yet + assert(container.aSet == 0) + // This will deadlock, because b will take monitor on LazyValContainer, and then thread t + // will wait on that monitor, not able to initialize a. + // b will therefore see aSet == 0. + assert(container.b == 0) + // However, after b finishes initializing, the monitor will be released, and then thread t + // will finish initializing a, and set aSet to 1. + container.t.join() + assert(container.aSet == 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/SparkUncaughtExceptionHandlerSuite.scala b/core/src/test/scala/org/apache/spark/util/SparkUncaughtExceptionHandlerSuite.scala index 4843409661554..20b3fa7e7d753 100644 --- a/core/src/test/scala/org/apache/spark/util/SparkUncaughtExceptionHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SparkUncaughtExceptionHandlerSuite.scala @@ -39,7 +39,15 @@ class SparkUncaughtExceptionHandlerSuite extends SparkFunSuite { (ThrowableTypes.SparkFatalRuntimeException, true, SparkExitCode.UNCAUGHT_EXCEPTION), (ThrowableTypes.SparkFatalRuntimeException, false, 0), (ThrowableTypes.SparkFatalOutOfMemoryError, true, SparkExitCode.OOM), - (ThrowableTypes.SparkFatalOutOfMemoryError, false, SparkExitCode.OOM) + (ThrowableTypes.SparkFatalOutOfMemoryError, false, SparkExitCode.OOM), + (ThrowableTypes.NestedOOMError, true, SparkExitCode.OOM), + (ThrowableTypes.NestedOOMError, false, SparkExitCode.OOM), + (ThrowableTypes.NestedSparkFatalException, true, SparkExitCode.OOM), + (ThrowableTypes.NestedSparkFatalException, false, SparkExitCode.OOM), + (ThrowableTypes.NonFatalNestedErrors, true, SparkExitCode.UNCAUGHT_EXCEPTION), + (ThrowableTypes.NonFatalNestedErrors, false, 0), + (ThrowableTypes.DeepNestedOOMError, true, SparkExitCode.UNCAUGHT_EXCEPTION), + (ThrowableTypes.DeepNestedOOMError, false, 0) ).foreach { case (throwable: ThrowableTypes.ThrowableTypesVal, exitOnUncaughtException: Boolean, expectedExitCode) => @@ -74,6 +82,46 @@ object ThrowableTypes extends Enumeration { val SparkFatalOutOfMemoryError = ThrowableTypesVal("SparkFatalException(OutOfMemoryError)", new SparkFatalException(new OutOfMemoryError)) + // SPARK-50034: If there is a fatal error in the cause chain, + // we should also identify that fatal error and exit with the + // correct exit code. + val NestedOOMError = ThrowableTypesVal( + "NestedFatalError", + new RuntimeException("Nonfatal Level 1", + new RuntimeException("Nonfatal Level 2", + new RuntimeException("Nonfatal Level 3", + new OutOfMemoryError()))) + ) + + val NestedSparkFatalException = ThrowableTypesVal( + "NestedSparkFatalException", + new RuntimeException("Nonfatal Level 1", + new RuntimeException("Nonfatal Level 2", + new SparkFatalException(new OutOfMemoryError()))) + ) + + // Nested exception with non-fatal errors only + val NonFatalNestedErrors = ThrowableTypesVal( + "NonFatalNestedErrors", + new RuntimeException("Nonfatal Level 1", + new RuntimeException("Nonfatal Level 2", + new RuntimeException("Nonfatal Level 3", + new RuntimeException("Nonfatal Level 4"))) + ) + ) + + // Should not report as OOM when its depth is greater than killOnFatalErrorDepth + val DeepNestedOOMError = ThrowableTypesVal( + "DeepNestedOOMError", + new RuntimeException("Nonfatal Level 1", + new RuntimeException("Nonfatal Level 2", + new RuntimeException("Nonfatal Level 3", + new RuntimeException("Nonfatal Level 4", + new RuntimeException("Nonfatal Level 5", + new OutOfMemoryError())))) + ) + ) + // returns the actual Throwable by its name def getThrowableByName(name: String): Throwable = { super.withName(name).asInstanceOf[ThrowableTypesVal].t diff --git a/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala b/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala new file mode 100644 index 0000000000000..c0754ee063d67 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.{ByteArrayOutputStream, ObjectOutputStream} + +import org.apache.spark.SparkFunSuite + +class TransientLazySuite extends SparkFunSuite { + + test("TransientLazy val works") { + var test: Option[Object] = None + + val lazyval = new TransientLazy({ + test = Some(new Object()) + test + }) + + // Ensure no initialization happened before the lazy value was dereferenced + assert(test.isEmpty) + + // Ensure the first invocation creates a new object + assert(lazyval() == test && test.isDefined) + + // Ensure the subsequent invocation serves the same object + assert(lazyval() == test && test.isDefined) + } + + test("TransientLazy val is serializable") { + val lazyval = new TransientLazy({ + new Object() + }) + + // Ensure serializable before the dereference + val oos = new ObjectOutputStream(new ByteArrayOutputStream()) + oos.writeObject(lazyval) + + val dereferenced = lazyval() + + // Ensure serializable after the dereference + val oos2 = new ObjectOutputStream(new ByteArrayOutputStream()) + oos2.writeObject(lazyval) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4fe6fcf17f49f..a6e3345fc600c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -28,7 +28,7 @@ import java.util.concurrent.TimeUnit import java.util.zip.GZIPOutputStream import scala.collection.mutable.ListBuffer -import scala.util.Random +import scala.util.{Random, Try} import com.google.common.io.Files import org.apache.commons.io.IOUtils @@ -735,8 +735,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { withTempDir { tmpDir => val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) System.setProperty("spark.test.fileNameLoadB", "2") - Files.write("spark.test.fileNameLoadA true\n" + - "spark.test.fileNameLoadB 1\n", outFile, UTF_8) + Files.asCharSink(outFile, UTF_8).write("spark.test.fileNameLoadA true\n" + + "spark.test.fileNameLoadB 1\n") val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) properties .filter { case (k, v) => k.startsWith("spark.")} @@ -765,7 +765,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") - Files.write("some text", sourceFile, UTF_8) + Files.asCharSink(sourceFile, UTF_8).write("some text") val path = if (Utils.isWindows) { @@ -1523,6 +1523,116 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { conf.set(SERIALIZER, "org.apache.spark.serializer.JavaSerializer") assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false) } + + + private def throwException(): String = { + throw new Exception("test") + } + + private def callDoTry(): Try[String] = { + Utils.doTryWithCallerStacktrace { + throwException() + } + } + + private def callGetTry(t: Try[String]): String = { + Utils.getTryWithCallerStacktrace(t) + } + + private def callGetTryAgain(t: Try[String]): String = { + Utils.getTryWithCallerStacktrace(t) + } + + test("doTryWithCallerStacktrace and getTryWithCallerStacktrace") { + val t = callDoTry() + + val e1 = intercept[Exception] { + callGetTry(t) + } + // Uncomment for manual inspection + // e1.printStackTrace() + // Example: + // java.lang.Exception: test + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.Utils$.getTryWithCallerStacktrace(Utils.scala:1639) + // at org.apache.spark.util.UtilsSuite.callGetTry(UtilsSuite.scala:1650) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1661) + // <- callGetTry is seen as calling getTryWithCallerStacktrace + + val st1 = e1.getStackTrace + // throwException should be on the stack trace + assert(st1.exists(_.getMethodName == "throwException")) + // callDoTry shouldn't be on the stack trace, but callGetTry should be. + assert(!st1.exists(_.getMethodName == "callDoTry")) + assert(st1.exists(_.getMethodName == "callGetTry")) + + // The original stack trace with callDoTry should be in the suppressed exceptions. + // Example: + // scalastyle:off line.size.limit + // Suppressed: java.lang.Exception: Full stacktrace of original doTryWithCallerStacktrace caller + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.UtilsSuite.callDoTry(UtilsSuite.scala:1645) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1658) + // ... 56 more + // scalastyle:on line.size.limit + val origSt = e1.getSuppressed.find( + _.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE) + assert(origSt.isDefined) + assert(origSt.get.getStackTrace.exists(_.getMethodName == "throwException")) + assert(origSt.get.getStackTrace.exists(_.getMethodName == "callDoTry")) + + // The stack trace under Try should be in the suppressed exceptions. + // Example: + // Suppressed: java.lang.Exception: Stacktrace under doTryWithCallerStacktrace + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala: 1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala: 1645) + // at scala.util.Try$.apply(Try.scala: 213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala: 1586) + val trySt = e1.getSuppressed.find( + _.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE) + assert(trySt.isDefined) + // calls under callDoTry should be present. + assert(trySt.get.getStackTrace.exists(_.getMethodName == "throwException")) + // callDoTry should be removed. + assert(!trySt.get.getStackTrace.exists(_.getMethodName == "callDoTry")) + + val e2 = intercept[Exception] { + callGetTryAgain(t) + } + // Uncomment for manual inspection + // e2.printStackTrace() + // Example: + // java.lang.Exception: test + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.Utils$.getTryWithCallerStacktrace(Utils.scala:1639) + // at org.apache.spark.util.UtilsSuite.callGetTryAgain(UtilsSuite.scala:1654) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1711) + // <- callGetTryAgain is seen as calling getTryWithCallerStacktrace + + val st2 = e2.getStackTrace + // throwException should be on the stack trace + assert(st2.exists(_.getMethodName == "throwException")) + // callDoTry shouldn't be on the stack trace, but callGetTryAgain should be. + assert(!st2.exists(_.getMethodName == "callDoTry")) + assert(st2.exists(_.getMethodName == "callGetTryAgain")) + // callGetTry that we called before shouldn't be on the stack trace. + assert(!st2.exists(_.getMethodName == "callGetTry")) + + // Unfortunately, this utility is not able to clone the exception, but modifies it in place, + // so now e1 is also pointing to "callGetTryAgain" instead of "callGetTry". + val st1Again = e1.getStackTrace + assert(st1Again.exists(_.getMethodName == "callGetTryAgain")) + assert(!st1Again.exists(_.getMethodName == "callGetTry")) + } } private class SimpleExtension diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 8bad50951a78f..d8c9196293950 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -31,12 +31,13 @@ graphlib-dot.min.js sorttable.js vis-timeline-graph2d.min.js vis-timeline-graph2d.min.css -dataTables.bootstrap4.*.min.css -dataTables.bootstrap4.*.min.js +dataTables.bootstrap4.min.css +dataTables.bootstrap4.min.js dataTables.rowsGroup.js jquery.blockUI.min.js jquery.cookies.2.2.0.min.js -jquery.dataTables.*.min.js +jquery.dataTables.min.css +jquery.dataTables.min.js jquery.mustache.js .*\.avsc .*\.txt @@ -125,9 +126,6 @@ exported_table/* node_modules spark-events-broken/* SqlBaseLexer.tokens -# Spark Connect related files with custom licence -any.proto -empty.proto .*\.explain .*\.proto.bin LimitedInputStream.java @@ -138,4 +136,6 @@ people.xml ui-test/package.json ui-test/package-lock.json core/src/main/resources/org/apache/spark/ui/static/package.json +testCommitLog .*\.har +.nojekyll diff --git a/dev/connect-check-protos.py b/dev/check-protos.py similarity index 73% rename from dev/connect-check-protos.py rename to dev/check-protos.py index 12a499caf866c..bfca8b27be21c 100755 --- a/dev/connect-check-protos.py +++ b/dev/check-protos.py @@ -18,7 +18,7 @@ # # Utility for checking whether generated codes in PySpark are out of sync. -# usage: ./dev/connect-check-protos.py +# usage: ./dev/check-protos.py import os import sys @@ -43,12 +43,12 @@ def run_cmd(cmd): return subprocess.check_output(cmd.split(" ")).decode("utf-8") -def check_connect_protos(): - print("Start checking the generated codes in pyspark-connect.") - with tempfile.TemporaryDirectory(prefix="check_connect_protos") as tmp: - run_cmd(f"{SPARK_HOME}/dev/connect-gen-protos.sh {tmp}") +def check_protos(module_name, cmp_path, proto_path): + print(f"Start checking the generated codes in pyspark-${module_name}.") + with tempfile.TemporaryDirectory(prefix=f"check_${module_name}__protos") as tmp: + run_cmd(f"{SPARK_HOME}/dev/gen-protos.sh {module_name} {tmp}") result = filecmp.dircmp( - f"{SPARK_HOME}/python/pyspark/sql/connect/proto/", + f"{SPARK_HOME}/{cmp_path}", tmp, ignore=["__init__.py", "__pycache__"], ) @@ -71,14 +71,17 @@ def check_connect_protos(): success = False if success: - print("Finish checking the generated codes in pyspark-connect: SUCCESS") + print(f"Finish checking the generated codes in pyspark-${module_name}: SUCCESS") else: fail( "Generated files for pyspark-connect are out of sync! " - "If you have touched files under connect/common/src/main/protobuf/, " - "please run ./dev/connect-gen-protos.sh. " + f"If you have touched files under ${proto_path}, " + f"please run ./dev/${module_name}-gen-protos.sh. " "If you haven't touched any file above, please rebase your PR against main branch." ) -check_connect_protos() +check_protos( + "connect", "python/pyspark/sql/connect/proto/", "sql/connect/common/src/main/protobuf/" +) +check_protos("streaming", "python/pyspark/sql/streaming/proto/", "sql/core/src/main/protobuf/") diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 677381704427c..9925ae406dbd9 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -68,4 +68,6 @@ files="src/main/java/org/apache/spark/network/util/LimitedInputStream.java" /> + diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index c4023a84ee3cf..4285028109419 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -206,6 +206,10 @@ + + + + diff --git a/dev/connect-gen-protos.sh b/dev/connect-gen-protos.sh index 546266c91f88a..8ed323cc42599 100755 --- a/dev/connect-gen-protos.sh +++ b/dev/connect-gen-protos.sh @@ -24,80 +24,4 @@ if [[ $# -gt 1 ]]; then exit -1 fi - -SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" -cd "$SPARK_HOME" - - -OUTPUT_PATH=${SPARK_HOME}/python/pyspark/sql/connect/proto/ -if [[ $# -eq 1 ]]; then - rm -Rf $1 - mkdir -p $1 - OUTPUT_PATH=$1 -fi - -pushd connect/common/src/main - -LICENSE=$(cat <<'EOF' -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -EOF) -echo "$LICENSE" > /tmp/tmp_licence - - -# Delete the old generated protobuf files. -rm -Rf gen - -# Now, regenerate the new files -buf generate --debug -vvv - -# We need to edit the generate python files to account for the actual package location and not -# the one generated by proto. -for f in `find gen/proto/python -name "*.py*"`; do - # First fix the imports. - if [[ $f == *_pb2.py || $f == *_pb2_grpc.py ]]; then - sed -e 's/from spark.connect import/from pyspark.sql.connect.proto import/g' $f > $f.tmp - mv $f.tmp $f - # Now fix the module name in the serialized descriptor. - sed -e "s/DESCRIPTOR, 'spark.connect/DESCRIPTOR, 'pyspark.sql.connect.proto/g" $f > $f.tmp - mv $f.tmp $f - elif [[ $f == *.pyi ]]; then - sed -e 's/import spark.connect./import pyspark.sql.connect.proto./g' -e 's/spark.connect./pyspark.sql.connect.proto./g' -e '/ *@typing_extensions\.final/d' $f > $f.tmp - mv $f.tmp $f - fi - - # Prepend the Apache licence header to the files. - cp $f $f.bak - cat /tmp/tmp_licence $f.bak > $f - - LC=$(wc -l < $f) - echo $LC - if [[ $f == *_grpc.py && $LC -eq 20 ]]; then - rm $f - fi - rm $f.bak -done - -black --config $SPARK_HOME/dev/pyproject.toml gen/proto/python - -# Last step copy the result files to the destination module. -for f in `find gen/proto/python -name "*.py*"`; do - cp $f $OUTPUT_PATH -done - -# Clean up everything. -rm -Rf gen +./dev/gen-protos.sh connect "$@" diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 19589b951a6e1..8654c8317ae49 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -536,7 +536,7 @@ if [[ "$1" == "publish-release" ]]; then file_short=$(echo $file | sed -e "s/\.\///") dest_url="$nexus_upload/org/apache/spark/$file_short" echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + curl --retry 3 --retry-all-errors -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url done echo "Closing nexus staging repository" diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index ca9e10bebfc53..8a9890bf68dde 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -17,6 +17,11 @@ # Image for building Spark releases. Based on Ubuntu 22.04. FROM ubuntu:jammy-20240227 +LABEL org.opencontainers.image.authors="Apache Spark project " +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Apache Spark Release Manager Image" +# Overwrite this label to avoid exposing the underlying Ubuntu OS version label +LABEL org.opencontainers.image.version="" ENV FULL_REFRESH_DATE 20240318 @@ -87,17 +92,17 @@ ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library RUN add-apt-repository ppa:pypy/ppa -RUN mkdir -p /usr/local/pypy/pypy3.9 && \ - curl -sqL https://downloads.python.org/pypy/pypy3.9-v7.3.16-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.9 --strip-components=1 && \ - ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.8 && \ - ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 +RUN mkdir -p /usr/local/pypy/pypy3.10 && \ + curl -sqL https://downloads.python.org/pypy/pypy3.10-v7.3.17-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.10 --strip-components=1 && \ + ln -sf /usr/local/pypy/pypy3.10/bin/pypy /usr/local/bin/pypy3.10 && \ + ln -sf /usr/local/pypy/pypy3.10/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 -RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml +RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.3' scipy coverage matplotlib lxml -ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 twine==3.4.1" +ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 twine==3.4.1" # Python deps for Spark Connect -ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4" +ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.28.3 googleapis-common-protos==1.65.0" # Install Python 3.10 packages RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 @@ -126,12 +131,13 @@ RUN python3.9 -m pip install --force $BASIC_PIP_PKGS unittest-xml-reporting $CON RUN python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ -'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ +'pandas-stubs==1.2.0.53' 'grpcio==1.67.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' RUN python3.9 -m pip list RUN gem install --no-document "bundler:2.4.22" RUN ln -s "$(which python3.9)" "/usr/local/bin/python" +RUN ln -s "$(which python3.9)" "/usr/local/bin/python3" WORKDIR /opt/spark-rm/output diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 9d00f178ef9b5..9f51a70745f27 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -1,10 +1,10 @@ HikariCP/2.5.1//HikariCP-2.5.1.jar JLargeArrays/1.5//JLargeArrays-1.5.jar JTransforms/3.1//JTransforms-3.1.jar -RoaringBitmap/1.1.0//RoaringBitmap-1.1.0.jar +RoaringBitmap/1.3.0//RoaringBitmap-1.3.0.jar ST4/4.0.4//ST4-4.0.4.jar activation/1.1.1//activation-1.1.1.jar -aircompressor/0.27//aircompressor-0.27.jar +aircompressor/2.0.2//aircompressor-2.0.2.jar algebra_2.13/2.8.0//algebra_2.13-2.8.0.jar aliyun-java-sdk-core/4.5.10//aliyun-java-sdk-core-4.5.10.jar aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar @@ -13,18 +13,18 @@ aliyun-sdk-oss/3.13.2//aliyun-sdk-oss-3.13.2.jar annotations/17.0.0//annotations-17.0.0.jar antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar antlr4-runtime/4.13.1//antlr4-runtime-4.13.1.jar -aopalliance-repackaged/3.0.3//aopalliance-repackaged-3.0.3.jar +aopalliance-repackaged/3.0.6//aopalliance-repackaged-3.0.6.jar arpack/3.0.3//arpack-3.0.3.jar arpack_combined_all/0.1//arpack_combined_all-0.1.jar -arrow-format/17.0.0//arrow-format-17.0.0.jar -arrow-memory-core/17.0.0//arrow-memory-core-17.0.0.jar -arrow-memory-netty-buffer-patch/17.0.0//arrow-memory-netty-buffer-patch-17.0.0.jar -arrow-memory-netty/17.0.0//arrow-memory-netty-17.0.0.jar -arrow-vector/17.0.0//arrow-vector-17.0.0.jar +arrow-format/18.1.0//arrow-format-18.1.0.jar +arrow-memory-core/18.1.0//arrow-memory-core-18.1.0.jar +arrow-memory-netty-buffer-patch/18.1.0//arrow-memory-netty-buffer-patch-18.1.0.jar +arrow-memory-netty/18.1.0//arrow-memory-netty-18.1.0.jar +arrow-vector/18.1.0//arrow-vector-18.1.0.jar audience-annotations/0.12.0//audience-annotations-0.12.0.jar -avro-ipc/1.11.3//avro-ipc-1.11.3.jar -avro-mapred/1.11.3//avro-mapred-1.11.3.jar -avro/1.11.3//avro-1.11.3.jar +avro-ipc/1.12.0//avro-ipc-1.12.0.jar +avro-mapred/1.12.0//avro-mapred-1.12.0.jar +avro/1.12.0//avro-1.12.0.jar azure-data-lake-store-sdk/2.3.9//azure-data-lake-store-sdk-2.3.9.jar azure-keyvault-core/1.0.0//azure-keyvault-core-1.0.0.jar azure-storage/7.0.1//azure-storage-7.0.1.jar @@ -33,51 +33,54 @@ breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar bundle/2.24.6//bundle-2.24.6.jar cats-kernel_2.13/2.8.0//cats-kernel_2.13-2.8.0.jar +checker-qual/3.42.0//checker-qual-3.42.0.jar chill-java/0.10.0//chill-java-0.10.0.jar chill_2.13/0.10.0//chill_2.13-0.10.0.jar -commons-cli/1.8.0//commons-cli-1.8.0.jar +commons-cli/1.9.0//commons-cli-1.9.0.jar commons-codec/1.17.1//commons-codec-1.17.1.jar commons-collections/3.2.2//commons-collections-3.2.2.jar commons-collections4/4.4//commons-collections4-4.4.jar commons-compiler/3.1.9//commons-compiler-3.1.9.jar -commons-compress/1.26.2//commons-compress-1.26.2.jar +commons-compress/1.27.1//commons-compress-1.27.1.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar -commons-io/2.16.1//commons-io-2.16.1.jar +commons-io/2.18.0//commons-io-2.18.0.jar commons-lang/2.6//commons-lang-2.6.jar -commons-lang3/3.15.0//commons-lang3-3.15.0.jar +commons-lang3/3.17.0//commons-lang3-3.17.0.jar commons-math3/3.6.1//commons-math3-3.6.1.jar commons-pool/1.5.4//commons-pool-1.5.4.jar commons-text/1.12.0//commons-text-1.12.0.jar compress-lzf/1.1.2//compress-lzf-1.1.2.jar -curator-client/5.7.0//curator-client-5.7.0.jar -curator-framework/5.7.0//curator-framework-5.7.0.jar -curator-recipes/5.7.0//curator-recipes-5.7.0.jar +curator-client/5.7.1//curator-client-5.7.1.jar +curator-framework/5.7.1//curator-framework-5.7.1.jar +curator-recipes/5.7.1//curator-recipes-5.7.1.jar datanucleus-api-jdo/4.2.4//datanucleus-api-jdo-4.2.4.jar datanucleus-core/4.1.17//datanucleus-core-4.1.17.jar datanucleus-rdbms/4.1.19//datanucleus-rdbms-4.1.19.jar -datasketches-java/6.0.0//datasketches-java-6.0.0.jar -datasketches-memory/2.2.0//datasketches-memory-2.2.0.jar +datasketches-java/6.1.1//datasketches-java-6.1.1.jar +datasketches-memory/3.0.2//datasketches-memory-3.0.2.jar derby/10.16.1.1//derby-10.16.1.1.jar derbyshared/10.16.1.1//derbyshared-10.16.1.1.jar derbytools/10.16.1.1//derbytools-10.16.1.1.jar dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar +error_prone_annotations/2.26.1//error_prone_annotations-2.26.1.jar esdk-obs-java/3.20.4.2//esdk-obs-java-3.20.4.2.jar +failureaccess/1.0.2//failureaccess-1.0.2.jar flatbuffers-java/24.3.25//flatbuffers-java-24.3.25.jar -gcs-connector/hadoop3-2.2.21/shaded/gcs-connector-hadoop3-2.2.21-shaded.jar +gcs-connector/hadoop3-2.2.25/shaded/gcs-connector-hadoop3-2.2.25-shaded.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar -gson/2.2.4//gson-2.2.4.jar -guava/14.0.1//guava-14.0.1.jar -hadoop-aliyun/3.4.0//hadoop-aliyun-3.4.0.jar -hadoop-annotations/3.4.0//hadoop-annotations-3.4.0.jar -hadoop-aws/3.4.0//hadoop-aws-3.4.0.jar -hadoop-azure-datalake/3.4.0//hadoop-azure-datalake-3.4.0.jar -hadoop-azure/3.4.0//hadoop-azure-3.4.0.jar -hadoop-client-api/3.4.0//hadoop-client-api-3.4.0.jar -hadoop-client-runtime/3.4.0//hadoop-client-runtime-3.4.0.jar -hadoop-cloud-storage/3.4.0//hadoop-cloud-storage-3.4.0.jar -hadoop-huaweicloud/3.4.0//hadoop-huaweicloud-3.4.0.jar -hadoop-shaded-guava/1.2.0//hadoop-shaded-guava-1.2.0.jar +gson/2.11.0//gson-2.11.0.jar +guava/33.2.1-jre//guava-33.2.1-jre.jar +hadoop-aliyun/3.4.1//hadoop-aliyun-3.4.1.jar +hadoop-annotations/3.4.1//hadoop-annotations-3.4.1.jar +hadoop-aws/3.4.1//hadoop-aws-3.4.1.jar +hadoop-azure-datalake/3.4.1//hadoop-azure-datalake-3.4.1.jar +hadoop-azure/3.4.1//hadoop-azure-3.4.1.jar +hadoop-client-api/3.4.1//hadoop-client-api-3.4.1.jar +hadoop-client-runtime/3.4.1//hadoop-client-runtime-3.4.1.jar +hadoop-cloud-storage/3.4.1//hadoop-cloud-storage-3.4.1.jar +hadoop-huaweicloud/3.4.1//hadoop-huaweicloud-3.4.1.jar +hadoop-shaded-guava/1.3.0//hadoop-shaded-guava-1.3.0.jar hive-beeline/2.3.10//hive-beeline-2.3.10.jar hive-cli/2.3.10//hive-cli-2.3.10.jar hive-common/2.3.10//hive-common-2.3.10.jar @@ -92,26 +95,27 @@ hive-shims-common/2.3.10//hive-shims-common-2.3.10.jar hive-shims-scheduler/2.3.10//hive-shims-scheduler-2.3.10.jar hive-shims/2.3.10//hive-shims-2.3.10.jar hive-storage-api/2.8.1//hive-storage-api-2.8.1.jar -hk2-api/3.0.3//hk2-api-3.0.3.jar -hk2-locator/3.0.3//hk2-locator-3.0.3.jar -hk2-utils/3.0.3//hk2-utils-3.0.3.jar +hk2-api/3.0.6//hk2-api-3.0.6.jar +hk2-locator/3.0.6//hk2-locator-3.0.6.jar +hk2-utils/3.0.6//hk2-utils-3.0.6.jar httpclient/4.5.14//httpclient-4.5.14.jar httpcore/4.4.16//httpcore-4.4.16.jar -icu4j/75.1//icu4j-75.1.jar +icu4j/76.1//icu4j-76.1.jar ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.2//ivy-2.5.2.jar -jackson-annotations/2.17.2//jackson-annotations-2.17.2.jar +j2objc-annotations/3.0.0//j2objc-annotations-3.0.0.jar +jackson-annotations/2.18.1//jackson-annotations-2.18.1.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar -jackson-core/2.17.2//jackson-core-2.17.2.jar -jackson-databind/2.17.2//jackson-databind-2.17.2.jar -jackson-dataformat-cbor/2.17.2//jackson-dataformat-cbor-2.17.2.jar -jackson-dataformat-yaml/2.17.2//jackson-dataformat-yaml-2.17.2.jar +jackson-core/2.18.1//jackson-core-2.18.1.jar +jackson-databind/2.18.1//jackson-databind-2.18.1.jar +jackson-dataformat-cbor/2.18.1//jackson-dataformat-cbor-2.18.1.jar +jackson-dataformat-yaml/2.18.1//jackson-dataformat-yaml-2.18.1.jar jackson-datatype-jdk8/2.17.0//jackson-datatype-jdk8-2.17.0.jar -jackson-datatype-jsr310/2.17.2//jackson-datatype-jsr310-2.17.2.jar +jackson-datatype-jsr310/2.18.1//jackson-datatype-jsr310-2.18.1.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.13/2.17.2//jackson-module-scala_2.13-2.17.2.jar -jakarta.annotation-api/2.0.0//jakarta.annotation-api-2.0.0.jar +jackson-module-scala_2.13/2.18.1//jackson-module-scala_2.13-2.18.1.jar +jakarta.annotation-api/2.1.1//jakarta.annotation-api-2.1.1.jar jakarta.inject-api/2.0.1//jakarta.inject-api-2.0.1.jar jakarta.servlet-api/5.0.0//jakarta.servlet-api-5.0.0.jar jakarta.validation-api/3.0.2//jakarta.validation-api-3.0.2.jar @@ -120,28 +124,29 @@ jakarta.xml.bind-api/2.3.2//jakarta.xml.bind-api-2.3.2.jar janino/3.1.9//janino-3.1.9.jar java-diff-utils/4.12//java-diff-utils-4.12.jar java-xmlbuilder/1.2//java-xmlbuilder-1.2.jar -javassist/3.29.2-GA//javassist-3.29.2-GA.jar +javassist/3.30.2-GA//javassist-3.30.2-GA.jar javax.jdo/3.2.0-m3//javax.jdo-3.2.0-m3.jar javax.servlet-api/4.0.1//javax.servlet-api-4.0.1.jar javolution/5.5.1//javolution-5.5.1.jar jaxb-api/2.2.11//jaxb-api-2.2.11.jar jaxb-runtime/2.3.2//jaxb-runtime-2.3.2.jar -jcl-over-slf4j/2.0.13//jcl-over-slf4j-2.0.13.jar +jcl-over-slf4j/2.0.16//jcl-over-slf4j-2.0.16.jar jdo-api/3.0.1//jdo-api-3.0.1.jar jdom2/2.0.6//jdom2-2.0.6.jar -jersey-client/3.0.12//jersey-client-3.0.12.jar -jersey-common/3.0.12//jersey-common-3.0.12.jar -jersey-container-servlet-core/3.0.12//jersey-container-servlet-core-3.0.12.jar -jersey-container-servlet/3.0.12//jersey-container-servlet-3.0.12.jar -jersey-hk2/3.0.12//jersey-hk2-3.0.12.jar -jersey-server/3.0.12//jersey-server-3.0.12.jar +jersey-client/3.0.16//jersey-client-3.0.16.jar +jersey-common/3.0.16//jersey-common-3.0.16.jar +jersey-container-servlet-core/3.0.16//jersey-container-servlet-core-3.0.16.jar +jersey-container-servlet/3.0.16//jersey-container-servlet-3.0.16.jar +jersey-hk2/3.0.16//jersey-hk2-3.0.16.jar +jersey-server/3.0.16//jersey-server-3.0.16.jar jettison/1.5.4//jettison-1.5.4.jar -jetty-util-ajax/11.0.21//jetty-util-ajax-11.0.21.jar -jetty-util/11.0.21//jetty-util-11.0.21.jar +jetty-util-ajax/11.0.24//jetty-util-ajax-11.0.24.jar +jetty-util/11.0.24//jetty-util-11.0.24.jar +jjwt-api/0.12.6//jjwt-api-0.12.6.jar jline/2.14.6//jline-2.14.6.jar -jline/3.25.1//jline-3.25.1.jar +jline/3.26.3//jline-3.26.3.jar jna/5.14.0//jna-5.14.0.jar -joda-time/2.12.7//joda-time-2.12.7.jar +joda-time/2.13.0//joda-time-2.13.0.jar jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar json/1.8//json-1.8.jar @@ -152,110 +157,111 @@ json4s-jackson_2.13/4.0.7//json4s-jackson_2.13-4.0.7.jar json4s-scalap_2.13/4.0.7//json4s-scalap_2.13-4.0.7.jar jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar -jul-to-slf4j/2.0.13//jul-to-slf4j-2.0.13.jar +jul-to-slf4j/2.0.16//jul-to-slf4j-2.0.16.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client-api/6.13.1//kubernetes-client-api-6.13.1.jar -kubernetes-client/6.13.1//kubernetes-client-6.13.1.jar -kubernetes-httpclient-okhttp/6.13.1//kubernetes-httpclient-okhttp-6.13.1.jar -kubernetes-model-admissionregistration/6.13.1//kubernetes-model-admissionregistration-6.13.1.jar -kubernetes-model-apiextensions/6.13.1//kubernetes-model-apiextensions-6.13.1.jar -kubernetes-model-apps/6.13.1//kubernetes-model-apps-6.13.1.jar -kubernetes-model-autoscaling/6.13.1//kubernetes-model-autoscaling-6.13.1.jar -kubernetes-model-batch/6.13.1//kubernetes-model-batch-6.13.1.jar -kubernetes-model-certificates/6.13.1//kubernetes-model-certificates-6.13.1.jar -kubernetes-model-common/6.13.1//kubernetes-model-common-6.13.1.jar -kubernetes-model-coordination/6.13.1//kubernetes-model-coordination-6.13.1.jar -kubernetes-model-core/6.13.1//kubernetes-model-core-6.13.1.jar -kubernetes-model-discovery/6.13.1//kubernetes-model-discovery-6.13.1.jar -kubernetes-model-events/6.13.1//kubernetes-model-events-6.13.1.jar -kubernetes-model-extensions/6.13.1//kubernetes-model-extensions-6.13.1.jar -kubernetes-model-flowcontrol/6.13.1//kubernetes-model-flowcontrol-6.13.1.jar -kubernetes-model-gatewayapi/6.13.1//kubernetes-model-gatewayapi-6.13.1.jar -kubernetes-model-metrics/6.13.1//kubernetes-model-metrics-6.13.1.jar -kubernetes-model-networking/6.13.1//kubernetes-model-networking-6.13.1.jar -kubernetes-model-node/6.13.1//kubernetes-model-node-6.13.1.jar -kubernetes-model-policy/6.13.1//kubernetes-model-policy-6.13.1.jar -kubernetes-model-rbac/6.13.1//kubernetes-model-rbac-6.13.1.jar -kubernetes-model-resource/6.13.1//kubernetes-model-resource-6.13.1.jar -kubernetes-model-scheduling/6.13.1//kubernetes-model-scheduling-6.13.1.jar -kubernetes-model-storageclass/6.13.1//kubernetes-model-storageclass-6.13.1.jar +kubernetes-client-api/6.13.4//kubernetes-client-api-6.13.4.jar +kubernetes-client/6.13.4//kubernetes-client-6.13.4.jar +kubernetes-httpclient-okhttp/6.13.4//kubernetes-httpclient-okhttp-6.13.4.jar +kubernetes-model-admissionregistration/6.13.4//kubernetes-model-admissionregistration-6.13.4.jar +kubernetes-model-apiextensions/6.13.4//kubernetes-model-apiextensions-6.13.4.jar +kubernetes-model-apps/6.13.4//kubernetes-model-apps-6.13.4.jar +kubernetes-model-autoscaling/6.13.4//kubernetes-model-autoscaling-6.13.4.jar +kubernetes-model-batch/6.13.4//kubernetes-model-batch-6.13.4.jar +kubernetes-model-certificates/6.13.4//kubernetes-model-certificates-6.13.4.jar +kubernetes-model-common/6.13.4//kubernetes-model-common-6.13.4.jar +kubernetes-model-coordination/6.13.4//kubernetes-model-coordination-6.13.4.jar +kubernetes-model-core/6.13.4//kubernetes-model-core-6.13.4.jar +kubernetes-model-discovery/6.13.4//kubernetes-model-discovery-6.13.4.jar +kubernetes-model-events/6.13.4//kubernetes-model-events-6.13.4.jar +kubernetes-model-extensions/6.13.4//kubernetes-model-extensions-6.13.4.jar +kubernetes-model-flowcontrol/6.13.4//kubernetes-model-flowcontrol-6.13.4.jar +kubernetes-model-gatewayapi/6.13.4//kubernetes-model-gatewayapi-6.13.4.jar +kubernetes-model-metrics/6.13.4//kubernetes-model-metrics-6.13.4.jar +kubernetes-model-networking/6.13.4//kubernetes-model-networking-6.13.4.jar +kubernetes-model-node/6.13.4//kubernetes-model-node-6.13.4.jar +kubernetes-model-policy/6.13.4//kubernetes-model-policy-6.13.4.jar +kubernetes-model-rbac/6.13.4//kubernetes-model-rbac-6.13.4.jar +kubernetes-model-resource/6.13.4//kubernetes-model-resource-6.13.4.jar +kubernetes-model-scheduling/6.13.4//kubernetes-model-scheduling-6.13.4.jar +kubernetes-model-storageclass/6.13.4//kubernetes-model-storageclass-6.13.4.jar lapack/3.0.3//lapack-3.0.3.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.16.0//libthrift-0.16.0.jar -log4j-1.2-api/2.22.1//log4j-1.2-api-2.22.1.jar -log4j-api/2.22.1//log4j-api-2.22.1.jar -log4j-core/2.22.1//log4j-core-2.22.1.jar -log4j-layout-template-json/2.22.1//log4j-layout-template-json-2.22.1.jar -log4j-slf4j2-impl/2.22.1//log4j-slf4j2-impl-2.22.1.jar +listenablefuture/9999.0-empty-to-avoid-conflict-with-guava//listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar +log4j-1.2-api/2.24.2//log4j-1.2-api-2.24.2.jar +log4j-api/2.24.2//log4j-api-2.24.2.jar +log4j-core/2.24.2//log4j-core-2.24.2.jar +log4j-layout-template-json/2.24.2//log4j-layout-template-json-2.24.2.jar +log4j-slf4j2-impl/2.24.2//log4j-slf4j2-impl-2.24.2.jar logging-interceptor/3.12.12//logging-interceptor-3.12.12.jar lz4-java/1.8.0//lz4-java-1.8.0.jar -metrics-core/4.2.26//metrics-core-4.2.26.jar -metrics-graphite/4.2.26//metrics-graphite-4.2.26.jar -metrics-jmx/4.2.26//metrics-jmx-4.2.26.jar -metrics-json/4.2.26//metrics-json-4.2.26.jar -metrics-jvm/4.2.26//metrics-jvm-4.2.26.jar +metrics-core/4.2.29//metrics-core-4.2.29.jar +metrics-graphite/4.2.29//metrics-graphite-4.2.29.jar +metrics-jmx/4.2.29//metrics-jmx-4.2.29.jar +metrics-json/4.2.29//metrics-json-4.2.29.jar +metrics-jvm/4.2.29//metrics-jvm-4.2.29.jar minlog/1.3.0//minlog-1.3.0.jar -netty-all/4.1.110.Final//netty-all-4.1.110.Final.jar -netty-buffer/4.1.110.Final//netty-buffer-4.1.110.Final.jar -netty-codec-http/4.1.110.Final//netty-codec-http-4.1.110.Final.jar -netty-codec-http2/4.1.110.Final//netty-codec-http2-4.1.110.Final.jar -netty-codec-socks/4.1.110.Final//netty-codec-socks-4.1.110.Final.jar -netty-codec/4.1.110.Final//netty-codec-4.1.110.Final.jar -netty-common/4.1.110.Final//netty-common-4.1.110.Final.jar -netty-handler-proxy/4.1.110.Final//netty-handler-proxy-4.1.110.Final.jar -netty-handler/4.1.110.Final//netty-handler-4.1.110.Final.jar -netty-resolver/4.1.110.Final//netty-resolver-4.1.110.Final.jar -netty-tcnative-boringssl-static/2.0.65.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-aarch_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-x86_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-aarch_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-x86_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-windows-x86_64.jar -netty-tcnative-classes/2.0.65.Final//netty-tcnative-classes-2.0.65.Final.jar -netty-transport-classes-epoll/4.1.110.Final//netty-transport-classes-epoll-4.1.110.Final.jar -netty-transport-classes-kqueue/4.1.110.Final//netty-transport-classes-kqueue-4.1.110.Final.jar -netty-transport-native-epoll/4.1.110.Final/linux-aarch_64/netty-transport-native-epoll-4.1.110.Final-linux-aarch_64.jar -netty-transport-native-epoll/4.1.110.Final/linux-riscv64/netty-transport-native-epoll-4.1.110.Final-linux-riscv64.jar -netty-transport-native-epoll/4.1.110.Final/linux-x86_64/netty-transport-native-epoll-4.1.110.Final-linux-x86_64.jar -netty-transport-native-kqueue/4.1.110.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.110.Final-osx-aarch_64.jar -netty-transport-native-kqueue/4.1.110.Final/osx-x86_64/netty-transport-native-kqueue-4.1.110.Final-osx-x86_64.jar -netty-transport-native-unix-common/4.1.110.Final//netty-transport-native-unix-common-4.1.110.Final.jar -netty-transport/4.1.110.Final//netty-transport-4.1.110.Final.jar +netty-all/4.1.114.Final//netty-all-4.1.114.Final.jar +netty-buffer/4.1.114.Final//netty-buffer-4.1.114.Final.jar +netty-codec-http/4.1.114.Final//netty-codec-http-4.1.114.Final.jar +netty-codec-http2/4.1.114.Final//netty-codec-http2-4.1.114.Final.jar +netty-codec-socks/4.1.114.Final//netty-codec-socks-4.1.114.Final.jar +netty-codec/4.1.114.Final//netty-codec-4.1.114.Final.jar +netty-common/4.1.114.Final//netty-common-4.1.114.Final.jar +netty-handler-proxy/4.1.114.Final//netty-handler-proxy-4.1.114.Final.jar +netty-handler/4.1.114.Final//netty-handler-4.1.114.Final.jar +netty-resolver/4.1.114.Final//netty-resolver-4.1.114.Final.jar +netty-tcnative-boringssl-static/2.0.69.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.69.Final-linux-aarch_64.jar +netty-tcnative-boringssl-static/2.0.69.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.69.Final-linux-x86_64.jar +netty-tcnative-boringssl-static/2.0.69.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.69.Final-osx-aarch_64.jar +netty-tcnative-boringssl-static/2.0.69.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.69.Final-osx-x86_64.jar +netty-tcnative-boringssl-static/2.0.69.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.69.Final-windows-x86_64.jar +netty-tcnative-classes/2.0.69.Final//netty-tcnative-classes-2.0.69.Final.jar +netty-transport-classes-epoll/4.1.114.Final//netty-transport-classes-epoll-4.1.114.Final.jar +netty-transport-classes-kqueue/4.1.114.Final//netty-transport-classes-kqueue-4.1.114.Final.jar +netty-transport-native-epoll/4.1.114.Final/linux-aarch_64/netty-transport-native-epoll-4.1.114.Final-linux-aarch_64.jar +netty-transport-native-epoll/4.1.114.Final/linux-riscv64/netty-transport-native-epoll-4.1.114.Final-linux-riscv64.jar +netty-transport-native-epoll/4.1.114.Final/linux-x86_64/netty-transport-native-epoll-4.1.114.Final-linux-x86_64.jar +netty-transport-native-kqueue/4.1.114.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.114.Final-osx-aarch_64.jar +netty-transport-native-kqueue/4.1.114.Final/osx-x86_64/netty-transport-native-kqueue-4.1.114.Final-osx-x86_64.jar +netty-transport-native-unix-common/4.1.114.Final//netty-transport-native-unix-common-4.1.114.Final.jar +netty-transport/4.1.114.Final//netty-transport-4.1.114.Final.jar objenesis/3.3//objenesis-3.3.jar okhttp/3.12.12//okhttp-3.12.12.jar -okio/1.15.0//okio-1.15.0.jar +okio/1.17.6//okio-1.17.6.jar opencsv/2.3//opencsv-2.3.jar opentracing-api/0.33.0//opentracing-api-0.33.0.jar opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar opentracing-util/0.33.0//opentracing-util-0.33.0.jar -orc-core/2.0.1/shaded-protobuf/orc-core-2.0.1-shaded-protobuf.jar +orc-core/2.0.3/shaded-protobuf/orc-core-2.0.3-shaded-protobuf.jar orc-format/1.0.0/shaded-protobuf/orc-format-1.0.0-shaded-protobuf.jar -orc-mapreduce/2.0.1/shaded-protobuf/orc-mapreduce-2.0.1-shaded-protobuf.jar -orc-shims/2.0.1//orc-shims-2.0.1.jar +orc-mapreduce/2.0.3/shaded-protobuf/orc-mapreduce-2.0.3-shaded-protobuf.jar +orc-shims/2.0.3//orc-shims-2.0.3.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar -parquet-column/1.14.1//parquet-column-1.14.1.jar -parquet-common/1.14.1//parquet-common-1.14.1.jar -parquet-encoding/1.14.1//parquet-encoding-1.14.1.jar -parquet-format-structures/1.14.1//parquet-format-structures-1.14.1.jar -parquet-hadoop/1.14.1//parquet-hadoop-1.14.1.jar -parquet-jackson/1.14.1//parquet-jackson-1.14.1.jar +parquet-column/1.14.4//parquet-column-1.14.4.jar +parquet-common/1.14.4//parquet-common-1.14.4.jar +parquet-encoding/1.14.4//parquet-encoding-1.14.4.jar +parquet-format-structures/1.14.4//parquet-format-structures-1.14.4.jar +parquet-hadoop/1.14.4//parquet-hadoop-1.14.4.jar +parquet-jackson/1.14.4//parquet-jackson-1.14.4.jar pickle/1.5//pickle-1.5.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar -rocksdbjni/9.2.1//rocksdbjni-9.2.1.jar +rocksdbjni/9.7.3//rocksdbjni-9.7.3.jar scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar -scala-compiler/2.13.14//scala-compiler-2.13.14.jar -scala-library/2.13.14//scala-library-2.13.14.jar +scala-compiler/2.13.15//scala-compiler-2.13.15.jar +scala-library/2.13.15//scala-library-2.13.15.jar scala-parallel-collections_2.13/1.0.4//scala-parallel-collections_2.13-1.0.4.jar scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar -scala-reflect/2.13.14//scala-reflect-2.13.14.jar +scala-reflect/2.13.15//scala-reflect-2.13.15.jar scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar -slf4j-api/2.0.13//slf4j-api-2.0.13.jar +slf4j-api/2.0.16//slf4j-api-2.0.16.jar snakeyaml-engine/2.7//snakeyaml-engine-2.7.jar -snakeyaml/2.2//snakeyaml-2.2.jar -snappy-java/1.1.10.5//snappy-java-1.1.10.5.jar +snakeyaml/2.3//snakeyaml-2.3.jar +snappy-java/1.1.10.7//snappy-java-1.1.10.7.jar spire-macros_2.13/0.18.0//spire-macros_2.13-0.18.0.jar spire-platform_2.13/0.18.0//spire-platform_2.13-0.18.0.jar spire-util_2.13/0.18.0//spire-util_2.13-0.18.0.jar @@ -264,15 +270,14 @@ stax-api/1.0.1//stax-api-1.0.1.jar stream/2.9.8//stream-2.9.8.jar super-csv/2.2.0//super-csv-2.2.0.jar threeten-extra/1.7.1//threeten-extra-1.7.1.jar -tink/1.13.0//tink-1.13.0.jar +tink/1.15.0//tink-1.15.0.jar transaction-api/1.1//transaction-api-1.1.jar -txw2/3.0.2//txw2-3.0.2.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar wildfly-openssl/1.1.3.Final//wildfly-openssl-1.1.3.Final.jar -xbean-asm9-shaded/4.25//xbean-asm9-shaded-4.25.jar +xbean-asm9-shaded/4.26//xbean-asm9-shaded-4.26.jar xmlschema-core/2.3.1//xmlschema-core-2.3.1.jar -xz/1.9//xz-1.9.jar +xz/1.10//xz-1.10.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar -zookeeper-jute/3.9.2//zookeeper-jute-3.9.2.jar -zookeeper/3.9.2//zookeeper-3.9.2.jar -zstd-jni/1.5.6-4//zstd-jni-1.5.6-4.jar +zookeeper-jute/3.9.3//zookeeper-jute-3.9.3.jar +zookeeper/3.9.3//zookeeper-3.9.3.jar +zstd-jni/1.5.6-8//zstd-jni-1.5.6-8.jar diff --git a/dev/gen-protos.sh b/dev/gen-protos.sh new file mode 100755 index 0000000000000..d169964feb853 --- /dev/null +++ b/dev/gen-protos.sh @@ -0,0 +1,127 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -ex + +SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" +cd "$SPARK_HOME" + +OUTPUT_PATH="" +MODULE="" +SOURCE_MODULE="" +TARGET_MODULE="" + +function usage() { + echo "Illegal number of parameters." + echo "Usage:./dev/gen-protos.sh [connect|streaming] [output_path]" + exit -1 +} + +if [[ $# -lt 1 || $# -gt 2 ]]; then + usage +fi + +if [[ $1 == "connect" ]]; then + MODULE="connect" + OUTPUT_PATH=${SPARK_HOME}/python/pyspark/sql/connect/proto/ + SOURCE_MODULE="spark.connect" + TARGET_MODULE="pyspark.sql.connect.proto" +elif [[ $1 == "streaming" ]]; then + MODULE="streaming" + OUTPUT_PATH=${SPARK_HOME}/python/pyspark/sql/streaming/proto/ + SOURCE_MODULE="org.apache.spark.sql.execution.streaming" + TARGET_MODULE="pyspark.sql.streaming.proto" +else + usage +fi + +if [[ $# -eq 2 ]]; then + rm -Rf $2 + mkdir -p $2 + OUTPUT_PATH=$2 +fi + +if [[ $MODULE == "connect" ]]; then + pushd sql/connect/common/src/main +elif [[ $MODULE == "streaming" ]]; then + pushd sql/core/src/main +fi + +LICENSE=$(cat <<'EOF' +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +EOF) +echo "$LICENSE" > /tmp/tmp_licence + +# Delete the old generated protobuf files. +rm -Rf gen + +# Now, regenerate the new files +buf generate --debug -vvv + +# We need to edit the generate python files to account for the actual package location and not +# the one generated by proto. +for f in `find gen/proto/python -name "*.py*"`; do + # First fix the imports. + if [[ $f == *_pb2.py || $f == *_pb2_grpc.py ]]; then + sed -e "s/from ${SOURCE_MODULE} import/from ${TARGET_MODULE} import/g" $f > $f.tmp + mv $f.tmp $f + # Now fix the module name in the serialized descriptor. + sed -e "s/DESCRIPTOR, '${SOURCE_MODULE}/DESCRIPTOR, '${TARGET_MODULE}/g" $f > $f.tmp + mv $f.tmp $f + elif [[ $f == *.pyi ]]; then + sed -e "s/import ${SOURCE_MODULE}./import ${TARGET_MODULE}./g" -e "s/${SOURCE_MODULE}./${TARGET_MODULE}./g" -e '/ *@typing_extensions\.final/d' $f > $f.tmp + mv $f.tmp $f + fi + + # Prepend the Apache licence header to the files. + cp $f $f.bak + cat /tmp/tmp_licence $f.bak > $f + + LC=$(wc -l < $f) + echo $LC + if [[ $f == *_grpc.py && $LC -eq 20 ]]; then + rm $f + fi + rm $f.bak +done + +black --config $SPARK_HOME/dev/pyproject.toml gen/proto/python + +# Last step copy the result files to the destination module. +for f in `find gen/proto/python -name "*.py*"`; do + cp $f $OUTPUT_PATH +done + +# Clean up everything. +rm -Rf gen diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 6ba9be87552ab..2817818cbc4e3 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -17,9 +17,14 @@ # Image for building and testing Spark branches. Based on Ubuntu 22.04. # See also in https://hub.docker.com/_/ubuntu -FROM ubuntu:jammy-20240227 +FROM ubuntu:jammy-20240911.1 +LABEL org.opencontainers.image.authors="Apache Spark project " +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image" +# Overwrite this label to avoid exposing the underlying Ubuntu OS version label +LABEL org.opencontainers.image.version="" -ENV FULL_REFRESH_DATE 20240318 +ENV FULL_REFRESH_DATE 20241119 ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true @@ -81,17 +86,17 @@ ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library RUN add-apt-repository ppa:pypy/ppa -RUN mkdir -p /usr/local/pypy/pypy3.9 && \ - curl -sqL https://downloads.python.org/pypy/pypy3.9-v7.3.16-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.9 --strip-components=1 && \ - ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.8 && \ - ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 +RUN mkdir -p /usr/local/pypy/pypy3.10 && \ + curl -sqL https://downloads.python.org/pypy/pypy3.10-v7.3.17-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.10 --strip-components=1 && \ + ln -sf /usr/local/pypy/pypy3.10/bin/pypy /usr/local/bin/pypy3.10 && \ + ln -sf /usr/local/pypy/pypy3.10/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 -RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml +RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.3' scipy coverage matplotlib lxml -ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" # Python deps for Spark Connect -ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4 graphviz==0.20.3" +ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.28.3 googleapis-common-protos==1.65.0 graphviz==0.20.3" # Install Python 3.10 packages RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 @@ -130,14 +135,23 @@ RUN apt-get update && apt-get install -y \ python3.12 \ && rm -rf /var/lib/apt/lists/* RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 -# TODO(SPARK-46647) Add unittest-xml-reporting into Python 3.12 image when it supports Python 3.12 RUN python3.12 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this -RUN python3.12 -m pip install $BASIC_PIP_PKGS $CONNECT_PIP_PKGS lxml && \ +RUN python3.12 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \ python3.12 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \ python3.12 -m pip install torcheval && \ python3.12 -m pip cache purge +# Install Python 3.13 at the last stage to avoid breaking the existing Python installations +RUN apt-get update && apt-get install -y \ + python3.13 \ + && rm -rf /var/lib/apt/lists/* +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13 +# TODO(SPARK-49862) Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13 image when it supports Python 3.13 +RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this +RUN python3.13 -m pip install numpy>=2.1 pyarrow>=18.0.0 six==1.16.0 pandas==2.2.3 scipy coverage matplotlib openpyxl grpcio==1.67.0 grpcio-status==1.67.0 lxml jinja2 && \ + python3.13 -m pip cache purge + # Remove unused installation packages to free up disk space -RUN apt-get remove --purge -y 'gfortran-11' 'humanity-icon-theme' 'nodejs-doc' || true +RUN apt-get remove --purge -y 'humanity-icon-theme' 'nodejs-doc' RUN apt-get autoremove --purge -y RUN apt-get clean diff --git a/dev/lint-scala b/dev/lint-scala index d44eb02ef9094..23df146a8d1b4 100755 --- a/dev/lint-scala +++ b/dev/lint-scala @@ -29,16 +29,17 @@ ERRORS=$(./build/mvn \ -Dscalafmt.skip=false \ -Dscalafmt.validateOnly=true \ -Dscalafmt.changedOnly=false \ - -pl connect/common \ - -pl connect/server \ + -pl sql/api \ + -pl sql/connect/common \ + -pl sql/connect/server \ -pl connector/connect/client/jvm \ 2>&1 | grep -e "Unformatted files found" \ ) if test ! -z "$ERRORS"; then - echo -e "The scalafmt check failed on connect or connector/connect at following occurrences:\n\n$ERRORS\n" + echo -e "The scalafmt check failed on sql/connect or connector/connect at following occurrences:\n\n$ERRORS\n" echo "Before submitting your change, please make sure to format your code using the following command:" - echo "./build/mvn scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl connect/common -pl connect/server -pl connector/connect/client/jvm" + echo "./build/mvn scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl sql/api -pl sql/connect/common -pl sql/connect/server -pl connector/connect/client/jvm" exit 1 else echo -e "Scalafmt checks passed." diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index dd60231fa83eb..53c317a39ea1e 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -187,7 +187,7 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI echo "Build flags: $@" >> "$DISTDIR/RELEASE" # Copy jars -cp "$SPARK_HOME"/assembly/target/scala*/jars/* "$DISTDIR/jars/" +cp -r "$SPARK_HOME"/assembly/target/scala*/jars/* "$DISTDIR/jars/" # Only create the hive-jackson directory if they exist. if [ -f "$DISTDIR"/jars/jackson-core-asl-1.9.13.jar ]; then diff --git a/dev/protobuf-breaking-changes-check.sh b/dev/protobuf-breaking-changes-check.sh index cb0a508f62ec8..327e54be63e62 100755 --- a/dev/protobuf-breaking-changes-check.sh +++ b/dev/protobuf-breaking-changes-check.sh @@ -21,7 +21,7 @@ set -ex if [[ $# -gt 1 ]]; then echo "Illegal number of parameters." echo "Usage: ./dev/protobuf-breaking-changes-check.sh [branch]" - echo "the default branch is 'master', available options are 'master', 'branch-3.4', etc" + echo "the default branch is 'master'" exit -1 fi @@ -33,7 +33,7 @@ if [[ $# -eq 1 ]]; then BRANCH=$1 fi -pushd connect/common/src/main && +pushd sql/connect/common/src/main && echo "Start protobuf breaking changes checking against $BRANCH" && buf breaking --against "https://github.com/apache/spark.git#branch=$BRANCH,subdir=connector/connect/common/src/main" && echo "Finsh protobuf breaking changes checking: SUCCESS" diff --git a/connector/docker/spark-test/worker/default_cmd b/dev/py-cleanup similarity index 75% rename from connector/docker/spark-test/worker/default_cmd rename to dev/py-cleanup index 1f2aac95ed699..6a2edd1040171 100755 --- a/connector/docker/spark-test/worker/default_cmd +++ b/dev/py-cleanup @@ -17,9 +17,15 @@ # limitations under the License. # -IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }') -echo "CONTAINER_IP=$IP" -export SPARK_LOCAL_IP=$IP -export SPARK_PUBLIC_DNS=$IP +# Utility for temporary files cleanup in 'python'. +# usage: ./dev/py-cleanup -/opt/spark/bin/spark-class org.apache.spark.deploy.worker.Worker $1 +set -ex + +SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" +cd "$SPARK_HOME" + +rm -rf python/target +rm -rf python/lib/pyspark.zip +rm -rf python/docs/build +rm -rf python/docs/source/reference/*/api diff --git a/dev/requirements.txt b/dev/requirements.txt index e0216a63ba790..a9874f77113ab 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -3,11 +3,11 @@ py4j>=0.10.9.7 # PySpark dependencies (optional) numpy>=1.21 -pyarrow>=4.0.0 +pyarrow>=10.0.0 six==1.16.0 -pandas>=1.4.4 +pandas>=2.0.0 scipy -plotly +plotly>=4.8 mlflow>=2.3.1 scikit-learn matplotlib @@ -58,9 +58,10 @@ black==23.9.1 py # Spark Connect (required) -grpcio>=1.62.0 -grpcio-status>=1.62.0 -googleapis-common-protos>=1.56.4 +grpcio>=1.67.0 +grpcio-status>=1.67.0 +googleapis-common-protos>=1.65.0 +protobuf==5.28.3 # Spark Connect python proto generation plugin (optional) mypy-protobuf==3.3.0 diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py deleted file mode 100755 index aa82b28e38217..0000000000000 --- a/dev/run-tests-jenkins.py +++ /dev/null @@ -1,236 +0,0 @@ -#!/usr/bin/env python3 - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import sys -import json -import functools -import subprocess -from urllib.request import urlopen -from urllib.request import Request -from urllib.error import HTTPError, URLError - -from sparktestsupport import SPARK_HOME, ERROR_CODES -from sparktestsupport.shellutils import run_cmd - - -def print_err(msg): - """ - Given a set of arguments, will print them to the STDERR stream - """ - print(msg, file=sys.stderr) - - -def post_message_to_github(msg, ghprb_pull_id): - print("Attempting to post to GitHub...") - - api_url = os.getenv("GITHUB_API_BASE", "https://api.github.com/repos/apache/spark") - url = api_url + "/issues/" + ghprb_pull_id + "/comments" - github_oauth_key = os.environ["GITHUB_OAUTH_KEY"] - - posted_message = json.dumps({"body": msg}) - request = Request( - url, - headers={ - "Authorization": "token %s" % github_oauth_key, - "Content-Type": "application/json", - }, - data=posted_message.encode("utf-8"), - ) - try: - response = urlopen(request) - - if response.getcode() == 201: - print(" > Post successful.") - except HTTPError as http_e: - print_err("Failed to post message to GitHub.") - print_err(" > http_code: %s" % http_e.code) - print_err(" > api_response: %s" % http_e.read()) - print_err(" > data: %s" % posted_message) - except URLError as url_e: - print_err("Failed to post message to GitHub.") - print_err(" > urllib_status: %s" % url_e.reason[1]) - print_err(" > data: %s" % posted_message) - - -def pr_message( - build_display_name, build_url, ghprb_pull_id, short_commit_hash, commit_url, msg, post_msg="" -): - # align the arguments properly for string formatting - str_args = ( - build_display_name, - msg, - build_url, - ghprb_pull_id, - short_commit_hash, - commit_url, - str(" " + post_msg + ".") if post_msg else ".", - ) - return "**[Test build %s %s](%stestReport)** for PR %s at commit [`%s`](%s)%s" % str_args - - -def run_pr_checks(pr_tests, ghprb_actual_commit, sha1): - """ - Executes a set of pull request checks to ease development and report issues with various - components such as style, linting, dependencies, compatibilities, etc. - @return a list of messages to post back to GitHub - """ - # Ensure we save off the current HEAD to revert to - current_pr_head = run_cmd(["git", "rev-parse", "HEAD"], return_output=True).strip() - pr_results = list() - - for pr_test in pr_tests: - test_name = pr_test + ".sh" - pr_results.append( - run_cmd( - [ - "bash", - os.path.join(SPARK_HOME, "dev", "tests", test_name), - ghprb_actual_commit, - sha1, - ], - return_output=True, - ).rstrip() - ) - # Ensure, after each test, that we're back on the current PR - run_cmd(["git", "checkout", "-f", current_pr_head]) - return pr_results - - -def run_tests(tests_timeout): - """ - Runs the `dev/run-tests` script and responds with the correct error message - under the various failure scenarios. - @return a tuple containing the test result code and the result note to post to GitHub - """ - - test_result_code = subprocess.Popen( - ["timeout", tests_timeout, os.path.join(SPARK_HOME, "dev", "run-tests")] - ).wait() - - failure_note_by_errcode = { - # error to denote run-tests script failures: - 1: "executing the `dev/run-tests` script", - ERROR_CODES["BLOCK_GENERAL"]: "some tests", - ERROR_CODES["BLOCK_RAT"]: "RAT tests", - ERROR_CODES["BLOCK_SCALA_STYLE"]: "Scala style tests", - ERROR_CODES["BLOCK_JAVA_STYLE"]: "Java style tests", - ERROR_CODES["BLOCK_PYTHON_STYLE"]: "Python style tests", - ERROR_CODES["BLOCK_R_STYLE"]: "R style tests", - ERROR_CODES["BLOCK_DOCUMENTATION"]: "to generate documentation", - ERROR_CODES["BLOCK_BUILD"]: "to build", - ERROR_CODES["BLOCK_BUILD_TESTS"]: "build dependency tests", - ERROR_CODES["BLOCK_MIMA"]: "MiMa tests", - ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: "Spark unit tests", - ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: "PySpark unit tests", - ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: "PySpark pip packaging tests", - ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: "SparkR unit tests", - ERROR_CODES["BLOCK_TIMEOUT"]: "from timeout after a configured wait of `%s`" - % (tests_timeout), - } - - if test_result_code == 0: - test_result_note = " * This patch passes all tests." - else: - note = failure_note_by_errcode.get( - test_result_code, "due to an unknown error code, %s" % test_result_code - ) - test_result_note = " * This patch **fails %s**." % note - - return [test_result_code, test_result_note] - - -def main(): - # Important Environment Variables - # --- - # $ghprbActualCommit - # This is the hash of the most recent commit in the PR. - # The merge-base of this and master is the commit from which the PR was branched. - # $sha1 - # If the patch merges cleanly, this is a reference to the merge commit hash - # (e.g. "origin/pr/2606/merge"). - # If the patch does not merge cleanly, it is equal to $ghprbActualCommit. - # The merge-base of this and master in the case of a clean merge is the most recent commit - # against master. - ghprb_pull_id = os.environ["ghprbPullId"] - ghprb_actual_commit = os.environ["ghprbActualCommit"] - ghprb_pull_title = os.environ["ghprbPullTitle"].lower() - sha1 = os.environ["sha1"] - - # Marks this build as a pull request build. - os.environ["SPARK_JENKINS_PRB"] = "true" - # Switch to a Maven-based build if the PR title contains "test-maven": - if "test-maven" in ghprb_pull_title: - os.environ["SPARK_JENKINS_BUILD_TOOL"] = "maven" - if "test-hadoop3" in ghprb_pull_title: - os.environ["SPARK_JENKINS_BUILD_PROFILE"] = "hadoop3" - # Switch the Scala profile based on the PR title: - if "test-scala2.13" in ghprb_pull_title: - os.environ["SPARK_JENKINS_BUILD_SCALA_PROFILE"] = "scala2.13" - - build_display_name = os.environ["BUILD_DISPLAY_NAME"] - build_url = os.environ["BUILD_URL"] - - project_url = os.getenv("SPARK_PROJECT_URL", "https://github.com/apache/spark") - commit_url = project_url + "/commit/" + ghprb_actual_commit - - # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( - short_commit_hash = ghprb_actual_commit[0:7] - - # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins. Usually Jenkins's timeout is higher - # then this. Please consult with the build manager or a committer when it should be increased. - tests_timeout = "500m" - - # Array to capture all test names to run on the pull request. These tests are represented - # by their file equivalents in the dev/tests/ directory. - # - # To write a PR test: - # * the file must reside within the dev/tests directory - # * be an executable bash script - # * accept three arguments on the command line, the first being the GitHub PR long commit - # hash, the second the GitHub SHA1 hash, and the final the current PR hash - # * and, lastly, return string output to be included in the pr message output that will - # be posted to GitHub - pr_tests = ["pr_merge_ability", "pr_public_classes"] - - # `bind_message_base` returns a function to generate messages for GitHub posting - github_message = functools.partial( - pr_message, build_display_name, build_url, ghprb_pull_id, short_commit_hash, commit_url - ) - - # post start message - post_message_to_github(github_message("has started"), ghprb_pull_id) - - pr_check_results = run_pr_checks(pr_tests, ghprb_actual_commit, sha1) - - test_result_code, test_result_note = run_tests(tests_timeout) - - # post end message - result_message = github_message("has finished") - result_message += "\n" + test_result_note + "\n" - result_message += "\n".join(pr_check_results) - - post_message_to_github(result_message, ghprb_pull_id) - - sys.exit(test_result_code) - - -if __name__ == "__main__": - main() diff --git a/dev/spark-test-image-util/docs/build-docs b/dev/spark-test-image-util/docs/build-docs new file mode 100644 index 0000000000000..6ff9c7cd9455c --- /dev/null +++ b/dev/spark-test-image-util/docs/build-docs @@ -0,0 +1,71 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +if ! [ -x "$(command -v docker)" ]; then + echo "Error: Docker is not installed." >&2 + exit 1 +fi + +DOCKER_CACHE_IMG="ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:master" +REPO_OWNER="apache/spark" +REPOSITORY="apache-spark-ci-image-docs" +IMG_TAG=$(date +%s) +IMG_NAME="${REPOSITORY}:${IMG_TAG}" +IMG_URL="$REPO_OWNER/$IMG_NAME" +DOCKER_MOUNT_SPARK_HOME="/__w/spark/spark" +BUILD_DOCS_SCRIPT_PATH="${DOCKER_MOUNT_SPARK_HOME}/dev/spark-test-image-util/docs/run-in-container" + +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +SPARK_HOME="$(cd "`dirname "${BASH_SOURCE[0]}"`"/../../..; pwd)" + +# 1.Compile spark outside the container to prepare for generating documents inside the container. +build/sbt -Phive -Pkinesis-asl clean unidoc package + +# 2.Build container image. +docker buildx build \ + --cache-from type=registry,ref="${DOCKER_CACHE_IMG}" \ + --tag "${IMG_URL}" "${FWDIR}" \ + --file "${SPARK_HOME}/dev/spark-test-image/docs/Dockerfile" + +# 3.Build docs on container: `error docs`, `scala doc`, `python doc`, `sql doc`. +docker run \ + --mount type=bind,source="${SPARK_HOME}",target="${DOCKER_MOUNT_SPARK_HOME}" \ + --interactive --tty "${IMG_URL}" \ + /bin/bash -c "sh ${BUILD_DOCS_SCRIPT_PATH}" + +# 4.Build docs on host: `r doc`. +# +# Why does `r` document need to be compiled outside the container? +# Because when compiling inside the container, the permission of the directory +# `/__w/spark/spark/R/pkg/docs` automatically generated by `RScript` is `dr-xr--r-x`, +# and when writing to subsequent files, will throw an error as: +# `! [EACCES] Failed to copy '/usr/local/lib/R/site-library/pkgdown/BS5/assets/katex-auto.js' +# to '/__w/spark/spark/R/pkg/docs/katex-auto.js': permission denied` +export SKIP_ERRORDOC=1 +export SKIP_SCALADOC=1 +export SKIP_PYTHONDOC=1 +export SKIP_SQLDOC=1 +cd docs +bundle exec jekyll build + +# 5.Remove container image. +IMG_ID=$(docker images | grep "${IMG_TAG}" | awk '{print $3}') +docker image rm --force "${IMG_ID}" + +echo "Build doc done." diff --git a/dev/run-tests-jenkins b/dev/spark-test-image-util/docs/run-in-container old mode 100755 new mode 100644 similarity index 51% rename from dev/run-tests-jenkins rename to dev/spark-test-image-util/docs/run-in-container index c5bf160380b56..1d43c602f7c72 --- a/dev/run-tests-jenkins +++ b/dev/spark-test-image-util/docs/run-in-container @@ -1,5 +1,3 @@ -#!/usr/bin/env bash - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -17,21 +15,23 @@ # limitations under the License. # -# Wrapper script that runs the Spark tests then reports QA results -# to github via its API. -# Environment variables are populated by the code here: -# https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 - -FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" -cd "$FWDIR" +# 1.Set env variable. +export JAVA_HOME=/usr/lib/jvm/java-17-openjdk-arm64 +export PATH=$JAVA_HOME/bin:$PATH +export SPARK_DOCS_IS_BUILT_ON_HOST=1 +# We expect to compile the R document on the host. +export SKIP_RDOC=1 -export PATH=/home/anaconda/envs/py36/bin:$PATH -export LANG="en_US.UTF-8" +# 2.Install bundler. +gem install bundler -v 2.4.22 +cd /__w/spark/spark/docs +bundle install -PYTHON_VERSION_CHECK=$(python3 -c 'import sys; print(sys.version_info < (3, 8, 0))') -if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then - echo "Python versions prior to 3.8 are not supported." - exit -1 -fi +# 3.Build docs, includes: `error docs`, `scala doc`, `python doc`, `sql doc`, excludes: `r doc`. +# We need this link to make sure `python3` points to `python3.9` which contains the prerequisite packages. +ln -s "$(which python3.9)" "/usr/local/bin/python3" -exec python3 -u ./dev/run-tests-jenkins.py "$@" +# Build docs first with SKIP_API to ensure they are buildable without requiring any +# language docs to be built beforehand. +cd /__w/spark/spark/docs +bundle exec jekyll build diff --git a/dev/spark-test-image/docs/Dockerfile b/dev/spark-test-image/docs/Dockerfile new file mode 100644 index 0000000000000..2db7e0717cdfd --- /dev/null +++ b/dev/spark-test-image/docs/Dockerfile @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building and testing Spark branches. Based on Ubuntu 22.04. +# See also in https://hub.docker.com/_/ubuntu +FROM ubuntu:jammy-20240911.1 +LABEL org.opencontainers.image.authors="Apache Spark project " +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image for Documentation" +# Overwrite this label to avoid exposing the underlying Ubuntu OS version label +LABEL org.opencontainers.image.version="" + +ENV FULL_REFRESH_DATE 20241029 + +ENV DEBIAN_FRONTEND noninteractive +ENV DEBCONF_NONINTERACTIVE_SEEN true + +RUN apt-get update && apt-get install -y \ + build-essential \ + ca-certificates \ + curl \ + gfortran \ + git \ + gnupg \ + libcurl4-openssl-dev \ + libfontconfig1-dev \ + libfreetype6-dev \ + libfribidi-dev \ + libgit2-dev \ + libharfbuzz-dev \ + libjpeg-dev \ + liblapack-dev \ + libopenblas-dev \ + libpng-dev \ + libpython3-dev \ + libssl-dev \ + libtiff5-dev \ + libxml2-dev \ + nodejs \ + npm \ + openjdk-17-jdk-headless \ + pandoc \ + pkg-config \ + qpdf \ + r-base \ + ruby \ + ruby-dev \ + software-properties-common \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + + +# See more in SPARK-39959, roxygen2 < 7.2.1 +RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', 'rmarkdown', 'testthat'), repos='https://cloud.r-project.org/')" && \ + Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" + +# See more in SPARK-39735 +ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" + +# Install Python 3.9 +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt-get update && apt-get install -y python3.9 python3.9-distutils \ + && rm -rf /var/lib/apt/lists/* +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9 + +# Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 +# See 'ipython_genutils' in SPARK-38517 +# See 'docutils<0.18.0' in SPARK-39421 +RUN python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ + 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ + 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ + 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' \ + && python3.9 -m pip cache purge diff --git a/dev/spark-test-image/lint/Dockerfile b/dev/spark-test-image/lint/Dockerfile new file mode 100644 index 0000000000000..f9ea3124291b1 --- /dev/null +++ b/dev/spark-test-image/lint/Dockerfile @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building and testing Spark branches. Based on Ubuntu 22.04. +# See also in https://hub.docker.com/_/ubuntu +FROM ubuntu:jammy-20240911.1 +LABEL org.opencontainers.image.authors="Apache Spark project " +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image for Linter" +# Overwrite this label to avoid exposing the underlying Ubuntu OS version label +LABEL org.opencontainers.image.version="" + +ENV FULL_REFRESH_DATE 20241112 + +ENV DEBIAN_FRONTEND noninteractive +ENV DEBCONF_NONINTERACTIVE_SEEN true + +RUN apt-get update && apt-get install -y \ + build-essential \ + ca-certificates \ + curl \ + gfortran \ + git \ + gnupg \ + libcurl4-openssl-dev \ + libfontconfig1-dev \ + libfreetype6-dev \ + libfribidi-dev \ + libgit2-dev \ + libharfbuzz-dev \ + libjpeg-dev \ + libpng-dev \ + libssl-dev \ + libtiff5-dev \ + libxml2-dev \ + nodejs \ + npm \ + pkg-config \ + qpdf \ + r-base \ + software-properties-common \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', 'rmarkdown', 'testthat'), repos='https://cloud.r-project.org/')" \ + && Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" \ + && Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" \ + && Rscript -e "devtools::install_version('lintr', version='2.0.1', repos='https://cloud.r-project.org')" \ + +# See more in SPARK-39735 +ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" + +# Install Python 3.9 +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt-get update && apt-get install -y python3.9 python3.9-distutils \ + && rm -rf /var/lib/apt/lists/* +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9 + +RUN python3.9 -m pip install \ + 'black==23.9.1' \ + 'flake8==3.9.0' \ + 'googleapis-common-protos-stubs==2.2.0' \ + 'grpc-stubs==1.24.11' \ + 'grpcio-status==1.67.0' \ + 'grpcio==1.67.0' \ + 'ipython' \ + 'ipython_genutils' \ + 'jinja2' \ + 'matplotlib' \ + 'mypy==1.8.0' \ + 'numpy==2.0.2' \ + 'numpydoc' \ + 'pandas' \ + 'pandas-stubs==1.2.0.53' \ + 'plotly>=4.8' \ + 'pyarrow>=18.0.0' \ + 'pytest-mypy-plugins==1.9.3' \ + 'pytest==7.1.3' \ + && python3.9 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu \ + && python3.9 -m pip install torcheval \ + && python3.9 -m pip cache purge diff --git a/dev/spark-test-image/sparkr/Dockerfile b/dev/spark-test-image/sparkr/Dockerfile new file mode 100644 index 0000000000000..43260c714a550 --- /dev/null +++ b/dev/spark-test-image/sparkr/Dockerfile @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building and testing Spark branches. Based on Ubuntu 22.04. +# See also in https://hub.docker.com/_/ubuntu +FROM ubuntu:jammy-20240911.1 +LABEL org.opencontainers.image.authors="Apache Spark project " +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image for SparkR" +# Overwrite this label to avoid exposing the underlying Ubuntu OS version label +LABEL org.opencontainers.image.version="" + +ENV FULL_REFRESH_DATE 20241114 + +ENV DEBIAN_FRONTEND noninteractive +ENV DEBCONF_NONINTERACTIVE_SEEN true + +RUN apt-get update && apt-get install -y \ + build-essential \ + ca-certificates \ + curl \ + gfortran \ + git \ + gnupg \ + libcurl4-openssl-dev \ + libfontconfig1-dev \ + libfreetype6-dev \ + libfribidi-dev \ + libgit2-dev \ + libharfbuzz-dev \ + libjpeg-dev \ + liblapack-dev \ + libopenblas-dev \ + libpng-dev \ + libpython3-dev \ + libssl-dev \ + libtiff5-dev \ + libxml2-dev \ + pandoc \ + pkg-config \ + qpdf \ + r-base \ + software-properties-common \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +RUN echo 'deb https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/' >> /etc/apt/sources.list +RUN gpg --keyserver hkps://keyserver.ubuntu.com --recv-key E298A3A825C0D65DFD57CBB651716619E084DAB9 +RUN gpg -a --export E084DAB9 | apt-key add - +RUN add-apt-repository 'deb https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/' + +# See more in SPARK-39959, roxygen2 < 7.2.1 +RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', \ + 'rmarkdown', 'testthat', 'devtools', 'e1071', 'survival', 'arrow', \ + 'ggplot2', 'mvtnorm', 'statmod', 'xml2'), repos='https://cloud.r-project.org/')" && \ + Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('lintr', version='2.0.1', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" && \ + Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" + +# See more in SPARK-39735 +ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 8e595707494cd..b8702113a26c7 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -324,7 +324,7 @@ def __hash__(self): name="connect", dependencies=[hive, avro, protobuf], source_file_regexes=[ - "connect", + "sql/connect", "connector/connect", ], sbt_test_goals=[ @@ -427,7 +427,7 @@ def __hash__(self): pyspark_core = Module( name="pyspark-core", dependencies=[core], - source_file_regexes=["python/(?!pyspark/(ml|mllib|sql|streaming))"], + source_file_regexes=["python/(?!pyspark/(ml|mllib|sql|streaming|pandas|resource|testing))"], python_test_goals=[ # doctests "pyspark.conf", @@ -500,6 +500,7 @@ def __hash__(self): "pyspark.sql.pandas.typehints", "pyspark.sql.pandas.utils", "pyspark.sql.observation", + "pyspark.sql.tvf", # unittests "pyspark.sql.tests.test_arrow", "pyspark.sql.tests.test_arrow_cogrouped_map", @@ -520,10 +521,12 @@ def __hash__(self): "pyspark.sql.tests.test_errors", "pyspark.sql.tests.test_functions", "pyspark.sql.tests.test_group", + "pyspark.sql.tests.test_sql", "pyspark.sql.tests.pandas.test_pandas_cogrouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state", "pyspark.sql.tests.pandas.test_pandas_map", + "pyspark.sql.tests.pandas.test_pandas_transform_with_state", "pyspark.sql.tests.test_arrow_map", "pyspark.sql.tests.pandas.test_pandas_udf", "pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg", @@ -542,12 +545,17 @@ def __hash__(self): "pyspark.sql.tests.streaming.test_streaming_foreach", "pyspark.sql.tests.streaming.test_streaming_foreach_batch", "pyspark.sql.tests.streaming.test_streaming_listener", + "pyspark.sql.tests.test_subquery", "pyspark.sql.tests.test_types", "pyspark.sql.tests.test_udf", "pyspark.sql.tests.test_udf_profiler", "pyspark.sql.tests.test_udtf", + "pyspark.sql.tests.test_tvf", "pyspark.sql.tests.test_utils", "pyspark.sql.tests.test_resources", + "pyspark.sql.tests.plot.test_frame_plot", + "pyspark.sql.tests.plot.test_frame_plot_plotly", + "pyspark.sql.tests.test_connect_compatibility", ], ) @@ -1008,6 +1016,7 @@ def __hash__(self): "pyspark.sql.connect.protobuf.functions", "pyspark.sql.connect.streaming.readwriter", "pyspark.sql.connect.streaming.query", + "pyspark.sql.connect.tvf", # sql unittests "pyspark.sql.tests.connect.test_connect_plan", "pyspark.sql.tests.connect.test_connect_basic", @@ -1029,12 +1038,14 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_serde", "pyspark.sql.tests.connect.test_parity_functions", "pyspark.sql.tests.connect.test_parity_group", + "pyspark.sql.tests.connect.test_parity_sql", "pyspark.sql.tests.connect.test_parity_dataframe", "pyspark.sql.tests.connect.test_parity_collection", "pyspark.sql.tests.connect.test_parity_creation", "pyspark.sql.tests.connect.test_parity_observation", "pyspark.sql.tests.connect.test_parity_repartition", "pyspark.sql.tests.connect.test_parity_stat", + "pyspark.sql.tests.connect.test_parity_subquery", "pyspark.sql.tests.connect.test_parity_types", "pyspark.sql.tests.connect.test_parity_column", "pyspark.sql.tests.connect.test_parity_readwriter", @@ -1042,6 +1053,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_udf_profiler", "pyspark.sql.tests.connect.test_parity_memory_profiler", "pyspark.sql.tests.connect.test_parity_udtf", + "pyspark.sql.tests.connect.test_parity_tvf", "pyspark.sql.tests.connect.test_parity_pandas_udf", "pyspark.sql.tests.connect.test_parity_pandas_map", "pyspark.sql.tests.connect.test_parity_arrow_map", @@ -1051,6 +1063,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_parity_python_datasource", "pyspark.sql.tests.connect.test_parity_python_streaming_datasource", + "pyspark.sql.tests.connect.test_parity_frame_plot", + "pyspark.sql.tests.connect.test_parity_frame_plot_plotly", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_artifact_localcluster", diff --git a/connector/docker/spark-test/build b/dev/streaming-gen-protos.sh similarity index 82% rename from connector/docker/spark-test/build rename to dev/streaming-gen-protos.sh index 55dff4754b000..3d80bda4fb94e 100755 --- a/connector/docker/spark-test/build +++ b/dev/streaming-gen-protos.sh @@ -16,7 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +set -ex -docker build -t spark-test-base spark-test/base/ -docker build -t spark-test-master spark-test/master/ -docker build -t spark-test-worker spark-test/worker/ +if [[ $# -gt 1 ]]; then + echo "Illegal number of parameters." + echo "Usage: ./dev/streaming-gen-protos.sh [path]" + exit -1 +fi + +./dev/gen-protos.sh streaming "$@" diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh deleted file mode 100755 index ad1ad5e736594..0000000000000 --- a/dev/tests/pr_public_classes.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# -# This script follows the base format for testing pull requests against -# another branch and returning results to be published. More details can be -# found at dev/run-tests-jenkins. -# -# Arg1: The GitHub Pull Request Actual Commit -# known as `ghprbActualCommit` in `run-tests-jenkins` - -ghprbActualCommit="$1" - -# $ghprbActualCommit is an automatic merge commit generated by GitHub; its parents are some Spark -# master commit and the tip of the pull request branch. - -# By diffing$ghprbActualCommit^...$ghprbActualCommit and filtering to examine the diffs of only -# non-test files, we can get changes introduced in the PR and not anything else added to master -# since the PR was branched. - -# Handle differences between GNU and BSD sed -if [[ $(uname) == "Darwin" ]]; then - SED='sed -E' -else - SED='sed -r' -fi - -source_files=$( - git diff $ghprbActualCommit^...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ - | grep -v -e "\/test" `# ignore files in test directories` \ - | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ - | tr "\n" " " -) - -new_public_classes=$( - git diff $ghprbActualCommit^...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ - | grep "^\+" `# filter in only added lines` \ - | $SED -e "s/^\+//g" `# remove the leading +` \ - | grep -e "trait " -e "class " `# filter in lines with these key words` \ - | grep -e "{" -e "(" `# filter in lines with these key words, too` \ - | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ - | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | $SED -e "s/\{.*//g" `# remove from the { onwards` \ - | $SED -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | $SED -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | $SED -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | $SED -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | $SED -e "s/$/\\\n/g" `# append newline to end of line` \ - | tr -d "\n" `# remove actual LF characters` -) - -if [ -z "$new_public_classes" ]; then - echo " * This patch adds no public classes." -else - public_classes_note=" * This patch adds the following public classes _(experimental)_:" - echo -e "${public_classes_note}\n${new_public_classes}" -fi diff --git a/dev/tox.ini b/dev/tox.ini index 47b1b4a9d7832..05a6b16a03bd9 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -59,5 +59,6 @@ exclude = *python/pyspark/worker.pyi, *python/pyspark/java_gateway.pyi, *python/pyspark/sql/connect/proto/*, + *python/pyspark/sql/streaming/proto/*, */venv/* max-line-length = 100 diff --git a/docs/_config.yml b/docs/_config.yml index e74eda0470417..089d6bf2097b8 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -22,7 +22,7 @@ include: SPARK_VERSION: 4.0.0-SNAPSHOT SPARK_VERSION_SHORT: 4.0.0 SCALA_BINARY_VERSION: "2.13" -SCALA_VERSION: "2.13.14" +SCALA_VERSION: "2.13.15" SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark # Before a new release, we should: diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 01c8a8076958f..5fc1f3bcf9b5a 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -108,5 +108,7 @@ url: sql-ref-syntax.html#data-retrieval-statements - text: Auxiliary Statements url: sql-ref-syntax.html#auxiliary-statements + - text: Pipe Syntax + url: sql-pipe-syntax.html - text: Error Conditions url: sql-error-conditions.html diff --git a/docs/_data/menu-streaming.yaml b/docs/_data/menu-streaming.yaml new file mode 100644 index 0000000000000..b1dd024451125 --- /dev/null +++ b/docs/_data/menu-streaming.yaml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +- text: Overview + url: streaming/index.html +- text: Getting Started + url: streaming/getting-started.html + subitems: + - text: Quick Example + url: streaming/getting-started.html#quick-example + - text: Programming Model + url: streaming/getting-started.html#programming-model +- text: APIs on DataFrames and Datasets + url: streaming/apis-on-dataframes-and-datasets.html + subitems: + - text: Creating Streaming DataFrames and Streaming Datasets + url: streaming/apis-on-dataframes-and-datasets.html#creating-streaming-dataframes-and-streaming-datasets + - text: Operations on Streaming DataFrames/Datasets + url: streaming/apis-on-dataframes-and-datasets.html#operations-on-streaming-dataframesdatasets + - text: Starting Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#starting-streaming-queries + - text: Managing Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#managing-streaming-queries + - text: Monitoring Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#monitoring-streaming-queries + - text: Recovering from Failures with Checkpointing + url: streaming/apis-on-dataframes-and-datasets.html#recovering-from-failures-with-checkpointing + - text: Recovery Semantics after Changes in a Streaming Query + url: streaming/apis-on-dataframes-and-datasets.html#recovery-semantics-after-changes-in-a-streaming-query +- text: Performance Tips + url: streaming/performance-tips.html + subitems: + - text: Asynchronous Progress Tracking + url: streaming/performance-tips.html#asynchronous-progress-tracking + - text: Continuous Processing + url: streaming/performance-tips.html#continuous-processing +- text: Additional Information + url: streaming/additional-information.html + subitems: + - text: Miscellaneous Notes + url: streaming/additional-information.html#miscellaneous-notes + - text: Related Resources + url: streaming/additional-information.html#related-resources + - text: Migration Guide + url: streaming/additional-information.html#migration-guide diff --git a/docs/_includes/nav-left-wrapper-streaming.html b/docs/_includes/nav-left-wrapper-streaming.html new file mode 100644 index 0000000000000..82849f8140f5d --- /dev/null +++ b/docs/_includes/nav-left-wrapper-streaming.html @@ -0,0 +1,22 @@ +{% comment %} +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +{% endcomment %} +
+
+

Structured Streaming Programming Guide

+ {% include nav-left.html nav=include.nav-streaming %} +
+
diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html index 19d68fd191635..935ed0c732ee6 100644 --- a/docs/_includes/nav-left.html +++ b/docs/_includes/nav-left.html @@ -2,7 +2,7 @@
    {% for item in include.nav %}
  • - + {% if navurl contains item.url %} {{ item.text }} {% else %} diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index c61c9349a6d7e..f5a20dd441b0e 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -1,3 +1,9 @@ +{% assign current_page_segments = page.dir | split: "/" | where_exp: "element","element != ''" %} +{% assign rel_path_to_root = "" %} +{% for i in (1..current_page_segments.size) %} + {% assign rel_path_to_root = rel_path_to_root | append: "../" %} +{% endfor %} + @@ -21,12 +27,12 @@ - - + + - + - + {% production %} @@ -51,8 +57,8 @@